@@ -1428,6 +1428,11 @@ def get_memory_usage(storage_dtype, compute_dtype):
14281428 @parameterized .expand ([None , "foo" ])
14291429 def test_works_with_automodel (self , subfolder ):
14301430 config , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
1431+ has_generator_in_inputs = False
1432+ if "generator" in inputs_dict :
1433+ has_generator_in_inputs = True
1434+ inputs_dict ["generator" ] = torch .manual_seed (0 )
1435+
14311436 model = self .model_class (** config ).eval ()
14321437 model_cls_name = model .__class__ .__name__
14331438 model .to (torch_device )
@@ -1438,14 +1443,18 @@ def test_works_with_automodel(self, subfolder):
14381443 with tempfile .TemporaryDirectory () as tmpdir :
14391444 path = os .path .join (tmpdir , subfolder ) if subfolder else tmpdir
14401445 model .save_pretrained (path )
1441- automodel = AutoModel .from_pretrained (tmpdir , subfolder = subfolder ).to (torch_device )
1446+ automodel = AutoModel .from_pretrained (tmpdir , subfolder = subfolder ).eval ()
1447+ automodel .to (torch_device )
14421448
14431449 automodel_cls_name = automodel .__class__ .__name__
14441450 self .assertTrue (model_cls_name == automodel_cls_name )
14451451 for p1 , p2 in zip (model .parameters (), automodel .parameters ()):
1446- self .assertTrue (torch .equal (p1 , p2 ))
1452+ if not (torch .isnan (p1 ).any () and torch .isnan (p2 ).any ()):
1453+ self .assertTrue (torch .equal (p1 , p2 ))
14471454
14481455 torch .manual_seed (0 )
1456+ if has_generator_in_inputs :
1457+ inputs_dict ["generator" ] = torch .manual_seed (0 )
14491458 output_automodel = model (** inputs_dict , return_dict = False )[0 ]
14501459
14511460 self .assertTrue (torch .allclose (output [0 ], output_automodel [0 ], atol = 1e-5 ))
0 commit comments