Skip to content

Commit e1f75c3

Browse files
Update model_patcher.py
1 parent 8654a53 commit e1f75c3

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

optimum/exporters/openvino/model_patcher.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4362,8 +4362,10 @@ def lm_forward(self, attention_mask, position_ids, past_key_values, inputs_embed
43624362
deepstack_visual_embeds=deepstack_visual_embeds,
43634363
)
43644364
hidden_states = outputs[0]
4365+
logits_to_keep = 1
43654366
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
4366-
logits = self.lm_head(hidden_states)
4367+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
4368+
logits = self.lm_head(hidden_states[:, slice_indices, :])
43674369
return (logits, outputs.past_key_values.to_legacy_cache())
43684370

43694371
model.__orig_forward = model.forward

0 commit comments

Comments
 (0)