Skip to content

Commit b950c74

Browse files
committed
Fix leaf-level group offload root hook
1 parent 8d059e6 commit b950c74

File tree

1 file changed

+22
-0
lines changed

1 file changed

+22
-0
lines changed

src/diffusers/hooks/group_offloading.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -982,6 +982,28 @@ def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOff
982982
)
983983
_apply_group_offloading_hook(parent_module, group, config=config)
984984

985+
# Ensure the top-level module also has a group_offloading hook so hook presence checks pass,
986+
# even when it holds no parameters/buffers itself.
987+
if config.stream is None:
988+
root_registry = HookRegistry.check_if_exists_or_initialize(module)
989+
if root_registry.get_hook(_GROUP_OFFLOADING) is None:
990+
empty_group = ModuleGroup(
991+
modules=[],
992+
offload_device=config.offload_device,
993+
onload_device=config.onload_device,
994+
offload_to_disk_path=None,
995+
offload_leader=module,
996+
onload_leader=module,
997+
parameters=[],
998+
buffers=[],
999+
non_blocking=False,
1000+
stream=None,
1001+
record_stream=False,
1002+
onload_self=True,
1003+
group_id=f"{config.module_prefix}{module.__class__.__name__}_empty_group",
1004+
)
1005+
root_registry.register_hook(GroupOffloadingHook(empty_group, config=config), _GROUP_OFFLOADING)
1006+
9851007
if config.stream is not None:
9861008
# When using streams, we need to know the layer execution order for applying prefetching (to overlap data transfer
9871009
# and computation). Since we don't know the order beforehand, we apply a lazy prefetching hook that will find the

0 commit comments

Comments
 (0)