Skip to content

Commit 51d9c08

Browse files
Support transformers v4.49 logits_to_keep for IPEX (#1188)
* fix logits_to_keep Signed-off-by: jiqing-feng <[email protected]> * fix typo Signed-off-by: jiqing-feng <[email protected]> * Update optimum/intel/ipex/modeling_base.py Co-authored-by: Ella Charlaix <[email protected]> * fix Signed-off-by: jiqing-feng <[email protected]> * add comments Signed-off-by: jiqing-feng <[email protected]> * fix format Signed-off-by: jiqing-feng <[email protected]> --------- Signed-off-by: jiqing-feng <[email protected]> Co-authored-by: Ella Charlaix <[email protected]>
1 parent 93ee486 commit 51d9c08

File tree

1 file changed

+18
-0
lines changed

1 file changed

+18
-0
lines changed

optimum/intel/ipex/modeling_base.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,8 +353,17 @@ def _reorder_cache(self, *args, **kwargs):
353353
def prepare_inputs_for_generation(self, *args, **kwargs):
354354
return self.model.prepare_inputs_for_generation(*args, **kwargs)
355355

356+
def _supports_logits_to_keep(self) -> bool:
357+
"""
358+
Return True if the current model supports the keyword argument `logits_to_keep` in forward()
359+
to save memory. Checking it in this way allows to avoid using a new model attribute.
360+
"""
361+
return "logits_to_keep" in set(inspect.signature(self.model.forward).parameters.keys())
362+
356363
def _supports_num_logits_to_keep(self) -> bool:
357364
"""
365+
Will be deprecated after we no longer support transformers < 4.49
366+
358367
Return True if the current model supports the keyword argument `num_logits_to_keep` in forward()
359368
to save memory. Checking it in this way allows to avoid using a new model attribute.
360369
"""
@@ -470,8 +479,17 @@ def prepare_inputs_for_generation(self, *args, **kwargs):
470479
def get_encoder(self, *args, **kwargs):
471480
return self.model.get_encoder(*args, **kwargs)
472481

482+
def _supports_logits_to_keep(self) -> bool:
483+
"""
484+
Return True if the current model supports the keyword argument `logits_to_keep` in forward()
485+
to save memory. Checking it in this way allows to avoid using a new model attribute.
486+
"""
487+
return "logits_to_keep" in set(inspect.signature(self.model.forward).parameters.keys())
488+
473489
def _supports_num_logits_to_keep(self) -> bool:
474490
"""
491+
Will be deprecated after we no longer support transformers < 4.49
492+
475493
Return True if the current model supports the keyword argument `num_logits_to_keep` in forward()
476494
to save memory. Checking it in this way allows to avoid using a new model attribute.
477495
"""

0 commit comments

Comments
 (0)