@@ -232,27 +232,29 @@ def forward(
232232 if self .apply_output :
233233 logits = self .output (h )
234234
235- if self .output_prune_map is not None :
236- # expand to original size so that downstream applications can use the logits as-is.
237- if self .generate_full_logits :
238- # (1, seq_len, pruned_size) -> (1, seq_len, original_size)
239- expanded_logits = torch .full (
240- [logits .shape [0 ], logits .shape [1 ], self .vocab_size ],
241- float ("-inf" ),
242- device = logits .device ,
243- dtype = logits .dtype ,
244- )
245- expanded_logits [:, :, list (self .output_prune_map .values ())] = logits
246- else :
247- # (1, pruned_size) -> (1, original_size)
248- expanded_logits = torch .full (
249- [logits .shape [0 ], self .vocab_size ],
250- float ("-inf" ),
251- device = logits .device ,
252- dtype = logits .dtype ,
253- )
254- expanded_logits [:, list (self .output_prune_map .values ())] = logits
255- logits = expanded_logits
235+ if self .output_prune_map is not None :
236+ # expand to original size so that downstream applications can use the logits as-is.
237+ if self .generate_full_logits :
238+ # (1, seq_len, pruned_size) -> (1, seq_len, original_size)
239+ expanded_logits = torch .full (
240+ [logits .shape [0 ], logits .shape [1 ], self .vocab_size ],
241+ float ("-inf" ),
242+ device = logits .device ,
243+ dtype = logits .dtype ,
244+ )
245+ expanded_logits [:, :, list (self .output_prune_map .values ())] = logits
246+ else :
247+ # (1, pruned_size) -> (1, original_size)
248+ expanded_logits = torch .full (
249+ [logits .shape [0 ], self .vocab_size ],
250+ float ("-inf" ),
251+ device = logits .device ,
252+ dtype = logits .dtype ,
253+ )
254+ expanded_logits [:, list (self .output_prune_map .values ())] = logits
255+ logits = expanded_logits
256+ else :
257+ logits = h
256258
257259 if attn_options_update is not None :
258260 return logits , attn_options_update
0 commit comments