From 25cec5d96a38df62054de955631d059b1c50a1a6 Mon Sep 17 00:00:00 2001 From: Martin Yuan Date: Tue, 25 Feb 2025 10:17:56 -0800 Subject: [PATCH] Fix pyre error for logits (#8687) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/8687 Fix the pyre error when logits is not initialized Reviewed By: swolchok, jackzhxng Differential Revision: D70183946 --- examples/models/llama/llama_transformer.py | 44 +++++++++++----------- 1 file changed, 23 insertions(+), 21 deletions(-) diff --git a/examples/models/llama/llama_transformer.py b/examples/models/llama/llama_transformer.py index aba55705d20..3536936e47e 100644 --- a/examples/models/llama/llama_transformer.py +++ b/examples/models/llama/llama_transformer.py @@ -232,27 +232,29 @@ def forward( if self.apply_output: logits = self.output(h) - if self.output_prune_map is not None: - # expand to original size so that downstream applications can use the logits as-is. - if self.generate_full_logits: - # (1, seq_len, pruned_size) -> (1, seq_len, original_size) - expanded_logits = torch.full( - [logits.shape[0], logits.shape[1], self.vocab_size], - float("-inf"), - device=logits.device, - dtype=logits.dtype, - ) - expanded_logits[:, :, list(self.output_prune_map.values())] = logits - else: - # (1, pruned_size) -> (1, original_size) - expanded_logits = torch.full( - [logits.shape[0], self.vocab_size], - float("-inf"), - device=logits.device, - dtype=logits.dtype, - ) - expanded_logits[:, list(self.output_prune_map.values())] = logits - logits = expanded_logits + if self.output_prune_map is not None: + # expand to original size so that downstream applications can use the logits as-is. + if self.generate_full_logits: + # (1, seq_len, pruned_size) -> (1, seq_len, original_size) + expanded_logits = torch.full( + [logits.shape[0], logits.shape[1], self.vocab_size], + float("-inf"), + device=logits.device, + dtype=logits.dtype, + ) + expanded_logits[:, :, list(self.output_prune_map.values())] = logits + else: + # (1, pruned_size) -> (1, original_size) + expanded_logits = torch.full( + [logits.shape[0], self.vocab_size], + float("-inf"), + device=logits.device, + dtype=logits.dtype, + ) + expanded_logits[:, list(self.output_prune_map.values())] = logits + logits = expanded_logits + else: + logits = h if attn_options_update is not None: return logits, attn_options_update