@@ -117,6 +117,31 @@ def reset(self):
117117 self .cache = None
118118
119119
120+ class PyramidAttentionBroadcastHook (ModelHook ):
121+ r"""A hook that applies Pyramid Attention Broadcast to a given module."""
122+
123+ _is_stateful = True
124+
125+ def __init__ (self ) -> None :
126+ super ().__init__ ()
127+
128+ def new_forward (self , module : nn .Module , * args , ** kwargs ) -> Any :
129+ args , kwargs = module ._diffusers_hook .pre_forward (module , * args , ** kwargs )
130+ state : PyramidAttentionBroadcastState = module ._pyramid_attention_broadcast_state
131+
132+ if state .skip_callback (module ):
133+ output = module ._pyramid_attention_broadcast_state .cache
134+ else :
135+ output = module ._old_forward (* args , ** kwargs )
136+
137+ state .cache = output
138+ state .iteration += 1
139+ return module ._diffusers_hook .post_forward (module , output )
140+
141+ def reset_state (self , module : nn .Module ) -> None :
142+ module ._pyramid_attention_broadcast_state .reset ()
143+
144+
120145def apply_pyramid_attention_broadcast (
121146 pipeline : DiffusionPipeline ,
122147 config : Optional [PyramidAttentionBroadcastConfig ] = None ,
@@ -275,28 +300,3 @@ def _apply_pyramid_attention_broadcast_on_mochi_attention_class(
275300) -> bool :
276301 # The same logic as Attention class works here, so just use that for now
277302 return _apply_pyramid_attention_broadcast_on_attention_class (pipeline , name , module , config )
278-
279-
280- class PyramidAttentionBroadcastHook (ModelHook ):
281- r"""A hook that applies Pyramid Attention Broadcast to a given module."""
282-
283- _is_stateful = True
284-
285- def __init__ (self ) -> None :
286- super ().__init__ ()
287-
288- def new_forward (self , module : nn .Module , * args , ** kwargs ) -> Any :
289- args , kwargs = module ._diffusers_hook .pre_forward (module , * args , ** kwargs )
290- state : PyramidAttentionBroadcastState = module ._pyramid_attention_broadcast_state
291-
292- if state .skip_callback (module ):
293- output = module ._pyramid_attention_broadcast_state .cache
294- else :
295- output = module ._old_forward (* args , ** kwargs )
296-
297- state .cache = output
298- state .iteration += 1
299- return module ._diffusers_hook .post_forward (module , output )
300-
301- def reset_state (self , module : nn .Module ) -> None :
302- module ._pyramid_attention_broadcast_state .reset ()
0 commit comments