@@ -137,20 +137,34 @@ class PyramidAttentionBroadcastHook(ModelHook):
137137
138138    _is_stateful  =  True 
139139
140-     def  __init__ (self , skip_callback : Callable [[torch .nn .Module ], bool ]) ->  None :
140+     def  __init__ (
141+         self , timestep_skip_range : Tuple [int , int ], block_skip_range : int , current_timestep_callback : Callable [[], int ]
142+     ) ->  None :
141143        super ().__init__ ()
142144
143-         self .skip_callback  =  skip_callback 
145+         self .timestep_skip_range  =  timestep_skip_range 
146+         self .block_skip_range  =  block_skip_range 
147+         self .current_timestep_callback  =  current_timestep_callback 
144148
145149    def  initialize_hook (self , module ):
146150        self .state  =  PyramidAttentionBroadcastState ()
147151        return  module 
148152
149153    def  new_forward (self , module : torch .nn .Module , * args , ** kwargs ) ->  Any :
150-         if  self .skip_callback (module ):
151-             output  =  self .state .cache 
152-         else :
154+         is_within_timestep_range  =  (
155+             self .timestep_skip_range [0 ] <  self .current_timestep_callback () <  self .timestep_skip_range [1 ]
156+         )
157+         should_compute_attention  =  (
158+             self .state .cache  is  None 
159+             or  self .state .iteration  ==  0 
160+             or  not  is_within_timestep_range 
161+             or  self .state .iteration  %  self .block_skip_range  ==  0 
162+         )
163+ 
164+         if  should_compute_attention :
153165            output  =  module ._old_forward (* args , ** kwargs )
166+         else :
167+             output  =  self .state .cache 
154168
155169        self .state .cache  =  output 
156170        self .state .iteration  +=  1 
@@ -266,44 +280,35 @@ def _apply_pyramid_attention_broadcast_on_attention_class(
266280        )
267281        return  False 
268282
269-     def  skip_callback (module : torch .nn .Module ) ->  bool :
270-         hook : PyramidAttentionBroadcastHook  =  module ._diffusers_hook .get_hook ("pyramid_attention_broadcast" )
271-         pab_state : PyramidAttentionBroadcastState  =  hook .state 
272- 
273-         if  pab_state .cache  is  None :
274-             return  False 
275- 
276-         is_within_timestep_range  =  timestep_skip_range [0 ] <  config .current_timestep_callback () <  timestep_skip_range [1 ]
277-         if  not  is_within_timestep_range :
278-             # We are still not in the phase of inference where skipping attention is possible without minimal quality 
279-             # loss, as described in the paper. So, the attention computation cannot be skipped 
280-             return  False 
281- 
282-         should_compute_attention  =  pab_state .iteration  >  0  and  pab_state .iteration  %  block_skip_range  ==  0 
283-         return  not  should_compute_attention 
284- 
285283    logger .debug (f"Enabling Pyramid Attention Broadcast ({ block_type } { name }  )
286-     _apply_pyramid_attention_broadcast (module , skip_callback )
284+     _apply_pyramid_attention_broadcast_hook (
285+         module , timestep_skip_range , block_skip_range , config .current_timestep_callback 
286+     )
287287    return  True 
288288
289289
290- def  _apply_pyramid_attention_broadcast (
290+ def  _apply_pyramid_attention_broadcast_hook (
291291    module : Union [Attention , MochiAttention ],
292-     skip_callback : Callable [[torch .nn .Module ], bool ],
292+     timestep_skip_range : Tuple [int , int ],
293+     block_skip_range : int ,
294+     current_timestep_callback : Callable [[], int ],
293295):
294296    r""" 
295297    Apply [Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588) to a given torch.nn.Module. 
296298
297299    Args: 
298300        module (`torch.nn.Module`): 
299301            The module to apply Pyramid Attention Broadcast to. 
300-         skip_callback (`Callable[[nn.Module], bool]`): 
301-             A callback function that determines whether the attention computation should be skipped or not. The 
302-             callback function should return a boolean value, where `True` indicates that the attention computation 
303-             should be skipped, and `False` indicates that the attention computation should not be skipped. The callback 
304-             function will receive a torch.nn.Module containing a `_pyramid_attention_broadcast_state` attribute that 
305-             can should be used to retrieve and update the state of PAB for the given module. 
302+         timestep_skip_range (`Tuple[int, int]`): 
303+             The range of timesteps to skip in the attention layer. The attention computations will be conditionally 
304+             skipped if the current timestep is within the specified range. 
305+         block_skip_range (`int`): 
306+             The number of times a specific attention broadcast is skipped before computing the attention states to 
307+             re-use. If this is set to the value `N`, the attention computation will be skipped `N - 1` times (i.e., old 
308+             attention states will be re-used) before computing the new attention states again. 
309+         current_timestep_callback (`Callable[[], int]`): 
310+             A callback function that returns the current inference timestep. 
306311    """ 
307312    registry  =  HookRegistry .check_if_exists_or_initialize (module )
308-     hook  =  PyramidAttentionBroadcastHook (skip_callback )
313+     hook  =  PyramidAttentionBroadcastHook (timestep_skip_range ,  block_skip_range ,  current_timestep_callback )
309314    registry .register_hook (hook , "pyramid_attention_broadcast" )
0 commit comments