@@ -353,8 +353,17 @@ def _reorder_cache(self, *args, **kwargs):
353353 def prepare_inputs_for_generation (self , * args , ** kwargs ):
354354 return self .model .prepare_inputs_for_generation (* args , ** kwargs )
355355
356+ def _supports_logits_to_keep (self ) -> bool :
357+ """
358+ Return True if the current model supports the keyword argument `logits_to_keep` in forward()
359+ to save memory. Checking it in this way allows to avoid using a new model attribute.
360+ """
361+ return "logits_to_keep" in set (inspect .signature (self .model .forward ).parameters .keys ())
362+
356363 def _supports_num_logits_to_keep (self ) -> bool :
357364 """
365+ Will be deprecated after we no longer support transformers < 4.49
366+
358367 Return True if the current model supports the keyword argument `num_logits_to_keep` in forward()
359368 to save memory. Checking it in this way allows to avoid using a new model attribute.
360369 """
@@ -470,8 +479,17 @@ def prepare_inputs_for_generation(self, *args, **kwargs):
470479 def get_encoder (self , * args , ** kwargs ):
471480 return self .model .get_encoder (* args , ** kwargs )
472481
482+ def _supports_logits_to_keep (self ) -> bool :
483+ """
484+ Return True if the current model supports the keyword argument `logits_to_keep` in forward()
485+ to save memory. Checking it in this way allows to avoid using a new model attribute.
486+ """
487+ return "logits_to_keep" in set (inspect .signature (self .model .forward ).parameters .keys ())
488+
473489 def _supports_num_logits_to_keep (self ) -> bool :
474490 """
491+ Will be deprecated after we no longer support transformers < 4.49
492+
475493 Return True if the current model supports the keyword argument `num_logits_to_keep` in forward()
476494 to save memory. Checking it in this way allows to avoid using a new model attribute.
477495 """
0 commit comments