Skip to content

Commit 7a66cda

Browse files
committed
removed hook declaration in leaf_level offloading
1 parent 372c8ab commit 7a66cda

File tree

2 files changed

+14
-19
lines changed

2 files changed

+14
-19
lines changed

src/diffusers/hooks/group_offloading.py

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -376,9 +376,10 @@ class LazyPrefetchGroupOffloadingHook(ModelHook):
376376

377377
_is_stateful = False
378378

379-
def __init__(self):
379+
def __init__(self, pin_groups: Optional[Union[str, Callable]] = None):
380380
self.execution_order: List[Tuple[str, torch.nn.Module]] = []
381381
self._layer_execution_tracker_module_names = set()
382+
self.pin_groups = pin_groups
382383

383384
def initialize_hook(self, module):
384385
def make_execution_order_update_callback(current_name, current_submodule):
@@ -460,8 +461,7 @@ def post_forward(self, module, output):
460461
group_offloading_hooks[i].next_group = group_offloading_hooks[i + 1].group
461462
group_offloading_hooks[i].next_group.onload_self = False
462463

463-
pin_groups = getattr(base_module_registry, "_group_offload_pin_groups", None)
464-
if pin_groups is not None and num_executed > 0:
464+
if self.pin_groups is not None and num_executed > 0:
465465
param_exec_info = []
466466
for idx, ((name, submodule), hook) in enumerate(zip(self.execution_order, group_offloading_hooks)):
467467
if hook is None:
@@ -473,22 +473,22 @@ def post_forward(self, module, output):
473473
num_param_modules = len(param_exec_info)
474474
if num_param_modules > 0:
475475
pinned_indices = set()
476-
if isinstance(pin_groups, str):
477-
if pin_groups == "all":
476+
if isinstance(self.pin_groups, str):
477+
if self.pin_groups == "all":
478478
pinned_indices = set(range(num_param_modules))
479-
elif pin_groups == "first_last":
479+
elif self.pin_groups == "first_last":
480480
pinned_indices.add(0)
481481
pinned_indices.add(num_param_modules - 1)
482-
elif callable(pin_groups):
482+
elif callable(self.pin_groups):
483483
for idx, (name, submodule, _) in enumerate(param_exec_info):
484484
should_pin = False
485485
try:
486-
should_pin = bool(pin_groups(submodule))
486+
should_pin = bool(self.pin_groups(submodule))
487487
except TypeError:
488488
try:
489-
should_pin = bool(pin_groups(name, submodule))
489+
should_pin = bool(self.pin_groups(name, submodule))
490490
except TypeError:
491-
should_pin = bool(pin_groups(name, submodule, idx))
491+
should_pin = bool(self.pin_groups(name, submodule, idx))
492492
if should_pin:
493493
pinned_indices.add(idx)
494494

@@ -651,8 +651,6 @@ def apply_group_offloading(
651651
pin_groups = normalized_pin_groups
652652

653653
_raise_error_if_accelerate_model_or_sequential_hook_present(module)
654-
registry = HookRegistry.check_if_exists_or_initialize(module)
655-
registry._group_offload_pin_groups = pin_groups
656654

657655
config = GroupOffloadingConfig(
658656
onload_device=onload_device,
@@ -671,9 +669,6 @@ def apply_group_offloading(
671669

672670

673671
def _apply_group_offloading(module: torch.nn.Module, config: GroupOffloadingConfig) -> None:
674-
registry = HookRegistry.check_if_exists_or_initialize(module)
675-
registry._group_offload_pin_groups = config.pin_groups
676-
677672
if config.offload_type == GroupOffloadingType.BLOCK_LEVEL:
678673
_apply_group_offloading_block_level(module, config)
679674
elif config.offload_type == GroupOffloadingType.LEAF_LEVEL:
@@ -966,8 +961,8 @@ def _apply_lazy_group_offloading_hook(
966961
if registry.get_hook(_GROUP_OFFLOADING) is None:
967962
hook = GroupOffloadingHook(group, config=config)
968963
registry.register_hook(hook, _GROUP_OFFLOADING)
969-
970-
lazy_prefetch_hook = LazyPrefetchGroupOffloadingHook()
964+
965+
lazy_prefetch_hook = LazyPrefetchGroupOffloadingHook(pin_groups = config.pin_groups)
971966
registry.register_hook(lazy_prefetch_hook, _LAZY_PREFETCH_GROUP_OFFLOADING)
972967

973968

tests/hooks/test_group_offloading.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -362,8 +362,8 @@ def apply_layer_output_tracker_hook(model: DummyModelWithLayerNorm):
362362
self.assertLess(
363363
cumulated_absmax, 1e-5, f"Output differences for {name} exceeded threshold: {cumulated_absmax:.5f}"
364364
)
365-
366-
def test_block_level_pin_first_last_groups_stay_on_device(self):
365+
366+
def test_block_level_pin_groups_stay_on_device(self):
367367
if torch.device(torch_device).type not in ["cuda", "xpu"]:
368368
return
369369

0 commit comments

Comments
 (0)