1313# limitations under the License. 
1414
1515from  dataclasses  import  dataclass 
16- from  typing  import  Callable ,  Optional ,  Protocol , Tuple 
16+ from  typing  import  Any ,  Callable ,  Optional , Tuple 
1717
1818import  torch .nn  as  nn 
1919
2020from  ..models .attention_processor  import  Attention 
21- from  ..models .hooks  import  PyramidAttentionBroadcastHook , add_hook_to_module 
21+ from  ..models .hooks  import  ModelHook , add_hook_to_module 
2222from  ..utils  import  logging 
2323from  .pipeline_utils  import  DiffusionPipeline 
2424
2828
2929_ATTENTION_CLASSES  =  (Attention ,)
3030
31- _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS  =  ("blocks" , "transformer_blocks" )
31+ _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS  =  ("blocks" , "transformer_blocks" ,  "single_transformer_blocks" )
3232_TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS  =  ("temporal_transformer_blocks" ,)
3333_CROSS_ATTENTION_BLOCK_IDENTIFIERS  =  ("blocks" , "transformer_blocks" )
3434
@@ -96,21 +96,15 @@ class PyramidAttentionBroadcastState:
9696
9797    def  __init__ (self ) ->  None :
9898        self .iteration  =  0 
99+         self .cache  =  None 
100+ 
101+     def  update_state (self , output : Any ) ->  None :
102+         self .iteration  +=  1 
103+         self .cache  =  output 
99104
100105    def  reset_state (self ):
101106        self .iteration  =  0 
102- 
103- 
104- class  nnModulePAB (Protocol ):
105-     r""" 
106-     Type hint for a torch.nn.Module that contains a `_pyramid_attention_broadcast_state` attribute. 
107- 
108-     Attributes: 
109-         _pyramid_attention_broadcast_state (`PyramidAttentionBroadcastState`): 
110-             The state of Pyramid Attention Broadcast. 
111-     """ 
112- 
113-     _pyramid_attention_broadcast_state : PyramidAttentionBroadcastState 
107+         self .cache  =  None 
114108
115109
116110def  apply_pyramid_attention_broadcast (
@@ -247,14 +241,15 @@ def _apply_pyramid_attention_broadcast_on_attention_class(
247241        )
248242        return 
249243
250-     def  skip_callback (module : nnModulePAB ) ->  bool :
244+     def  skip_callback (module : nn . Module ) ->  bool :
251245        pab_state  =  module ._pyramid_attention_broadcast_state 
252-         current_timestep  =  pipeline ._current_timestep 
253-         is_within_timestep_range  =  timestep_skip_range [0 ] <  current_timestep  <  timestep_skip_range [1 ]
246+         if  pab_state .cache  is  None :
247+             return  False 
248+ 
249+         is_within_timestep_range  =  timestep_skip_range [0 ] <  pipeline ._current_timestep  <  timestep_skip_range [1 ]
254250
255251        if  is_within_timestep_range :
256252            should_compute_attention  =  pab_state .iteration  >  0  and  pab_state .iteration  %  block_skip_range  ==  0 
257-             pab_state .iteration  +=  1 
258253            return  not  should_compute_attention 
259254
260255        # We are still not in the phase of inference where skipping attention is possible without minimal quality 
@@ -263,3 +258,24 @@ def skip_callback(module: nnModulePAB) -> bool:
263258
264259    logger .debug (f"Enabling Pyramid Attention Broadcast ({ block_type } { name }  )
265260    apply_pyramid_attention_broadcast_on_module (module , skip_callback )
261+ 
262+ 
263+ class  PyramidAttentionBroadcastHook (ModelHook ):
264+     def  __init__ (self , skip_callback : Callable [[nn .Module ], bool ]) ->  None :
265+         super ().__init__ ()
266+ 
267+         self .skip_callback  =  skip_callback 
268+ 
269+     def  new_forward (self , module : nn .Module , * args , ** kwargs ) ->  Any :
270+         args , kwargs  =  module ._diffusers_hook .pre_forward (module , * args , ** kwargs )
271+ 
272+         if  self .skip_callback (module ):
273+             output  =  module ._pyramid_attention_broadcast_state .cache 
274+         else :
275+             output  =  module ._old_forward (* args , ** kwargs )
276+ 
277+         return  module ._diffusers_hook .post_forward (module , output )
278+ 
279+     def  post_forward (self , module : nn .Module , output : Any ) ->  Any :
280+         module ._pyramid_attention_broadcast_state .update_state (output )
281+         return  output 
0 commit comments