Skip to content

Commit 1040c91

Browse files
committed
more fixes
1 parent fb66167 commit 1040c91

File tree

3 files changed

+32
-15
lines changed

3 files changed

+32
-15
lines changed

src/diffusers/hooks/hooks.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -121,16 +121,20 @@ def register_hook(self, hook: ModelHook, name: str) -> None:
121121
self._module_ref = hook.initialize_hook(self._module_ref)
122122

123123
if hasattr(hook, "new_forward"):
124-
new_forward = hook.new_forward
124+
rewritten_forward = hook.new_forward
125+
126+
def new_forward(module, *args, **kwargs):
127+
args, kwargs = hook.pre_forward(module, *args, **kwargs)
128+
output = rewritten_forward(module, *args, **kwargs)
129+
return hook.post_forward(module, output)
125130
else:
126131

127132
def new_forward(module, *args, **kwargs):
128133
args, kwargs = hook.pre_forward(module, *args, **kwargs)
129134
output = old_forward(*args, **kwargs)
130135
return hook.post_forward(module, output)
131136

132-
new_forward = functools.update_wrapper(new_forward, old_forward)
133-
self._module_ref.forward = new_forward.__get__(self._module_ref)
137+
self._module_ref.forward = functools.update_wrapper(functools.partial(new_forward, self._module_ref), old_forward)
134138

135139
self.hooks[name] = hook
136140
self._hook_order.append(name)
@@ -147,11 +151,16 @@ def remove_hook(self, name: str) -> None:
147151
del self.hooks[name]
148152
self._hook_order.remove(name)
149153

150-
def reset_stateful_hooks(self):
154+
def reset_stateful_hooks(self, recurse: bool = True) -> None:
151155
for hook_name in self._hook_order:
152156
hook = self.hooks[hook_name]
153157
if hook._is_stateful:
154158
hook.reset_state(self._module_ref)
159+
160+
if recurse:
161+
for module in self._module_ref.modules():
162+
if hasattr(module, "_diffusers_hook"):
163+
module._diffusers_hook.reset_stateful_hooks(recurse=False)
155164

156165
@classmethod
157166
def check_if_exists_or_initialize(cls, module: torch.nn.Module) -> "HookRegistry":

src/diffusers/hooks/pyramid_attention_broadcast.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,14 @@ def __init__(self) -> None:
106106
def reset(self):
107107
self.iteration = 0
108108
self.cache = None
109+
110+
def __repr__(self):
111+
cache_repr = ""
112+
if self.cache is None:
113+
cache_repr = "None"
114+
else:
115+
cache_repr = f"Tensor(shape={self.cache.shape}, dtype={self.cache.dtype})"
116+
return f"PyramidAttentionBroadcastState(iteration={self.iteration}, cache={cache_repr})"
109117

110118

111119
class PyramidAttentionBroadcastHook(ModelHook):
@@ -120,21 +128,21 @@ def __init__(self, skip_callback: Callable[[torch.nn.Module], bool]) -> None:
120128

121129
def initialize_hook(self, module):
122130
self.state = PyramidAttentionBroadcastState()
131+
return module
123132

124133
def new_forward(self, module: torch.nn.Module, *args, **kwargs) -> Any:
125-
args, kwargs = module._diffusers_hook.pre_forward(module, *args, **kwargs)
126-
127134
if self.skip_callback(module):
128-
output = module._pyramid_attention_broadcast_state.cache
135+
output = self.state.cache
129136
else:
130137
output = module._old_forward(*args, **kwargs)
131138

132139
self.state.cache = output
133140
self.state.iteration += 1
134-
return module._diffusers_hook.post_forward(module, output)
141+
return output
135142

136143
def reset_state(self, module: torch.nn.Module) -> None:
137-
module.state.reset()
144+
self.state.reset()
145+
return module
138146

139147

140148
def apply_pyramid_attention_broadcast(
@@ -168,7 +176,7 @@ def apply_pyramid_attention_broadcast(
168176
>>> config = PyramidAttentionBroadcastConfig(
169177
... spatial_attention_block_skip_range=2, spatial_attention_timestep_skip_range=(100, 800)
170178
... )
171-
>>> apply_pyramid_attention_broadcast(pipe, config)
179+
>>> apply_pyramid_attention_broadcast(pipe.transformer, config)
172180
```
173181
"""
174182
if config.current_timestep_callback is None:
@@ -192,9 +200,9 @@ def apply_pyramid_attention_broadcast(
192200
if not isinstance(submodule, _ATTENTION_CLASSES):
193201
continue
194202
if isinstance(submodule, Attention):
195-
_apply_pyramid_attention_broadcast_on_attention_class(name, module, config)
203+
_apply_pyramid_attention_broadcast_on_attention_class(name, submodule, config)
196204
if isinstance(submodule, MochiAttention):
197-
_apply_pyramid_attention_broadcast_on_mochi_attention_class(name, module, config)
205+
_apply_pyramid_attention_broadcast_on_mochi_attention_class(name, submodule, config)
198206

199207

200208
def _apply_pyramid_attention_broadcast_on_attention_class(
@@ -241,7 +249,9 @@ def _apply_pyramid_attention_broadcast_on_attention_class(
241249
return False
242250

243251
def skip_callback(module: torch.nn.Module) -> bool:
244-
pab_state = module._pyramid_attention_broadcast_state
252+
hook: PyramidAttentionBroadcastHook = module._diffusers_hook.get_hook("pyramid_attention_broadcast")
253+
pab_state: PyramidAttentionBroadcastState = hook.state
254+
245255
if pab_state.cache is None:
246256
return False
247257

src/diffusers/models/attention_processor.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -930,8 +930,6 @@ def __init__(
930930
self.out_dim = out_dim if out_dim is not None else query_dim
931931
self.out_context_dim = out_context_dim if out_context_dim else query_dim
932932
self.context_pre_only = context_pre_only
933-
# TODO(aryan): Maybe try to improve the checks in PAB instead
934-
self.is_cross_attention = False
935933

936934
self.heads = out_dim // dim_head if out_dim is not None else heads
937935

0 commit comments

Comments
 (0)