@@ -221,16 +221,19 @@ def _apply_pyramid_attention_broadcast_on_attention_class(
221221        and  not  module .is_cross_attention 
222222    )
223223
224-     block_skip_range , timestep_skip_range   =  None , None 
224+     block_skip_range , timestep_skip_range ,  block_type   =   None ,  None , None 
225225    if  is_spatial_self_attention :
226226        block_skip_range  =  config .spatial_attention_block_skip_range 
227227        timestep_skip_range  =  config .spatial_attention_timestep_skip_range 
228+         block_type  =  "spatial" 
228229    elif  is_temporal_self_attention :
229230        block_skip_range  =  config .temporal_attention_block_skip_range 
230231        timestep_skip_range  =  config .temporal_attention_timestep_skip_range 
232+         block_type  =  "temporal" 
231233    elif  is_cross_attention :
232234        block_skip_range  =  config .cross_attention_block_skip_range 
233235        timestep_skip_range  =  config .cross_attention_timestep_skip_range 
236+         block_type  =  "cross" 
234237
235238    if  block_skip_range  is  None  or  timestep_skip_range  is  None :
236239        logger .warning (f"Unable to apply Pyramid Attention Broadcast to the selected layer: { name }  )
@@ -250,5 +253,5 @@ def skip_callback(module: nnModulePAB) -> bool:
250253        # loss, as described in the paper. So, the attention computation cannot be skipped 
251254        return  False 
252255
253-     logger .debug (f"Enabling Pyramid Attention Broadcast in layer: { name }  )
256+     logger .debug (f"Enabling Pyramid Attention Broadcast ( { block_type } )  in layer: { name }  )
254257    apply_pyramid_attention_broadcast_on_module (module , skip_callback )
0 commit comments