Skip to content

Commit 960cec6

Browse files
Return combined decorator in transformers 5.2.0
1 parent 44731d9 commit 960cec6

File tree

2 files changed

+37
-2
lines changed

2 files changed

+37
-2
lines changed

nemo_automodel/shared/import_utils.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -497,6 +497,8 @@ def get_check_model_inputs_decorator():
497497
498498
In transformers >= 4.57.3, check_model_inputs became a function that returns a decorator.
499499
In older versions, it was directly a decorator.
500+
In transformers >= 5.2.0, check_model_inputs was removed and split into
501+
``merge_with_config_defaults`` and ``capture_outputs``.
500502
501503
Returns:
502504
Decorator function to validate model inputs.
@@ -511,5 +513,19 @@ def get_check_model_inputs_decorator():
511513
# Old API: check_model_inputs is directly a decorator
512514
return check_model_inputs
513515
except ImportError:
514-
# If transformers is not available, return a no-op decorator
515-
return null_decorator
516+
pass
517+
518+
# transformers >= 5.2.0: check_model_inputs was split into two decorators
519+
try:
520+
from transformers.utils.generic import merge_with_config_defaults
521+
from transformers.utils.output_capturing import capture_outputs
522+
523+
def _combined_decorator(func):
524+
return merge_with_config_defaults(capture_outputs(func))
525+
526+
return _combined_decorator
527+
except ImportError:
528+
pass
529+
530+
# No transformers decorator available — return a no-op decorator
531+
return null_decorator

tests/unit_tests/shared/test_import_utils.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,25 @@ def test_get_check_model_inputs_decorator():
202202
assert callable(decorator)
203203

204204

205+
def test_get_check_model_inputs_decorator_fallback_with_kwargs():
206+
"""
207+
The ``null_decorator`` fallback (returned by ``get_check_model_inputs_decorator``
208+
when transformers decorators are unavailable) must work as a plain ``@decorator``
209+
on functions called with keyword arguments.
210+
211+
This guards against a bug where a ``@contextmanager``-wrapped fallback
212+
produced a ``ContextDecorator`` whose ``__call__`` collided with model
213+
forward kwargs like ``input_ids``.
214+
"""
215+
decorator = si.null_decorator
216+
217+
@decorator
218+
def dummy(x=None):
219+
return x
220+
221+
assert dummy(x=42) == 42
222+
223+
205224
def test_null_decorator_as_direct_decorator():
206225
"""
207226
``null_decorator`` must be a valid no-op decorator in ``@decorator`` form.

0 commit comments

Comments
 (0)