File tree Expand file tree Collapse file tree 2 files changed +37
-2
lines changed
Expand file tree Collapse file tree 2 files changed +37
-2
lines changed Original file line number Diff line number Diff line change @@ -497,6 +497,8 @@ def get_check_model_inputs_decorator():
497497
498498 In transformers >= 4.57.3, check_model_inputs became a function that returns a decorator.
499499 In older versions, it was directly a decorator.
500+ In transformers >= 5.2.0, check_model_inputs was removed and split into
501+ ``merge_with_config_defaults`` and ``capture_outputs``.
500502
501503 Returns:
502504 Decorator function to validate model inputs.
@@ -511,5 +513,19 @@ def get_check_model_inputs_decorator():
511513 # Old API: check_model_inputs is directly a decorator
512514 return check_model_inputs
513515 except ImportError :
514- # If transformers is not available, return a no-op decorator
515- return null_decorator
516+ pass
517+
518+ # transformers >= 5.2.0: check_model_inputs was split into two decorators
519+ try :
520+ from transformers .utils .generic import merge_with_config_defaults
521+ from transformers .utils .output_capturing import capture_outputs
522+
523+ def _combined_decorator (func ):
524+ return merge_with_config_defaults (capture_outputs (func ))
525+
526+ return _combined_decorator
527+ except ImportError :
528+ pass
529+
530+ # No transformers decorator available — return a no-op decorator
531+ return null_decorator
Original file line number Diff line number Diff line change @@ -202,6 +202,25 @@ def test_get_check_model_inputs_decorator():
202202 assert callable (decorator )
203203
204204
205+ def test_get_check_model_inputs_decorator_fallback_with_kwargs ():
206+ """
207+ The ``null_decorator`` fallback (returned by ``get_check_model_inputs_decorator``
208+ when transformers decorators are unavailable) must work as a plain ``@decorator``
209+ on functions called with keyword arguments.
210+
211+ This guards against a bug where a ``@contextmanager``-wrapped fallback
212+ produced a ``ContextDecorator`` whose ``__call__`` collided with model
213+ forward kwargs like ``input_ids``.
214+ """
215+ decorator = si .null_decorator
216+
217+ @decorator
218+ def dummy (x = None ):
219+ return x
220+
221+ assert dummy (x = 42 ) == 42
222+
223+
205224def test_null_decorator_as_direct_decorator ():
206225 """
207226 ``null_decorator`` must be a valid no-op decorator in ``@decorator`` form.
You can’t perform that action at this time.
0 commit comments