Skip to content

Commit 1d42deb

Browse files
fix: resolving errors in the hf decorator function (#983)
Signed-off-by: Hemil Desai <[email protected]> Co-authored-by: Hemil Desai <[email protected]>
1 parent b32dc1f commit 1d42deb

File tree

5 files changed

+85
-6
lines changed

5 files changed

+85
-6
lines changed

nemo_automodel/components/models/biencoder/llama_bidirectional_model.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,14 +44,16 @@
4444
)
4545
from transformers.processing_utils import Unpack
4646
from transformers.utils import TransformersKwargs, auto_docstring, logging
47-
from transformers.utils.generic import check_model_inputs
4847

4948
try:
5049
from nemo_automodel.components.models.biencoder.state_dict_adapter import BiencoderStateDictAdapter
5150
except ImportError:
5251
BiencoderStateDictAdapter = object
5352

53+
from nemo_automodel.shared.import_utils import get_check_model_inputs_decorator
54+
5455
logger = logging.get_logger(__name__)
56+
check_model_inputs = get_check_model_inputs_decorator()
5557

5658

5759
def contrastive_scores_and_labels(
@@ -177,7 +179,7 @@ def _update_causal_mask(
177179
return attention_mask
178180
return None
179181

180-
@check_model_inputs()
182+
@check_model_inputs
181183
@auto_docstring
182184
def forward(
183185
self,

nemo_automodel/components/models/llama/model.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,18 +50,20 @@
5050
)
5151
from transformers.processing_utils import Unpack
5252
from transformers.utils import TransformersKwargs, can_return_tuple
53-
from transformers.utils.generic import check_model_inputs
5453

5554
from nemo_automodel.components.models.common.combined_projection import (
5655
CombinedGateUpMLP,
5756
CombinedQKVAttentionMixin,
5857
)
5958
from nemo_automodel.components.models.llama.state_dict_adapter import LlamaStateDictAdapter
6059
from nemo_automodel.components.moe.utils import BackendConfig
60+
from nemo_automodel.shared.import_utils import get_check_model_inputs_decorator
6161
from nemo_automodel.shared.utils import dtype_from_str
6262

6363
__all__ = ["build_llama_model", "LlamaForCausalLM"]
6464

65+
check_model_inputs = get_check_model_inputs_decorator()
66+
6567

6668
class LlamaAttention(CombinedQKVAttentionMixin, nn.Module):
6769
"""Multi-headed attention from 'Attention Is All You Need' paper with combined QKV projection."""
@@ -279,7 +281,7 @@ def __init__(
279281
# Initialize weights and apply final processing
280282
self.post_init()
281283

282-
@check_model_inputs()
284+
@check_model_inputs
283285
def forward(
284286
self,
285287
input_ids: Optional[torch.LongTensor] = None,

nemo_automodel/components/models/qwen2/model.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,18 +50,20 @@
5050
)
5151
from transformers.processing_utils import Unpack
5252
from transformers.utils import TransformersKwargs, can_return_tuple
53-
from transformers.utils.generic import check_model_inputs
5453

5554
from nemo_automodel.components.models.common.combined_projection import (
5655
CombinedGateUpMLP,
5756
CombinedQKVAttentionMixin,
5857
)
5958
from nemo_automodel.components.models.qwen2.state_dict_adapter import Qwen2StateDictAdapter
6059
from nemo_automodel.components.moe.utils import BackendConfig
60+
from nemo_automodel.shared.import_utils import get_check_model_inputs_decorator
6161
from nemo_automodel.shared.utils import dtype_from_str
6262

6363
__all__ = ["build_qwen2_model", "Qwen2ForCausalLM"]
6464

65+
check_model_inputs = get_check_model_inputs_decorator()
66+
6567

6668
class Qwen2Attention(CombinedQKVAttentionMixin, nn.Module):
6769
"""Multi-headed attention with combined QKV projection.
@@ -252,7 +254,7 @@ def __init__(
252254
# Initialize weights and apply final processing
253255
self.post_init()
254256

255-
@check_model_inputs()
257+
@check_model_inputs
256258
def forward(
257259
self,
258260
input_ids: Optional[torch.LongTensor] = None,

nemo_automodel/shared/import_utils.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -419,3 +419,50 @@ def is_te_min_version(version, check_equality=True):
419419
if check_equality:
420420
return get_te_version() >= PkgVersion(version)
421421
return get_te_version() > PkgVersion(version)
422+
423+
424+
def get_transformers_version():
425+
"""Get transformers version from __version__."""
426+
try:
427+
import transformers
428+
429+
if hasattr(transformers, "__version__"):
430+
_version = str(transformers.__version__)
431+
else:
432+
from importlib.metadata import version
433+
434+
_version = version("transformers")
435+
except ImportError:
436+
_version = "0.0.0"
437+
return PkgVersion(_version)
438+
439+
440+
def is_transformers_min_version(version, check_equality=True):
441+
"""Check if minimum version of `transformers` is installed."""
442+
if check_equality:
443+
return get_transformers_version() >= PkgVersion(version)
444+
return get_transformers_version() > PkgVersion(version)
445+
446+
447+
def get_check_model_inputs_decorator():
448+
"""
449+
Get the appropriate check_model_inputs decorator based on transformers version.
450+
451+
In transformers >= 4.57.3, check_model_inputs became a function that returns a decorator.
452+
In older versions, it was directly a decorator.
453+
454+
Returns:
455+
Decorator function to validate model inputs.
456+
"""
457+
try:
458+
from transformers.utils.generic import check_model_inputs
459+
460+
if is_transformers_min_version("4.57.3"):
461+
# New API: check_model_inputs() returns a decorator
462+
return check_model_inputs()
463+
else:
464+
# Old API: check_model_inputs is directly a decorator
465+
return check_model_inputs
466+
except ImportError:
467+
# If transformers is not available, return a no-op decorator
468+
return null_decorator

tests/unit_tests/shared/test_import_utils.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,3 +174,29 @@ def test_is_te_min_version():
174174
"""
175175
assert si.is_te_min_version("0.0.0") is True
176176
assert si.is_te_min_version("9999.0.0", check_equality=False) is False
177+
178+
179+
def test_get_transformers_version_type():
180+
"""
181+
``get_transformers_version`` should *never* raise – even when transformers is unavailable
182+
while building docs – and must always return a ``packaging.version.Version``.
183+
"""
184+
ver = si.get_transformers_version()
185+
assert isinstance(ver, PkgVersion)
186+
187+
188+
def test_is_transformers_min_version():
189+
"""
190+
* A ridiculously low requirement must be satisfied.
191+
* A far-future version must *not* be satisfied.
192+
"""
193+
assert si.is_transformers_min_version("0.0.0") is True
194+
assert si.is_transformers_min_version("9999.0.0", check_equality=False) is False
195+
196+
197+
def test_get_check_model_inputs_decorator():
198+
"""
199+
``get_check_model_inputs_decorator`` should always return a callable decorator.
200+
"""
201+
decorator = si.get_check_model_inputs_decorator()
202+
assert callable(decorator)

0 commit comments

Comments
 (0)