Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 46 additions & 28 deletions src/diffusers/hooks/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class ModelHook:
_is_stateful = False

def __init__(self):
self.fn_ref: "FunctionReference" = None
self.fn_ref: "HookFunctionReference" = None

def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
r"""
Expand Down Expand Up @@ -101,12 +101,27 @@ def reset_state(self, module: torch.nn.Module):
return module


class FunctionReference:
class HookFunctionReference:
def __init__(self) -> None:
"""A container class that maintains mutable references to forward pass functions in a hook chain.

Its mutable nature allows the hook system to modify the execution chain dynamically without rebuilding the
entire forward pass structure.

Attributes:
pre_forward: A callable that processes inputs before the main forward pass.
post_forward: A callable that processes outputs after the main forward pass.
forward: The current forward function in the hook chain.
original_forward: The original forward function, stored when a hook provides a custom new_forward.

The class enables hook removal by allowing updates to the forward chain through reference modification rather
than requiring reconstruction of the entire chain. When a hook is removed, only the relevant references need to
be updated, preserving the execution order of the remaining hooks.
"""
self.pre_forward = None
self.post_forward = None
self.old_forward = None
self.overwritten_forward = None
self.forward = None
self.original_forward = None


class HookRegistry:
Expand All @@ -125,24 +140,24 @@ def register_hook(self, hook: ModelHook, name: str) -> None:

self._module_ref = hook.initialize_hook(self._module_ref)

def create_new_forward(function_reference: FunctionReference):
def create_new_forward(function_reference: HookFunctionReference):
def new_forward(module, *args, **kwargs):
args, kwargs = function_reference.pre_forward(module, *args, **kwargs)
output = function_reference.old_forward(*args, **kwargs)
output = function_reference.forward(*args, **kwargs)
return function_reference.post_forward(module, output)

return new_forward

forward = self._module_ref.forward

fn_ref = FunctionReference()
fn_ref = HookFunctionReference()
fn_ref.pre_forward = hook.pre_forward
fn_ref.post_forward = hook.post_forward
fn_ref.old_forward = forward
fn_ref.forward = forward

if hasattr(hook, "new_forward"):
fn_ref.overwritten_forward = forward
fn_ref.old_forward = functools.update_wrapper(
fn_ref.original_forward = forward
fn_ref.forward = functools.update_wrapper(
functools.partial(hook.new_forward, self._module_ref), hook.new_forward
)

Expand All @@ -160,25 +175,28 @@ def get_hook(self, name: str) -> Optional[ModelHook]:
return self.hooks.get(name, None)

def remove_hook(self, name: str, recurse: bool = True) -> None:
num_hooks = len(self._hook_order)
if name in self.hooks.keys():
hook = self.hooks[name]
index = self._hook_order.index(name)
fn_ref = self._fn_refs[index]

old_forward = fn_ref.old_forward
if fn_ref.overwritten_forward is not None:
old_forward = fn_ref.overwritten_forward
if name not in self.hooks.keys():
logger.warning(f"hook: {name} was not found in HookRegistry")
return

if index == num_hooks - 1:
self._module_ref.forward = old_forward
else:
self._fn_refs[index + 1].old_forward = old_forward

self._module_ref = hook.deinitalize_hook(self._module_ref)
del self.hooks[name]
self._hook_order.pop(index)
self._fn_refs.pop(index)
num_hooks = len(self._hook_order)
hook = self.hooks[name]
index = self._hook_order.index(name)
fn_ref = self._fn_refs[index]

old_forward = fn_ref.forward
if fn_ref.original_forward is not None:
old_forward = fn_ref.original_forward

if index == num_hooks - 1:
self._module_ref.forward = old_forward
else:
self._fn_refs[index + 1].forward = old_forward

self._module_ref = hook.deinitalize_hook(self._module_ref)
del self.hooks[name]
self._hook_order.pop(index)
self._fn_refs.pop(index)

if recurse:
for module_name, module in self._module_ref.named_modules():
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/hooks/pyramid_attention_broadcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs) -> Any:
)

if should_compute_attention:
output = self.fn_ref.overwritten_forward(*args, **kwargs)
output = self.fn_ref.original_forward(*args, **kwargs)
else:
output = self.state.cache

Expand Down
4 changes: 1 addition & 3 deletions tests/hooks/test_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def new_forward(self, module, *args, **kwargs):
logger.debug("SkipLayerHook new_forward")
if self.skip_layer:
return args[0]
return self.fn_ref.overwritten_forward(*args, **kwargs)
return self.fn_ref.original_forward(*args, **kwargs)

def post_forward(self, module, output):
logger.debug("SkipLayerHook post_forward")
Expand Down Expand Up @@ -174,14 +174,12 @@ def test_hook_registry(self):

self.assertEqual(len(registry.hooks), 2)
self.assertEqual(registry._hook_order, ["add_hook", "multiply_hook"])
self.assertEqual(len(registry._fn_refs), 2)
self.assertEqual(registry_repr, expected_repr)

registry.remove_hook("add_hook")

self.assertEqual(len(registry.hooks), 1)
self.assertEqual(registry._hook_order, ["multiply_hook"])
self.assertEqual(len(registry._fn_refs), 1)

def test_stateful_hook(self):
registry = HookRegistry.check_if_exists_or_initialize(self.model)
Expand Down