Skip to content

Commit d9fad00

Browse files
committed
reorder
1 parent c2e0e3b commit d9fad00

File tree

1 file changed

+25
-25
lines changed

1 file changed

+25
-25
lines changed

src/diffusers/pipelines/pyramid_attention_broadcast_utils.py

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,31 @@ def reset(self):
117117
self.cache = None
118118

119119

120+
class PyramidAttentionBroadcastHook(ModelHook):
121+
r"""A hook that applies Pyramid Attention Broadcast to a given module."""
122+
123+
_is_stateful = True
124+
125+
def __init__(self) -> None:
126+
super().__init__()
127+
128+
def new_forward(self, module: nn.Module, *args, **kwargs) -> Any:
129+
args, kwargs = module._diffusers_hook.pre_forward(module, *args, **kwargs)
130+
state: PyramidAttentionBroadcastState = module._pyramid_attention_broadcast_state
131+
132+
if state.skip_callback(module):
133+
output = module._pyramid_attention_broadcast_state.cache
134+
else:
135+
output = module._old_forward(*args, **kwargs)
136+
137+
state.cache = output
138+
state.iteration += 1
139+
return module._diffusers_hook.post_forward(module, output)
140+
141+
def reset_state(self, module: nn.Module) -> None:
142+
module._pyramid_attention_broadcast_state.reset()
143+
144+
120145
def apply_pyramid_attention_broadcast(
121146
pipeline: DiffusionPipeline,
122147
config: Optional[PyramidAttentionBroadcastConfig] = None,
@@ -275,28 +300,3 @@ def _apply_pyramid_attention_broadcast_on_mochi_attention_class(
275300
) -> bool:
276301
# The same logic as Attention class works here, so just use that for now
277302
return _apply_pyramid_attention_broadcast_on_attention_class(pipeline, name, module, config)
278-
279-
280-
class PyramidAttentionBroadcastHook(ModelHook):
281-
r"""A hook that applies Pyramid Attention Broadcast to a given module."""
282-
283-
_is_stateful = True
284-
285-
def __init__(self) -> None:
286-
super().__init__()
287-
288-
def new_forward(self, module: nn.Module, *args, **kwargs) -> Any:
289-
args, kwargs = module._diffusers_hook.pre_forward(module, *args, **kwargs)
290-
state: PyramidAttentionBroadcastState = module._pyramid_attention_broadcast_state
291-
292-
if state.skip_callback(module):
293-
output = module._pyramid_attention_broadcast_state.cache
294-
else:
295-
output = module._old_forward(*args, **kwargs)
296-
297-
state.cache = output
298-
state.iteration += 1
299-
return module._diffusers_hook.post_forward(module, output)
300-
301-
def reset_state(self, module: nn.Module) -> None:
302-
module._pyramid_attention_broadcast_state.reset()

0 commit comments

Comments
 (0)