3030    UNet2DConditionModel ,
3131    apply_pyramid_attention_broadcast ,
3232)
33+ from  diffusers .hooks .pyramid_attention_broadcast  import  PyramidAttentionBroadcastHook 
3334from  diffusers .image_processor  import  VaeImageProcessor 
3435from  diffusers .loaders  import  FluxIPAdapterMixin , IPAdapterMixin 
3536from  diffusers .models .attention_processor  import  AttnProcessor 
3839from  diffusers .models .unets .unet_i2vgen_xl  import  I2VGenXLUNet 
3940from  diffusers .models .unets .unet_motion_model  import  UNetMotionModel 
4041from  diffusers .pipelines .pipeline_utils  import  StableDiffusionMixin 
41- from  diffusers .pipelines .pyramid_attention_broadcast_utils  import  PyramidAttentionBroadcastHook 
4242from  diffusers .schedulers  import  KarrasDiffusionSchedulers 
4343from  diffusers .utils  import  logging 
4444from  diffusers .utils .import_utils  import  is_xformers_available 
@@ -2298,7 +2298,9 @@ def test_pyramid_attention_broadcast_layers(self):
22982298        pipe  =  self .pipeline_class (** components )
22992299        pipe .set_progress_bar_config (disable = None )
23002300
2301-         apply_pyramid_attention_broadcast (pipe , self .pab_config )
2301+         self .pab_config .current_timestep_callback  =  lambda : pipe ._current_timestep 
2302+         denoiser  =  pipe .transformer  if  hasattr (pipe , "transformer" ) else  pipe .unet 
2303+         apply_pyramid_attention_broadcast (denoiser , self .pab_config )
23022304
23032305        expected_hooks  =  0 
23042306        if  self .pab_config .spatial_attention_block_skip_range  is  not None :
@@ -2312,30 +2314,30 @@ def test_pyramid_attention_broadcast_layers(self):
23122314        count  =  0 
23132315        for  module  in  denoiser .modules ():
23142316            if  hasattr (module , "_diffusers_hook" ):
2317+                 hook  =  module ._diffusers_hook .get_hook ("pyramid_attention_broadcast" )
2318+                 if  hook  is  None :
2319+                     continue 
23152320                count  +=  1 
23162321                self .assertTrue (
2317-                     isinstance (module . _diffusers_hook , PyramidAttentionBroadcastHook ),
2322+                     isinstance (hook , PyramidAttentionBroadcastHook ),
23182323                    "Hook should be of type PyramidAttentionBroadcastHook." ,
23192324                )
2320-                 self .assertTrue (
2321-                     hasattr (module , "_pyramid_attention_broadcast_state" ),
2322-                     "PAB state should be initialized when enabled." ,
2323-                 )
2324-                 self .assertTrue (
2325-                     module ._pyramid_attention_broadcast_state .cache  is  None , "Cache should be None at initialization." 
2326-                 )
2325+                 self .assertTrue (hook .state .cache  is  None , "Cache should be None at initialization." )
23272326        self .assertEqual (count , expected_hooks , "Number of hooks should match the expected number." )
23282327
23292328        # Perform dummy inference step to ensure state is updated 
23302329        def  pab_state_check_callback (pipe , i , t , kwargs ):
23312330            for  module  in  denoiser .modules ():
23322331                if  hasattr (module , "_diffusers_hook" ):
2332+                     hook  =  module ._diffusers_hook .get_hook ("pyramid_attention_broadcast" )
2333+                     if  hook  is  None :
2334+                         continue 
23332335                    self .assertTrue (
2334-                         module . _pyramid_attention_broadcast_state .cache  is  not None ,
2336+                         hook . state .cache  is  not None ,
23352337                        "Cache should have updated during inference." ,
23362338                    )
23372339                    self .assertTrue (
2338-                         module . _pyramid_attention_broadcast_state .iteration  ==  i  +  1 ,
2340+                         hook . state .iteration  ==  i  +  1 ,
23392341                        "Hook iteration state should have updated during inference." ,
23402342                    )
23412343            return  {}
@@ -2348,12 +2350,15 @@ def pab_state_check_callback(pipe, i, t, kwargs):
23482350        # After inference, reset_stateful_hooks is called within the pipeline, which should have reset the states 
23492351        for  module  in  denoiser .modules ():
23502352            if  hasattr (module , "_diffusers_hook" ):
2353+                 hook  =  module ._diffusers_hook .get_hook ("pyramid_attention_broadcast" )
2354+                 if  hook  is  None :
2355+                     continue 
23512356                self .assertTrue (
2352-                     module . _pyramid_attention_broadcast_state .cache  is  None ,
2357+                     hook . state .cache  is  None ,
23532358                    "Cache should be reset to None after inference." ,
23542359                )
23552360                self .assertTrue (
2356-                     module . _pyramid_attention_broadcast_state .iteration  ==  0 ,
2361+                     hook . state .iteration  ==  0 ,
23572362                    "Iteration should be reset to 0 after inference." ,
23582363                )
23592364
@@ -2374,7 +2379,9 @@ def test_pyramid_attention_broadcast_inference(self, expected_atol: float = 0.2)
23742379        original_image_slice  =  output .flatten ()
23752380        original_image_slice  =  np .concatenate ((original_image_slice [:8 ], original_image_slice [- 8 :]))
23762381
2377-         apply_pyramid_attention_broadcast (pipe , self .pab_config )
2382+         self .pab_config .current_timestep_callback  =  lambda : pipe ._current_timestep 
2383+         denoiser  =  pipe .transformer  if  hasattr (pipe , "transformer" ) else  pipe .unet 
2384+         apply_pyramid_attention_broadcast (denoiser , self .pab_config )
23782385
23792386        inputs  =  self .get_dummy_inputs (device )
23802387        inputs ["num_inference_steps" ] =  4 
0 commit comments