diff --git a/examples/models/llama/llama_transformer.py b/examples/models/llama/llama_transformer.py index aba55705d20..3536936e47e 100644 --- a/examples/models/llama/llama_transformer.py +++ b/examples/models/llama/llama_transformer.py @@ -232,27 +232,29 @@ def forward( if self.apply_output: logits = self.output(h) - if self.output_prune_map is not None: - # expand to original size so that downstream applications can use the logits as-is. - if self.generate_full_logits: - # (1, seq_len, pruned_size) -> (1, seq_len, original_size) - expanded_logits = torch.full( - [logits.shape[0], logits.shape[1], self.vocab_size], - float("-inf"), - device=logits.device, - dtype=logits.dtype, - ) - expanded_logits[:, :, list(self.output_prune_map.values())] = logits - else: - # (1, pruned_size) -> (1, original_size) - expanded_logits = torch.full( - [logits.shape[0], self.vocab_size], - float("-inf"), - device=logits.device, - dtype=logits.dtype, - ) - expanded_logits[:, list(self.output_prune_map.values())] = logits - logits = expanded_logits + if self.output_prune_map is not None: + # expand to original size so that downstream applications can use the logits as-is. + if self.generate_full_logits: + # (1, seq_len, pruned_size) -> (1, seq_len, original_size) + expanded_logits = torch.full( + [logits.shape[0], logits.shape[1], self.vocab_size], + float("-inf"), + device=logits.device, + dtype=logits.dtype, + ) + expanded_logits[:, :, list(self.output_prune_map.values())] = logits + else: + # (1, pruned_size) -> (1, original_size) + expanded_logits = torch.full( + [logits.shape[0], self.vocab_size], + float("-inf"), + device=logits.device, + dtype=logits.dtype, + ) + expanded_logits[:, list(self.output_prune_map.values())] = logits + logits = expanded_logits + else: + logits = h if attn_options_update is not None: return logits, attn_options_update