@@ -1428,6 +1428,39 @@ def test_sharded_checkpoints_with_variant(self):
14281428
14291429            self .assertTrue (torch .allclose (base_output [0 ], new_output [0 ], atol = 1e-5 ))
14301430
1431+     @require_torch_accelerator  
1432+     def  test_sharded_checkpoints_with_parallel_loading (self ):
1433+         torch .manual_seed (0 )
1434+         config , inputs_dict  =  self .prepare_init_args_and_inputs_for_common ()
1435+         model  =  self .model_class (** config ).eval ()
1436+         model  =  model .to (torch_device )
1437+ 
1438+         base_output  =  model (** inputs_dict )
1439+ 
1440+         model_size  =  compute_module_persistent_sizes (model )["" ]
1441+         max_shard_size  =  int ((model_size  *  0.75 ) /  (2 ** 10 ))  # Convert to KB as these test models are small. 
1442+         with  tempfile .TemporaryDirectory () as  tmp_dir :
1443+             model .cpu ().save_pretrained (tmp_dir , max_shard_size = f"{ max_shard_size }  )
1444+             self .assertTrue (os .path .exists (os .path .join (tmp_dir , SAFE_WEIGHTS_INDEX_NAME )))
1445+ 
1446+             # Now check if the right number of shards exists. First, let's get the number of shards. 
1447+             # Since this number can be dependent on the model being tested, it's important that we calculate it 
1448+             # instead of hardcoding it. 
1449+             expected_num_shards  =  caculate_expected_num_shards (os .path .join (tmp_dir , SAFE_WEIGHTS_INDEX_NAME ))
1450+             actual_num_shards  =  len ([file  for  file  in  os .listdir (tmp_dir ) if  file .endswith (".safetensors" )])
1451+             self .assertTrue (actual_num_shards  ==  expected_num_shards )
1452+ 
1453+             # Load with parallel loading 
1454+             os .environ ["HF_ENABLE_PARALLEL_LOADING" ] =  "yes" 
1455+             new_model  =  self .model_class .from_pretrained (tmp_dir ).eval ()
1456+             new_model  =  new_model .to (torch_device )
1457+ 
1458+             torch .manual_seed (0 )
1459+             if  "generator"  in  inputs_dict :
1460+                 _ , inputs_dict  =  self .prepare_init_args_and_inputs_for_common ()
1461+             new_output  =  new_model (** inputs_dict )
1462+             self .assertTrue (torch .allclose (base_output [0 ], new_output [0 ], atol = 1e-5 ))
1463+ 
14311464    @require_torch_accelerator  
14321465    def  test_sharded_checkpoints_device_map (self ):
14331466        if  self .model_class ._no_split_modules  is  None :
0 commit comments