Skip to content

Commit 9710bbc

Browse files
committed
update
1 parent ab2eff7 commit 9710bbc

File tree

3 files changed

+107
-98
lines changed

3 files changed

+107
-98
lines changed

src/diffusers/hooks/group_offloading.py

Lines changed: 7 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import glob
1615
import hashlib
1716
import os
1817
from contextlib import contextmanager, nullcontext
@@ -37,8 +36,7 @@
3736
_GROUP_OFFLOADING = "group_offloading"
3837
_LAYER_EXECUTION_TRACKER = "layer_execution_tracker"
3938
_LAZY_PREFETCH_GROUP_OFFLOADING = "lazy_prefetch_group_offloading"
40-
_GROUP_ID_LAZY_LEAF = "lazy_leafs"
41-
_GROUP_ID_UNMATCHED_GROUP = "top_level_unmatched_modules"
39+
GROUP_ID_LAZY_LEAF = "lazy_leafs"
4240
_SUPPORTED_PYTORCH_LAYERS = (
4341
torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d,
4442
torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d,
@@ -609,7 +607,7 @@ def _apply_group_offloading_block_level(
609607

610608
for i in range(0, len(submodule), num_blocks_per_group):
611609
current_modules = submodule[i : i + num_blocks_per_group]
612-
group_id = f"{name}_{i}_{i+len(current_modules)-1}"
610+
group_id = f"{name}_{i}_{i + len(current_modules) - 1}"
613611
group = ModuleGroup(
614612
modules=current_modules,
615613
offload_device=offload_device,
@@ -622,7 +620,7 @@ def _apply_group_offloading_block_level(
622620
record_stream=record_stream,
623621
low_cpu_mem_usage=low_cpu_mem_usage,
624622
onload_self=True,
625-
_group_id=group_id,
623+
group_id=group_id,
626624
)
627625
matched_module_groups.append(group)
628626
for j in range(i, i + len(current_modules)):
@@ -657,7 +655,7 @@ def _apply_group_offloading_block_level(
657655
stream=None,
658656
record_stream=False,
659657
onload_self=True,
660-
_group_id=_GROUP_ID_UNMATCHED_GROUP,
658+
group_id=f"{module.__class__.__name__}_unmatched_group",
661659
)
662660
if stream is None:
663661
_apply_group_offloading_hook(module, unmatched_group, None)
@@ -724,7 +722,7 @@ def _apply_group_offloading_leaf_level(
724722
record_stream=record_stream,
725723
low_cpu_mem_usage=low_cpu_mem_usage,
726724
onload_self=True,
727-
_group_id=name,
725+
group_id=name,
728726
)
729727
_apply_group_offloading_hook(submodule, group, None)
730728
modules_with_group_offloading.add(name)
@@ -772,7 +770,7 @@ def _apply_group_offloading_leaf_level(
772770
record_stream=record_stream,
773771
low_cpu_mem_usage=low_cpu_mem_usage,
774772
onload_self=True,
775-
_group_id=name,
773+
group_id=name,
776774
)
777775
_apply_group_offloading_hook(parent_module, group, None)
778776

@@ -794,7 +792,7 @@ def _apply_group_offloading_leaf_level(
794792
record_stream=False,
795793
low_cpu_mem_usage=low_cpu_mem_usage,
796794
onload_self=True,
797-
_group_id=_GROUP_ID_LAZY_LEAF,
795+
group_id=GROUP_ID_LAZY_LEAF,
798796
)
799797
_apply_lazy_group_offloading_hook(module, unmatched_group, None)
800798

@@ -908,90 +906,3 @@ def _compute_group_hash(group_id):
908906
hashed_id = hashlib.sha256(group_id.encode("utf-8")).hexdigest()
909907
# first 16 characters for a reasonably short but unique name
910908
return hashed_id[:16]
911-
912-
913-
def _get_expected_safetensors_files(
914-
module: torch.nn.Module,
915-
offload_to_disk_path: str,
916-
offload_type: str,
917-
num_blocks_per_group: Optional[int] = None,
918-
) -> Set[str]:
919-
expected_files = set()
920-
921-
def get_hashed_filename(group_id: str) -> str:
922-
short_hash = _compute_group_hash(group_id)
923-
return os.path.join(offload_to_disk_path, f"group_{short_hash}.safetensors")
924-
925-
if offload_type == "block_level":
926-
if num_blocks_per_group is None:
927-
raise ValueError("num_blocks_per_group must be provided for 'block_level' offloading.")
928-
929-
# Handle groups of ModuleList and Sequential blocks
930-
for name, submodule in module.named_children():
931-
if not isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)):
932-
continue
933-
934-
for i in range(0, len(submodule), num_blocks_per_group):
935-
current_modules = submodule[i : i + num_blocks_per_group]
936-
if not current_modules:
937-
continue
938-
start_idx = i
939-
end_idx = i + len(current_modules) - 1
940-
group_id = f"{name}.{start_idx}_to_{end_idx}"
941-
expected_files.add(get_hashed_filename(group_id))
942-
943-
# Handle the group for unmatched top-level modules and parameters
944-
group_id = _GROUP_ID_UNMATCHED_GROUP
945-
expected_files.add(get_hashed_filename(group_id))
946-
947-
elif offload_type == "leaf_level":
948-
# Handle leaf-level module groups
949-
for name, submodule in module.named_modules():
950-
if isinstance(submodule, _SUPPORTED_PYTORCH_LAYERS):
951-
# These groups will always have parameters, so a file is expected
952-
expected_files.add(get_hashed_filename(name))
953-
954-
# Handle groups for non-leaf parameters/buffers
955-
modules_with_group_offloading = {
956-
name for name, sm in module.named_modules() if isinstance(sm, _SUPPORTED_PYTORCH_LAYERS)
957-
}
958-
parameters = _gather_parameters_with_no_group_offloading_parent(module, modules_with_group_offloading)
959-
buffers = _gather_buffers_with_no_group_offloading_parent(module, modules_with_group_offloading)
960-
961-
all_orphans = parameters + buffers
962-
if all_orphans:
963-
parent_to_tensors = {}
964-
module_dict = dict(module.named_modules())
965-
for tensor_name, _ in all_orphans:
966-
parent_name = _find_parent_module_in_module_dict(tensor_name, module_dict)
967-
if parent_name not in parent_to_tensors:
968-
parent_to_tensors[parent_name] = []
969-
parent_to_tensors[parent_name].append(tensor_name)
970-
971-
for parent_name in parent_to_tensors:
972-
# A file is expected for each parent that gathers orphaned tensors
973-
expected_files.add(get_hashed_filename(parent_name))
974-
expected_files.add(get_hashed_filename(_GROUP_ID_LAZY_LEAF))
975-
976-
else:
977-
raise ValueError(f"Unsupported offload_type: {offload_type}")
978-
979-
return expected_files
980-
981-
982-
def _check_safetensors_serialization(
983-
module: torch.nn.Module,
984-
offload_to_disk_path: str,
985-
offload_type: str,
986-
num_blocks_per_group: Optional[int] = None,
987-
) -> bool:
988-
if not os.path.isdir(offload_to_disk_path):
989-
return False, None, None
990-
991-
expected_files = _get_expected_safetensors_files(module, offload_to_disk_path, offload_type, num_blocks_per_group)
992-
actual_files = set(glob.glob(os.path.join(offload_to_disk_path, "*.safetensors")))
993-
missing_files = expected_files - actual_files
994-
extra_files = actual_files - expected_files
995-
996-
is_correct = not missing_files and not extra_files
997-
return is_correct, extra_files, missing_files

src/diffusers/utils/testing_utils.py

Lines changed: 99 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import functools
2+
import glob
23
import importlib
34
import importlib.metadata
45
import inspect
@@ -18,7 +19,7 @@
1819
from contextlib import contextmanager
1920
from io import BytesIO, StringIO
2021
from pathlib import Path
21-
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
22+
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, Tuple, Union
2223

2324
import numpy as np
2425
import PIL.Image
@@ -1377,6 +1378,103 @@ def get_device_properties() -> DeviceProperties:
13771378
else:
13781379
DevicePropertiesUserDict = UserDict
13791380

1381+
if is_torch_available():
1382+
from diffusers.hooks.group_offloading import (
1383+
_GROUP_ID_LAZY_LEAF,
1384+
_SUPPORTED_PYTORCH_LAYERS,
1385+
_compute_group_hash,
1386+
_find_parent_module_in_module_dict,
1387+
_gather_buffers_with_no_group_offloading_parent,
1388+
_gather_parameters_with_no_group_offloading_parent,
1389+
)
1390+
1391+
def _get_expected_safetensors_files(
1392+
module: torch.nn.Module,
1393+
offload_to_disk_path: str,
1394+
offload_type: str,
1395+
num_blocks_per_group: Optional[int] = None,
1396+
) -> Set[str]:
1397+
expected_files = set()
1398+
1399+
def get_hashed_filename(group_id: str) -> str:
1400+
short_hash = _compute_group_hash(group_id)
1401+
return os.path.join(offload_to_disk_path, f"group_{short_hash}.safetensors")
1402+
1403+
if offload_type == "block_level":
1404+
if num_blocks_per_group is None:
1405+
raise ValueError("num_blocks_per_group must be provided for 'block_level' offloading.")
1406+
1407+
# Handle groups of ModuleList and Sequential blocks
1408+
unmatched_modules = []
1409+
for name, submodule in module.named_children():
1410+
if not isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)):
1411+
unmatched_modules.append(module)
1412+
continue
1413+
1414+
for i in range(0, len(submodule), num_blocks_per_group):
1415+
current_modules = submodule[i : i + num_blocks_per_group]
1416+
if not current_modules:
1417+
continue
1418+
group_id = f"{name}_{i}_{i + len(current_modules) - 1}"
1419+
expected_files.add(get_hashed_filename(group_id))
1420+
1421+
# Handle the group for unmatched top-level modules and parameters
1422+
for module in unmatched_modules:
1423+
expected_files.add(get_hashed_filename(f"{module.__class__.__name__}_unmatched_group"))
1424+
1425+
elif offload_type == "leaf_level":
1426+
# Handle leaf-level module groups
1427+
for name, submodule in module.named_modules():
1428+
if isinstance(submodule, _SUPPORTED_PYTORCH_LAYERS):
1429+
# These groups will always have parameters, so a file is expected
1430+
expected_files.add(get_hashed_filename(name))
1431+
1432+
# Handle groups for non-leaf parameters/buffers
1433+
modules_with_group_offloading = {
1434+
name for name, sm in module.named_modules() if isinstance(sm, _SUPPORTED_PYTORCH_LAYERS)
1435+
}
1436+
parameters = _gather_parameters_with_no_group_offloading_parent(module, modules_with_group_offloading)
1437+
buffers = _gather_buffers_with_no_group_offloading_parent(module, modules_with_group_offloading)
1438+
1439+
all_orphans = parameters + buffers
1440+
if all_orphans:
1441+
parent_to_tensors = {}
1442+
module_dict = dict(module.named_modules())
1443+
for tensor_name, _ in all_orphans:
1444+
parent_name = _find_parent_module_in_module_dict(tensor_name, module_dict)
1445+
if parent_name not in parent_to_tensors:
1446+
parent_to_tensors[parent_name] = []
1447+
parent_to_tensors[parent_name].append(tensor_name)
1448+
1449+
for parent_name in parent_to_tensors:
1450+
# A file is expected for each parent that gathers orphaned tensors
1451+
expected_files.add(get_hashed_filename(parent_name))
1452+
expected_files.add(get_hashed_filename(_GROUP_ID_LAZY_LEAF))
1453+
1454+
else:
1455+
raise ValueError(f"Unsupported offload_type: {offload_type}")
1456+
1457+
return expected_files
1458+
1459+
def _check_safetensors_serialization(
1460+
module: torch.nn.Module,
1461+
offload_to_disk_path: str,
1462+
offload_type: str,
1463+
num_blocks_per_group: Optional[int] = None,
1464+
) -> bool:
1465+
if not os.path.isdir(offload_to_disk_path):
1466+
return False, None, None
1467+
1468+
expected_files = _get_expected_safetensors_files(
1469+
module, offload_to_disk_path, offload_type, num_blocks_per_group
1470+
)
1471+
actual_files = set(glob.glob(os.path.join(offload_to_disk_path, "*.safetensors")))
1472+
missing_files = expected_files - actual_files
1473+
extra_files = actual_files - expected_files
1474+
1475+
is_correct = not missing_files and not extra_files
1476+
return is_correct, extra_files, missing_files
1477+
13801478

13811479
class Expectations(DevicePropertiesUserDict):
13821480
def get_expectation(self) -> Any:

tests/models/test_modeling_common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@
4141
from pytest import skip
4242
from requests.exceptions import HTTPError
4343

44-
from diffusers.hooks.group_offloading import _check_safetensors_serialization
4544
from diffusers.models import SD3Transformer2DModel, UNet2DConditionModel
4645
from diffusers.models.attention_processor import (
4746
AttnProcessor,
@@ -62,6 +61,7 @@
6261
from diffusers.utils.hub_utils import _add_variant
6362
from diffusers.utils.testing_utils import (
6463
CaptureLogger,
64+
_check_safetensors_serialization,
6565
backend_empty_cache,
6666
backend_max_memory_allocated,
6767
backend_reset_peak_memory_stats,

0 commit comments

Comments
 (0)