2121from  diffusers  import  (
2222    AutoencoderKLQwenImage ,
2323    FlowMatchEulerDiscreteScheduler ,
24+     QwenImageControlNetPipeline ,
2425    QwenImageTransformer2DModel ,
25-     QwenImageControlNetPipeline 
2626)
2727from  diffusers .models .controlnets .controlnet_qwenimage  import  QwenImageControlNetModel 
28- from  diffusers .utils  import  load_image 
2928from  diffusers .utils .testing_utils  import  enable_full_determinism , torch_device 
3029from  diffusers .utils .torch_utils  import  randn_tensor 
3130
32- from  ..pipeline_params  import  TEXT_TO_IMAGE_BATCH_PARAMS ,  TEXT_TO_IMAGE_IMAGE_PARAMS ,  TEXT_TO_IMAGE_PARAMS 
33- from  ..test_pipelines_common  import  PipelineTesterMixin , to_np ,  FluxIPAdapterTesterMixin 
31+ from  ..pipeline_params  import  TEXT_TO_IMAGE_PARAMS 
32+ from  ..test_pipelines_common  import  PipelineTesterMixin , to_np 
3433
3534
3635enable_full_determinism ()
3736
3837
39- 
4038class  QwenControlNetPipelineFastTests (PipelineTesterMixin , unittest .TestCase ):
4139    pipeline_class  =  QwenImageControlNetPipeline 
42-     params  =  (TEXT_TO_IMAGE_PARAMS  |  frozenset (["control_image" , "controlnet_conditioning_scale" ])) -  {"cross_attention_kwargs" }
40+     params  =  (TEXT_TO_IMAGE_PARAMS  |  frozenset (["control_image" , "controlnet_conditioning_scale" ])) -  {
41+         "cross_attention_kwargs" 
42+     }
4343    batch_params  =  frozenset (["prompt" , "negative_prompt" , "control_image" ])
4444    image_params  =  frozenset (["control_image" ])
4545    image_latents_params  =  frozenset (["latents" ])
@@ -56,7 +56,7 @@ class QwenControlNetPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
5656            "callback_on_step_end_tensor_inputs" ,
5757        ]
5858    )
59-   
59+ 
6060    supports_dduf  =  False 
6161    test_xformers_attention  =  True 
6262    test_layerwise_casting  =  True 
@@ -75,31 +75,29 @@ def get_dummy_components(self):
7575            guidance_embeds = False ,
7676            axes_dims_rope = (8 , 4 , 4 ),
7777        )
78-     
7978
8079        torch .manual_seed (0 )
8180        controlnet  =  QwenImageControlNetModel (
8281            patch_size = 2 ,
83-             in_channels = 16 ,         
84-             out_channels = 4 ,          
85-             num_layers = 2 ,             
86-             attention_head_dim = 16 ,    
87-             num_attention_heads = 3 ,   
88-             joint_attention_dim = 16 ,  
89-             axes_dims_rope = (8 , 4 , 4 )  
82+             in_channels = 16 ,
83+             out_channels = 4 ,
84+             num_layers = 2 ,
85+             attention_head_dim = 16 ,
86+             num_attention_heads = 3 ,
87+             joint_attention_dim = 16 ,
88+             axes_dims_rope = (8 , 4 , 4 ), 
9089        )
9190
92- 
9391        torch .manual_seed (0 )
94-         z_dim  =  4    
92+         z_dim  =  4 
9593        vae  =  AutoencoderKLQwenImage (
96-             base_dim = z_dim  *  6 ,           
97-             z_dim = z_dim ,                  
98-             dim_mult = [1 , 2 , 4 ],           
99-             num_res_blocks = 1 ,             
94+             base_dim = z_dim  *  6 ,
95+             z_dim = z_dim ,
96+             dim_mult = [1 , 2 , 4 ],
97+             num_res_blocks = 1 ,
10098            temperal_downsample = [False , True ],
101-             latents_mean = [0.0 ] *  z_dim ,   
102-             latents_std = [1.0 ] *  z_dim ,    
99+             latents_mean = [0.0 ] *  z_dim ,
100+             latents_std = [1.0 ] *  z_dim ,
103101        )
104102
105103        torch .manual_seed (0 )
@@ -191,13 +189,30 @@ def test_qwen_controlnet(self):
191189
192190        # Expected slice from the generated image 
193191        expected_slice  =  torch .tensor (
194-             [0.4726 , 0.5549 , 0.6324 , 0.6548 , 0.4968 , 0.4639 , 0.4749 , 0.4898 , 0.4725 , 0.4645 , 0.4435 , 0.3339 , 0.3400 , 0.4630 , 0.3879 , 0.4406 ]
192+             [
193+                 0.4726 ,
194+                 0.5549 ,
195+                 0.6324 ,
196+                 0.6548 ,
197+                 0.4968 ,
198+                 0.4639 ,
199+                 0.4749 ,
200+                 0.4898 ,
201+                 0.4725 ,
202+                 0.4645 ,
203+                 0.4435 ,
204+                 0.3339 ,
205+                 0.3400 ,
206+                 0.4630 ,
207+                 0.3879 ,
208+                 0.4406 ,
209+             ]
195210        )
196211
197212        generated_slice  =  generated_image .flatten ()
198213        generated_slice  =  torch .cat ([generated_slice [:8 ], generated_slice [- 8 :]])
199214        self .assertTrue (torch .allclose (generated_slice , expected_slice , atol = 1e-3 ))
200-          
215+ 
201216    def  test_attention_slicing_forward_pass (
202217        self , test_max_difference = True , test_mean_pixel_difference = True , expected_max_diff = 1e-3 
203218    ):
@@ -277,5 +292,3 @@ def test_vae_tiling(self, expected_diff_max: float = 0.2):
277292            expected_diff_max ,
278293            "VAE tiling should not affect the inference results" ,
279294        )
280-     
281-     
0 commit comments