@@ -1507,6 +1507,49 @@ def run_forward(model):
15071507 self .assertTrue (torch .allclose (output_without_group_offloading , output_with_group_offloading3 , atol = 1e-5 ))
15081508 self .assertTrue (torch .allclose (output_without_group_offloading , output_with_group_offloading4 , atol = 1e-5 ))
15091509
1510+ def test_error_when_disk_offload_run_together_with_group_offloading (self ):
1511+ config , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
1512+ model1 = self .model_class (** config ).eval ()
1513+ model1 = model1 .to (torch_device )
1514+
1515+ def has_accelerate_hooks (module ):
1516+ from accelerate .hooks import AlignDevicesHook , CpuOffload
1517+
1518+ count = 0
1519+ for name , submodule in module .named_modules ():
1520+ if not hasattr (submodule , "_hf_hook" ):
1521+ continue
1522+ if isinstance (submodule ._hf_hook , (AlignDevicesHook , CpuOffload )):
1523+ print (f"Found { name } with hook { submodule ._hf_hook } " )
1524+ count += 1
1525+ return count > 0
1526+
1527+ model_size = compute_module_sizes (model1 )["" ]
1528+ with tempfile .TemporaryDirectory () as tmp_dir :
1529+ model1 .cpu ().save_pretrained (tmp_dir )
1530+ max_size = int (self .model_split_percents [0 ] * model_size )
1531+ max_memory = {0 : max_size , "cpu" : max_size }
1532+ new_model = self .model_class .from_pretrained (
1533+ tmp_dir , device_map = "auto" , offload_folder = tmp_dir , max_memory = max_memory
1534+ )
1535+ self .check_device_map_is_respected (new_model , new_model .hf_device_map )
1536+ assert has_accelerate_hooks (new_model )
1537+
1538+ del model1
1539+ torch .cuda .synchronize ()
1540+ torch .cpu .synchronize ()
1541+ torch .cuda .empty_cache ()
1542+ gc .collect ()
1543+
1544+ model2 = self .model_class (** config )
1545+
1546+ # ===============================
1547+ # We still have accelerate hooks on model2 in some cases?????
1548+ assert has_accelerate_hooks (model2 )
1549+ # ===============================
1550+
1551+ model2 .enable_group_offload (torch_device , offload_type = "block_level" , num_blocks_per_group = 1 )
1552+
15101553
15111554@is_staging_test
15121555class ModelPushToHubTester (unittest .TestCase ):
0 commit comments