Skip to content

Commit 6639f25

Browse files
committed
updates
1 parent 7c8fc64 commit 6639f25

File tree

2 files changed

+19
-17
lines changed

2 files changed

+19
-17
lines changed

src/diffusers/hooks/group_offloading.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@
3737
_GROUP_OFFLOADING = "group_offloading"
3838
_LAYER_EXECUTION_TRACKER = "layer_execution_tracker"
3939
_LAZY_PREFETCH_GROUP_OFFLOADING = "lazy_prefetch_group_offloading"
40-
40+
_GROUP_ID_LAZY_LEAF = "lazy_leafs"
41+
_GROUP_ID_UNMATCHED_GROUP = "top_level_unmatched_modules"
4142
_SUPPORTED_PYTORCH_LAYERS = (
4243
torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d,
4344
torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d,
@@ -84,7 +85,7 @@ def __init__(
8485

8586
if self.offload_to_disk_path:
8687
self._group_id = _group_id
87-
short_hash = self._compute_group_hash(self._group_id)
88+
short_hash = _compute_group_hash(self._group_id)
8889
self.safetensors_file_path = os.path.join(self.offload_to_disk_path, f"group_{short_hash}.safetensors")
8990

9091
all_tensors = []
@@ -265,11 +266,6 @@ def offload_(self):
265266
for buffer in self.buffers:
266267
buffer.data = buffer.data.to(self.offload_device, non_blocking=self.non_blocking)
267268

268-
def _compute_group_hash(self, group_id):
269-
hashed_id = hashlib.sha256(group_id.encode("utf-8")).hexdigest()
270-
# first 16 characters for a reasonably short but unique name
271-
return hashed_id[:16]
272-
273269

274270
class GroupOffloadingHook(ModelHook):
275271
r"""
@@ -663,7 +659,7 @@ def _apply_group_offloading_block_level(
663659
stream=None,
664660
record_stream=False,
665661
onload_self=True,
666-
_group_id="top_level_unmatched_modules",
662+
_group_id=_GROUP_ID_UNMATCHED_GROUP,
667663
)
668664
if stream is None:
669665
_apply_group_offloading_hook(module, unmatched_group, None)
@@ -800,7 +796,7 @@ def _apply_group_offloading_leaf_level(
800796
record_stream=False,
801797
low_cpu_mem_usage=low_cpu_mem_usage,
802798
onload_self=True,
803-
name="lazy_leafs",
799+
_group_id=_GROUP_ID_LAZY_LEAF,
804800
)
805801
_apply_lazy_group_offloading_hook(module, unmatched_group, None)
806802

@@ -910,6 +906,12 @@ def _get_group_onload_device(module: torch.nn.Module) -> torch.device:
910906
raise ValueError("Group offloading is not enabled for the provided module.")
911907

912908

909+
def _compute_group_hash(group_id):
910+
hashed_id = hashlib.sha256(group_id.encode("utf-8")).hexdigest()
911+
# first 16 characters for a reasonably short but unique name
912+
return hashed_id[:16]
913+
914+
913915
def _get_expected_safetensors_files(
914916
module: torch.nn.Module,
915917
offload_to_disk_path: str,
@@ -919,8 +921,7 @@ def _get_expected_safetensors_files(
919921
expected_files = set()
920922

921923
def get_hashed_filename(group_id: str) -> str:
922-
hashed_id = hashlib.sha256(group_id.encode("utf-8")).hexdigest()
923-
short_hash = hashed_id[:16]
924+
short_hash = _compute_group_hash(group_id)
924925
return os.path.join(offload_to_disk_path, f"group_{short_hash}.safetensors")
925926

926927
if offload_type == "block_level":
@@ -942,7 +943,7 @@ def get_hashed_filename(group_id: str) -> str:
942943
expected_files.add(get_hashed_filename(group_id))
943944

944945
# Handle the group for unmatched top-level modules and parameters
945-
group_id = "top_level_unmatched_modules"
946+
group_id = _GROUP_ID_UNMATCHED_GROUP
946947
expected_files.add(get_hashed_filename(group_id))
947948

948949
elif offload_type == "leaf_level":
@@ -972,6 +973,7 @@ def get_hashed_filename(group_id: str) -> str:
972973
for parent_name in parent_to_tensors:
973974
# A file is expected for each parent that gathers orphaned tensors
974975
expected_files.add(get_hashed_filename(parent_name))
976+
expected_files.add(get_hashed_filename(_GROUP_ID_LAZY_LEAF))
975977

976978
else:
977979
raise ValueError(f"Unsupported offload_type: {offload_type}")

tests/models/test_modeling_common.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1346,7 +1346,6 @@ def test_model_parallelism(self):
13461346
new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory)
13471347
# Making sure part of the model will actually end up offloaded
13481348
self.assertSetEqual(set(new_model.hf_device_map.values()), {0, 1})
1349-
print(f" new_model.hf_device_map:{new_model.hf_device_map}")
13501349

13511350
self.check_device_map_is_respected(new_model, new_model.hf_device_map)
13521351

@@ -1708,6 +1707,7 @@ def test_group_offloading_with_disk(self, record_stream, offload_type):
17081707
return
17091708

17101709
model.eval()
1710+
model.to(torch_device)
17111711
output_without_group_offloading = model(**inputs_dict)[0]
17121712

17131713
torch.manual_seed(0)
@@ -1740,10 +1740,10 @@ def test_group_offloading_with_disk(self, record_stream, offload_type):
17401740
elif missing_files:
17411741
raise ValueError(f"Following files are missing: {', '.join(missing_files)}")
17421742

1743-
output_with_group_offloading = model(**inputs_dict)[0]
1744-
self.assertTrue(
1745-
torch.allclose(output_without_group_offloading, output_with_group_offloading, atol=1e-4, rtol=1 - 4)
1746-
)
1743+
output_with_group_offloading = model(**inputs_dict)[0]
1744+
self.assertTrue(
1745+
torch.allclose(output_without_group_offloading, output_with_group_offloading, atol=1e-4, rtol=1e-4)
1746+
)
17471747

17481748
def test_auto_model(self, expected_max_diff=5e-5):
17491749
if self.forward_requires_fresh_args:

0 commit comments

Comments
 (0)