Skip to content

Commit 1f93db4

Browse files
committed
fix
Signed-off-by: Alexandros Koumparoulis <[email protected]>
1 parent 7726541 commit 1f93db4

File tree

2 files changed

+16
-6
lines changed

2 files changed

+16
-6
lines changed

nemo_automodel/_transformers/auto_model.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,12 +48,16 @@
4848

4949

5050
@contextmanager
51-
def local_torch_dtype(dtype: torch.dtype, model_class_name: str | None = None):
51+
def local_torch_dtype(
52+
dtype: torch.dtype, model_class_name: str | None = None, default_dtype: torch.dtype = torch.bfloat16
53+
):
5254
"""
5355
Locally change the torch default dtype to `dtype`, and restore the old one upon exiting the context.
5456
If `model_class_name` is provided, it's used to provide a more helpful error message if `dtype` is not valid.
5557
"""
5658
# Just a more helping error before we set `torch.set_default_dtype` later on which would crash in this case
59+
if isinstance(dtype, str):
60+
dtype = default_dtype
5761
if not dtype.is_floating_point:
5862
if model_class_name is not None:
5963
error_message = (
@@ -62,7 +66,6 @@ def local_torch_dtype(dtype: torch.dtype, model_class_name: str | None = None):
6266
else:
6367
error_message = f"Cannot set `{dtype}` as torch's default as it's not a floating-point dtype"
6468
raise ValueError(error_message)
65-
6669
original_dtype = torch.get_default_dtype()
6770
try:
6871
torch.set_default_dtype(dtype)
@@ -426,9 +429,10 @@ def _retry(**override):
426429
_download_model_weights(hf_config, pretrained_model_name_or_path)
427430
logger.info(f"Using custom model implementation for {architectures[0]}")
428431
kwargs.pop("trust_remote_code", None)
429-
# TODO: restore weights after initialization.
430-
with local_torch_dtype(torch_dtype, ModelRegistry.model_arch_name_to_cls[architectures[0]].__name__):
431-
return ModelRegistry.model_arch_name_to_cls[architectures[0]](hf_config)
432+
# TODO(@akoumpa): restore weights after initialization.
433+
model_cls = ModelRegistry.model_arch_name_to_cls[architectures[0]]
434+
with local_torch_dtype(torch_dtype, model_cls.__name__):
435+
return model_cls(hf_config)
432436

433437
# 3. fallback to parent class
434438
model = None

tests/unit_tests/_transformers/test_auto_model.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ def test_from_pretrained_uses_registry_when_available(self):
103103
# Prepare a fake custom model class and return value
104104
custom_model_instance = Mock()
105105
custom_cls = Mock(return_value=custom_model_instance)
106+
custom_cls.__name__ = "MockMockMock"
106107
mock_registry.model_arch_name_to_cls = {"CustomArch": custom_cls}
107108

108109
returned = NeMoAutoModelForCausalLM.from_pretrained("dummy/path")
@@ -130,6 +131,7 @@ def test_from_config_uses_registry_when_available(self):
130131
# Registry provides a custom class
131132
custom_model_instance = Mock()
132133
custom_cls = Mock(return_value=custom_model_instance)
134+
custom_cls.__name__ = "MockMockMock"
133135
mock_registry.model_arch_name_to_cls = {"CustomArch": custom_cls}
134136

135137
returned = NeMoAutoModelForCausalLM.from_config(cfg)
@@ -160,6 +162,7 @@ def test_from_pretrained_registry_downloads_checkpoint_files_rank0(self):
160162
# Prepare a fake custom model class and return value
161163
custom_model_instance = Mock()
162164
custom_cls = Mock(return_value=custom_model_instance)
165+
custom_cls.__name__ = "MockMockMock"
163166
mock_registry.model_arch_name_to_cls = {"CustomArch": custom_cls}
164167

165168
returned = NeMoAutoModelForCausalLM.from_pretrained("dummy/repo-id")
@@ -194,6 +197,7 @@ def test_from_pretrained_registry_downloads_when_dist_uninitialized(self):
194197
# Prepare a fake custom model class and return value
195198
custom_model_instance = Mock()
196199
custom_cls = Mock(return_value=custom_model_instance)
200+
custom_cls.__name__ = "MockMockMock"
197201
mock_registry.model_arch_name_to_cls = {"CustomArch": custom_cls}
198202

199203
returned = NeMoAutoModelForCausalLM.from_pretrained("dummy/repo-id")
@@ -240,7 +244,8 @@ def test_from_config_with_string_calls_autoconfig(self):
240244
# Verify AutoConfig.from_pretrained was called with the string
241245
mock_autoconfig.assert_called_once_with(
242246
"hf-internal-testing/tiny-random-gpt2",
243-
trust_remote_code=False
247+
trust_remote_code=False,
248+
attn_implementation="flash_attention_2",
244249
)
245250
# Verify the model was returned
246251
assert model is mock_model
@@ -539,6 +544,7 @@ def test_packed_sequence_and_cp_overrides_from_pretrained(
539544
else:
540545
custom_model_instance = Mock()
541546
custom_cls = Mock(return_value=custom_model_instance)
547+
custom_cls.__name__ = "MockMockMock"
542548
mock_registry.model_arch_name_to_cls = {"CustomArch": custom_cls}
543549

544550
mock_hf_loader.return_value = MagicMock(config={})

0 commit comments

Comments
 (0)