@@ -103,6 +103,7 @@ def test_from_pretrained_uses_registry_when_available(self):
103103 # Prepare a fake custom model class and return value
104104 custom_model_instance = Mock ()
105105 custom_cls = Mock (return_value = custom_model_instance )
106+ custom_cls .__name__ = "MockMockMock"
106107 mock_registry .model_arch_name_to_cls = {"CustomArch" : custom_cls }
107108
108109 returned = NeMoAutoModelForCausalLM .from_pretrained ("dummy/path" )
@@ -130,6 +131,7 @@ def test_from_config_uses_registry_when_available(self):
130131 # Registry provides a custom class
131132 custom_model_instance = Mock ()
132133 custom_cls = Mock (return_value = custom_model_instance )
134+ custom_cls .__name__ = "MockMockMock"
133135 mock_registry .model_arch_name_to_cls = {"CustomArch" : custom_cls }
134136
135137 returned = NeMoAutoModelForCausalLM .from_config (cfg )
@@ -160,6 +162,7 @@ def test_from_pretrained_registry_downloads_checkpoint_files_rank0(self):
160162 # Prepare a fake custom model class and return value
161163 custom_model_instance = Mock ()
162164 custom_cls = Mock (return_value = custom_model_instance )
165+ custom_cls .__name__ = "MockMockMock"
163166 mock_registry .model_arch_name_to_cls = {"CustomArch" : custom_cls }
164167
165168 returned = NeMoAutoModelForCausalLM .from_pretrained ("dummy/repo-id" )
@@ -194,6 +197,7 @@ def test_from_pretrained_registry_downloads_when_dist_uninitialized(self):
194197 # Prepare a fake custom model class and return value
195198 custom_model_instance = Mock ()
196199 custom_cls = Mock (return_value = custom_model_instance )
200+ custom_cls .__name__ = "MockMockMock"
197201 mock_registry .model_arch_name_to_cls = {"CustomArch" : custom_cls }
198202
199203 returned = NeMoAutoModelForCausalLM .from_pretrained ("dummy/repo-id" )
@@ -240,7 +244,8 @@ def test_from_config_with_string_calls_autoconfig(self):
240244 # Verify AutoConfig.from_pretrained was called with the string
241245 mock_autoconfig .assert_called_once_with (
242246 "hf-internal-testing/tiny-random-gpt2" ,
243- trust_remote_code = False
247+ trust_remote_code = False ,
248+ attn_implementation = "flash_attention_2" ,
244249 )
245250 # Verify the model was returned
246251 assert model is mock_model
@@ -539,6 +544,7 @@ def test_packed_sequence_and_cp_overrides_from_pretrained(
539544 else :
540545 custom_model_instance = Mock ()
541546 custom_cls = Mock (return_value = custom_model_instance )
547+ custom_cls .__name__ = "MockMockMock"
542548 mock_registry .model_arch_name_to_cls = {"CustomArch" : custom_cls }
543549
544550 mock_hf_loader .return_value = MagicMock (config = {})
0 commit comments