2121import  torch 
2222from  transformers  import  AutoTokenizer , BertModel , T5EncoderModel 
2323
24- from  diffusers  import  ( AutoencoderKL , DDPMScheduler , HunyuanDiT2DModel ,
25-                         HunyuanDiTPipeline ) 
26- from   diffusers . utils . testing_utils   import  ( enable_full_determinism ,
27-                                             numpy_cosine_similarity_distance ,
28-                                             require_torch_accelerator ,  slow ,
29-                                             torch_device ) 
30- 
31- from  .. pipeline_params   import  ( TEXT_TO_IMAGE_BATCH_PARAMS , 
32-                                 TEXT_TO_IMAGE_IMAGE_PARAMS , 
33-                                 TEXT_TO_IMAGE_PARAMS ) 
24+ from  diffusers  import  AutoencoderKL , DDPMScheduler , HunyuanDiT2DModel ,  HunyuanDiTPipeline 
25+ from   diffusers . utils . testing_utils   import  ( 
26+      enable_full_determinism ,
27+     numpy_cosine_similarity_distance ,
28+     require_torch_accelerator ,
29+     slow , 
30+      torch_device , 
31+ ) 
32+ 
33+ from  .. pipeline_params   import   TEXT_TO_IMAGE_BATCH_PARAMS ,  TEXT_TO_IMAGE_IMAGE_PARAMS ,  TEXT_TO_IMAGE_PARAMS 
3434from  ..test_pipelines_common  import  (
35-     PipelineTesterMixin , check_qkv_fusion_matches_attn_procs_length ,
36-     check_qkv_fusion_processors_exist , to_np )
35+     PipelineTesterMixin ,
36+     check_qkv_fusion_matches_attn_procs_length ,
37+     check_qkv_fusion_processors_exist ,
38+     to_np ,
39+ )
40+ 
3741
3842enable_full_determinism ()
3943
@@ -170,9 +174,9 @@ def test_fused_qkv_projections(self):
170174        # TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added 
171175        # to the pipeline level. 
172176        pipe .transformer .fuse_qkv_projections ()
173-         assert  check_qkv_fusion_processors_exist (
174-             pipe . transformer 
175-         ),  "Something wrong with the fused attention processors. Expected all the attention processors to be fused." 
177+         assert  check_qkv_fusion_processors_exist (pipe . transformer ), ( 
178+             "Something wrong with the fused attention processors. Expected all the attention processors to be fused." 
179+         )
176180        assert  check_qkv_fusion_matches_attn_procs_length (
177181            pipe .transformer , pipe .transformer .original_attn_processors 
178182        ), "Something wrong with the attention processors concerning the fused QKV projections." 
@@ -188,15 +192,15 @@ def test_fused_qkv_projections(self):
188192        image_disabled  =  pipe (** inputs )[0 ]
189193        image_slice_disabled  =  image_disabled [0 , - 3 :, - 3 :, - 1 ]
190194
191-         assert  np .allclose (
192-             original_image_slice ,  image_slice_fused ,  atol = 1e-2 ,  rtol = 1e-2 
193-         ),  "Fusion of QKV projections shouldn't affect the outputs." 
194-         assert  np .allclose (
195-             image_slice_fused ,  image_slice_disabled ,  atol = 1e-2 ,  rtol = 1e-2 
196-         ),  "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled." 
197-         assert  np .allclose (
198-             original_image_slice ,  image_slice_disabled ,  atol = 1e-2 ,  rtol = 1e-2 
199-         ),  "Original outputs should match when fused QKV projections are disabled." 
195+         assert  np .allclose (original_image_slice ,  image_slice_fused ,  atol = 1e-2 ,  rtol = 1e-2 ), ( 
196+             "Fusion of QKV projections shouldn't affect the outputs." 
197+         )
198+         assert  np .allclose (image_slice_fused ,  image_slice_disabled ,  atol = 1e-2 ,  rtol = 1e-2 ), ( 
199+             "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled." 
200+         )
201+         assert  np .allclose (original_image_slice ,  image_slice_disabled ,  atol = 1e-2 ,  rtol = 1e-2 ), ( 
202+             "Original outputs should match when fused QKV projections are disabled." 
203+         )
200204
201205    @unittest .skip ( 
202206        "Test not supported as `encode_prompt` is called two times separately which deivates from about 99% of the pipelines we have."  
0 commit comments