Skip to content

Commit c53a4ab

Browse files
committed
update
1 parent 3f3e26a commit c53a4ab

File tree

2 files changed

+37
-29
lines changed

2 files changed

+37
-29
lines changed

src/diffusers/hooks/hooks.py

Lines changed: 36 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ class ModelHook:
3232
_is_stateful = False
3333

3434
def __init__(self):
35-
self.fn_ref: "FunctionReference" = None
35+
self.fn_ref: "HookFunctionReference" = None
3636

3737
def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
3838
r"""
@@ -101,12 +101,17 @@ def reset_state(self, module: torch.nn.Module):
101101
return module
102102

103103

104-
class FunctionReference:
104+
class HookFunctionReference:
105105
def __init__(self) -> None:
106+
"""
107+
Holding class for forward functions references used in Diffusers hooks. This struct allows you to easily swap
108+
out the forward function in the when a hook is removed from a modules hook registry.
109+
110+
"""
106111
self.pre_forward = None
107112
self.post_forward = None
108-
self.old_forward = None
109-
self.overwritten_forward = None
113+
self.forward = None
114+
self.original_forward = None
110115

111116

112117
class HookRegistry:
@@ -125,24 +130,24 @@ def register_hook(self, hook: ModelHook, name: str) -> None:
125130

126131
self._module_ref = hook.initialize_hook(self._module_ref)
127132

128-
def create_new_forward(function_reference: FunctionReference):
133+
def create_new_forward(function_reference: HookFunctionReference):
129134
def new_forward(module, *args, **kwargs):
130135
args, kwargs = function_reference.pre_forward(module, *args, **kwargs)
131-
output = function_reference.old_forward(*args, **kwargs)
136+
output = function_reference.forward(*args, **kwargs)
132137
return function_reference.post_forward(module, output)
133138

134139
return new_forward
135140

136141
forward = self._module_ref.forward
137142

138-
fn_ref = FunctionReference()
143+
fn_ref = HookFunctionReference()
139144
fn_ref.pre_forward = hook.pre_forward
140145
fn_ref.post_forward = hook.post_forward
141-
fn_ref.old_forward = forward
146+
fn_ref.forward = forward
142147

143148
if hasattr(hook, "new_forward"):
144-
fn_ref.overwritten_forward = forward
145-
fn_ref.old_forward = functools.update_wrapper(
149+
fn_ref.original_forward = forward
150+
fn_ref.forward = functools.update_wrapper(
146151
functools.partial(hook.new_forward, self._module_ref), hook.new_forward
147152
)
148153

@@ -160,25 +165,28 @@ def get_hook(self, name: str) -> Optional[ModelHook]:
160165
return self.hooks.get(name, None)
161166

162167
def remove_hook(self, name: str, recurse: bool = True) -> None:
163-
num_hooks = len(self._hook_order)
164-
if name in self.hooks.keys():
165-
hook = self.hooks[name]
166-
index = self._hook_order.index(name)
167-
fn_ref = self._fn_refs[index]
168+
if name not in self.hooks.keys():
169+
logger.warning(f"hook: {name} was not found in HookRegistry")
170+
return
168171

169-
old_forward = fn_ref.old_forward
170-
if fn_ref.overwritten_forward is not None:
171-
old_forward = fn_ref.overwritten_forward
172-
173-
if index == num_hooks - 1:
174-
self._module_ref.forward = old_forward
175-
else:
176-
self._fn_refs[index + 1].old_forward = old_forward
177-
178-
self._module_ref = hook.deinitalize_hook(self._module_ref)
179-
del self.hooks[name]
180-
self._hook_order.pop(index)
181-
self._fn_refs.pop(index)
172+
num_hooks = len(self._hook_order)
173+
hook = self.hooks[name]
174+
index = self._hook_order.index(name)
175+
fn_ref = self._fn_refs[index]
176+
177+
old_forward = fn_ref.forward
178+
if fn_ref.original_forward is not None:
179+
old_forward = fn_ref.original_forward
180+
181+
if index == num_hooks - 1:
182+
self._module_ref.forward = old_forward
183+
else:
184+
self._fn_refs[index + 1].forward = old_forward
185+
186+
self._module_ref = hook.deinitalize_hook(self._module_ref)
187+
del self.hooks[name]
188+
self._hook_order.pop(index)
189+
self._fn_refs.pop(index)
182190

183191
if recurse:
184192
for module_name, module in self._module_ref.named_modules():

src/diffusers/hooks/pyramid_attention_broadcast.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs) -> Any:
162162
)
163163

164164
if should_compute_attention:
165-
output = self.fn_ref.overwritten_forward(*args, **kwargs)
165+
output = self.fn_ref.original_forward(*args, **kwargs)
166166
else:
167167
output = self.state.cache
168168

0 commit comments

Comments
 (0)