Skip to content

Commit 447e881

Browse files
committed
refactor
1 parent 357668e commit 447e881

File tree

1 file changed

+10
-14
lines changed

1 file changed

+10
-14
lines changed

src/diffusers/hooks/group_offloading.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -281,11 +281,9 @@ class GroupOffloadingHook(ModelHook):
281281

282282
_is_stateful = False
283283

284-
def __init__(
285-
self, group: ModuleGroup, next_group: Optional[ModuleGroup] = None, *, config: GroupOffloadingConfig
286-
) -> None:
284+
def __init__(self, group: ModuleGroup, *, config: GroupOffloadingConfig) -> None:
287285
self.group = group
288-
self.next_group = next_group
286+
self.next_group: Optional[ModuleGroup] = None
289287
self.config = config
290288

291289
def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
@@ -609,7 +607,7 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf
609607
# Apply group offloading hooks to the module groups
610608
for i, group in enumerate(matched_module_groups):
611609
for group_module in group.modules:
612-
_apply_group_offloading_hook(group_module, group, None, config=config)
610+
_apply_group_offloading_hook(group_module, group, config=config)
613611

614612
# Parameters and Buffers of the top-level module need to be offloaded/onloaded separately
615613
# when the forward pass of this module is called. This is because the top-level module is not
@@ -638,9 +636,9 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf
638636
group_id=f"{module.__class__.__name__}_unmatched_group",
639637
)
640638
if config.stream is None:
641-
_apply_group_offloading_hook(module, unmatched_group, None, config=config)
639+
_apply_group_offloading_hook(module, unmatched_group, config=config)
642640
else:
643-
_apply_lazy_group_offloading_hook(module, unmatched_group, None, config=config)
641+
_apply_lazy_group_offloading_hook(module, unmatched_group, config=config)
644642

645643

646644
def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOffloadingConfig) -> None:
@@ -669,7 +667,7 @@ def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOff
669667
onload_self=True,
670668
group_id=name,
671669
)
672-
_apply_group_offloading_hook(submodule, group, None, config=config)
670+
_apply_group_offloading_hook(submodule, group, config=config)
673671
modules_with_group_offloading.add(name)
674672

675673
# Parameters and Buffers at all non-leaf levels need to be offloaded/onloaded separately when the forward pass
@@ -716,7 +714,7 @@ def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOff
716714
onload_self=True,
717715
group_id=name,
718716
)
719-
_apply_group_offloading_hook(parent_module, group, None, config=config)
717+
_apply_group_offloading_hook(parent_module, group, config=config)
720718

721719
if config.stream is not None:
722720
# When using streams, we need to know the layer execution order for applying prefetching (to overlap data transfer
@@ -738,13 +736,12 @@ def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOff
738736
onload_self=True,
739737
group_id=_GROUP_ID_LAZY_LEAF,
740738
)
741-
_apply_lazy_group_offloading_hook(module, unmatched_group, None, config=config)
739+
_apply_lazy_group_offloading_hook(module, unmatched_group, config=config)
742740

743741

744742
def _apply_group_offloading_hook(
745743
module: torch.nn.Module,
746744
group: ModuleGroup,
747-
next_group: Optional[ModuleGroup] = None,
748745
*,
749746
config: GroupOffloadingConfig,
750747
) -> None:
@@ -753,14 +750,13 @@ def _apply_group_offloading_hook(
753750
# We may have already registered a group offloading hook if the module had a torch.nn.Parameter whose parent
754751
# is the current module. In such cases, we don't want to overwrite the existing group offloading hook.
755752
if registry.get_hook(_GROUP_OFFLOADING) is None:
756-
hook = GroupOffloadingHook(group, next_group, config=config)
753+
hook = GroupOffloadingHook(group, config=config)
757754
registry.register_hook(hook, _GROUP_OFFLOADING)
758755

759756

760757
def _apply_lazy_group_offloading_hook(
761758
module: torch.nn.Module,
762759
group: ModuleGroup,
763-
next_group: Optional[ModuleGroup] = None,
764760
*,
765761
config: GroupOffloadingConfig,
766762
) -> None:
@@ -769,7 +765,7 @@ def _apply_lazy_group_offloading_hook(
769765
# We may have already registered a group offloading hook if the module had a torch.nn.Parameter whose parent
770766
# is the current module. In such cases, we don't want to overwrite the existing group offloading hook.
771767
if registry.get_hook(_GROUP_OFFLOADING) is None:
772-
hook = GroupOffloadingHook(group, next_group, config=config)
768+
hook = GroupOffloadingHook(group, config=config)
773769
registry.register_hook(hook, _GROUP_OFFLOADING)
774770

775771
lazy_prefetch_hook = LazyPrefetchGroupOffloadingHook()

0 commit comments

Comments
 (0)