Skip to content

Commit 207fb07

Browse files
committed
update
1 parent b75b204 commit 207fb07

File tree

1 file changed

+43
-0
lines changed

1 file changed

+43
-0
lines changed

tests/models/test_modeling_common.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1507,6 +1507,49 @@ def run_forward(model):
15071507
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading3, atol=1e-5))
15081508
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading4, atol=1e-5))
15091509

1510+
def test_error_when_disk_offload_run_together_with_group_offloading(self):
1511+
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
1512+
model1 = self.model_class(**config).eval()
1513+
model1 = model1.to(torch_device)
1514+
1515+
def has_accelerate_hooks(module):
1516+
from accelerate.hooks import AlignDevicesHook, CpuOffload
1517+
1518+
count = 0
1519+
for name, submodule in module.named_modules():
1520+
if not hasattr(submodule, "_hf_hook"):
1521+
continue
1522+
if isinstance(submodule._hf_hook, (AlignDevicesHook, CpuOffload)):
1523+
print(f"Found {name} with hook {submodule._hf_hook}")
1524+
count += 1
1525+
return count > 0
1526+
1527+
model_size = compute_module_sizes(model1)[""]
1528+
with tempfile.TemporaryDirectory() as tmp_dir:
1529+
model1.cpu().save_pretrained(tmp_dir)
1530+
max_size = int(self.model_split_percents[0] * model_size)
1531+
max_memory = {0: max_size, "cpu": max_size}
1532+
new_model = self.model_class.from_pretrained(
1533+
tmp_dir, device_map="auto", offload_folder=tmp_dir, max_memory=max_memory
1534+
)
1535+
self.check_device_map_is_respected(new_model, new_model.hf_device_map)
1536+
assert has_accelerate_hooks(new_model)
1537+
1538+
del model1
1539+
torch.cuda.synchronize()
1540+
torch.cpu.synchronize()
1541+
torch.cuda.empty_cache()
1542+
gc.collect()
1543+
1544+
model2 = self.model_class(**config)
1545+
1546+
# ===============================
1547+
# We still have accelerate hooks on model2 in some cases?????
1548+
assert has_accelerate_hooks(model2)
1549+
# ===============================
1550+
1551+
model2.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1)
1552+
15101553

15111554
@is_staging_test
15121555
class ModelPushToHubTester(unittest.TestCase):

0 commit comments

Comments
 (0)