@@ -106,6 +106,14 @@ def __init__(self) -> None:
106106 def reset (self ):
107107 self .iteration = 0
108108 self .cache = None
109+
110+ def __repr__ (self ):
111+ cache_repr = ""
112+ if self .cache is None :
113+ cache_repr = "None"
114+ else :
115+ cache_repr = f"Tensor(shape={ self .cache .shape } , dtype={ self .cache .dtype } )"
116+ return f"PyramidAttentionBroadcastState(iteration={ self .iteration } , cache={ cache_repr } )"
109117
110118
111119class PyramidAttentionBroadcastHook (ModelHook ):
@@ -120,21 +128,21 @@ def __init__(self, skip_callback: Callable[[torch.nn.Module], bool]) -> None:
120128
121129 def initialize_hook (self , module ):
122130 self .state = PyramidAttentionBroadcastState ()
131+ return module
123132
124133 def new_forward (self , module : torch .nn .Module , * args , ** kwargs ) -> Any :
125- args , kwargs = module ._diffusers_hook .pre_forward (module , * args , ** kwargs )
126-
127134 if self .skip_callback (module ):
128- output = module . _pyramid_attention_broadcast_state .cache
135+ output = self . state .cache
129136 else :
130137 output = module ._old_forward (* args , ** kwargs )
131138
132139 self .state .cache = output
133140 self .state .iteration += 1
134- return module . _diffusers_hook . post_forward ( module , output )
141+ return output
135142
136143 def reset_state (self , module : torch .nn .Module ) -> None :
137- module .state .reset ()
144+ self .state .reset ()
145+ return module
138146
139147
140148def apply_pyramid_attention_broadcast (
@@ -168,7 +176,7 @@ def apply_pyramid_attention_broadcast(
168176 >>> config = PyramidAttentionBroadcastConfig(
169177 ... spatial_attention_block_skip_range=2, spatial_attention_timestep_skip_range=(100, 800)
170178 ... )
171- >>> apply_pyramid_attention_broadcast(pipe, config)
179+ >>> apply_pyramid_attention_broadcast(pipe.transformer , config)
172180 ```
173181 """
174182 if config .current_timestep_callback is None :
@@ -192,9 +200,9 @@ def apply_pyramid_attention_broadcast(
192200 if not isinstance (submodule , _ATTENTION_CLASSES ):
193201 continue
194202 if isinstance (submodule , Attention ):
195- _apply_pyramid_attention_broadcast_on_attention_class (name , module , config )
203+ _apply_pyramid_attention_broadcast_on_attention_class (name , submodule , config )
196204 if isinstance (submodule , MochiAttention ):
197- _apply_pyramid_attention_broadcast_on_mochi_attention_class (name , module , config )
205+ _apply_pyramid_attention_broadcast_on_mochi_attention_class (name , submodule , config )
198206
199207
200208def _apply_pyramid_attention_broadcast_on_attention_class (
@@ -241,7 +249,9 @@ def _apply_pyramid_attention_broadcast_on_attention_class(
241249 return False
242250
243251 def skip_callback (module : torch .nn .Module ) -> bool :
244- pab_state = module ._pyramid_attention_broadcast_state
252+ hook : PyramidAttentionBroadcastHook = module ._diffusers_hook .get_hook ("pyramid_attention_broadcast" )
253+ pab_state : PyramidAttentionBroadcastState = hook .state
254+
245255 if pab_state .cache is None :
246256 return False
247257
0 commit comments