diff --git a/src/diffusers/hooks/hooks.py b/src/diffusers/hooks/hooks.py index f3968e853476..c1358ac201cf 100644 --- a/src/diffusers/hooks/hooks.py +++ b/src/diffusers/hooks/hooks.py @@ -32,7 +32,7 @@ class ModelHook: _is_stateful = False def __init__(self): - self.fn_ref: "FunctionReference" = None + self.fn_ref: "HookFunctionReference" = None def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module: r""" @@ -101,12 +101,27 @@ def reset_state(self, module: torch.nn.Module): return module -class FunctionReference: +class HookFunctionReference: def __init__(self) -> None: + """A container class that maintains mutable references to forward pass functions in a hook chain. + + Its mutable nature allows the hook system to modify the execution chain dynamically without rebuilding the + entire forward pass structure. + + Attributes: + pre_forward: A callable that processes inputs before the main forward pass. + post_forward: A callable that processes outputs after the main forward pass. + forward: The current forward function in the hook chain. + original_forward: The original forward function, stored when a hook provides a custom new_forward. + + The class enables hook removal by allowing updates to the forward chain through reference modification rather + than requiring reconstruction of the entire chain. When a hook is removed, only the relevant references need to + be updated, preserving the execution order of the remaining hooks. + """ self.pre_forward = None self.post_forward = None - self.old_forward = None - self.overwritten_forward = None + self.forward = None + self.original_forward = None class HookRegistry: @@ -125,24 +140,24 @@ def register_hook(self, hook: ModelHook, name: str) -> None: self._module_ref = hook.initialize_hook(self._module_ref) - def create_new_forward(function_reference: FunctionReference): + def create_new_forward(function_reference: HookFunctionReference): def new_forward(module, *args, **kwargs): args, kwargs = function_reference.pre_forward(module, *args, **kwargs) - output = function_reference.old_forward(*args, **kwargs) + output = function_reference.forward(*args, **kwargs) return function_reference.post_forward(module, output) return new_forward forward = self._module_ref.forward - fn_ref = FunctionReference() + fn_ref = HookFunctionReference() fn_ref.pre_forward = hook.pre_forward fn_ref.post_forward = hook.post_forward - fn_ref.old_forward = forward + fn_ref.forward = forward if hasattr(hook, "new_forward"): - fn_ref.overwritten_forward = forward - fn_ref.old_forward = functools.update_wrapper( + fn_ref.original_forward = forward + fn_ref.forward = functools.update_wrapper( functools.partial(hook.new_forward, self._module_ref), hook.new_forward ) @@ -160,25 +175,28 @@ def get_hook(self, name: str) -> Optional[ModelHook]: return self.hooks.get(name, None) def remove_hook(self, name: str, recurse: bool = True) -> None: - num_hooks = len(self._hook_order) - if name in self.hooks.keys(): - hook = self.hooks[name] - index = self._hook_order.index(name) - fn_ref = self._fn_refs[index] - - old_forward = fn_ref.old_forward - if fn_ref.overwritten_forward is not None: - old_forward = fn_ref.overwritten_forward + if name not in self.hooks.keys(): + logger.warning(f"hook: {name} was not found in HookRegistry") + return - if index == num_hooks - 1: - self._module_ref.forward = old_forward - else: - self._fn_refs[index + 1].old_forward = old_forward - - self._module_ref = hook.deinitalize_hook(self._module_ref) - del self.hooks[name] - self._hook_order.pop(index) - self._fn_refs.pop(index) + num_hooks = len(self._hook_order) + hook = self.hooks[name] + index = self._hook_order.index(name) + fn_ref = self._fn_refs[index] + + old_forward = fn_ref.forward + if fn_ref.original_forward is not None: + old_forward = fn_ref.original_forward + + if index == num_hooks - 1: + self._module_ref.forward = old_forward + else: + self._fn_refs[index + 1].forward = old_forward + + self._module_ref = hook.deinitalize_hook(self._module_ref) + del self.hooks[name] + self._hook_order.pop(index) + self._fn_refs.pop(index) if recurse: for module_name, module in self._module_ref.named_modules(): diff --git a/src/diffusers/hooks/pyramid_attention_broadcast.py b/src/diffusers/hooks/pyramid_attention_broadcast.py index 49a75cfdc2e8..9f8597d52f8c 100644 --- a/src/diffusers/hooks/pyramid_attention_broadcast.py +++ b/src/diffusers/hooks/pyramid_attention_broadcast.py @@ -162,7 +162,7 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs) -> Any: ) if should_compute_attention: - output = self.fn_ref.overwritten_forward(*args, **kwargs) + output = self.fn_ref.original_forward(*args, **kwargs) else: output = self.state.cache diff --git a/tests/hooks/test_hooks.py b/tests/hooks/test_hooks.py index 65fea530d1ef..74bd43c52315 100644 --- a/tests/hooks/test_hooks.py +++ b/tests/hooks/test_hooks.py @@ -126,7 +126,7 @@ def new_forward(self, module, *args, **kwargs): logger.debug("SkipLayerHook new_forward") if self.skip_layer: return args[0] - return self.fn_ref.overwritten_forward(*args, **kwargs) + return self.fn_ref.original_forward(*args, **kwargs) def post_forward(self, module, output): logger.debug("SkipLayerHook post_forward") @@ -174,14 +174,12 @@ def test_hook_registry(self): self.assertEqual(len(registry.hooks), 2) self.assertEqual(registry._hook_order, ["add_hook", "multiply_hook"]) - self.assertEqual(len(registry._fn_refs), 2) self.assertEqual(registry_repr, expected_repr) registry.remove_hook("add_hook") self.assertEqual(len(registry.hooks), 1) self.assertEqual(registry._hook_order, ["multiply_hook"]) - self.assertEqual(len(registry._fn_refs), 1) def test_stateful_hook(self): registry = HookRegistry.check_if_exists_or_initialize(self.model)