|  | 
| 23 | 23 |     FlowMatchEulerDiscreteScheduler, | 
| 24 | 24 |     QwenImageControlNetPipeline, | 
| 25 | 25 |     QwenImageTransformer2DModel, | 
|  | 26 | +    QwenImageControlNetModel,  | 
|  | 27 | +    QwenImageMultiControlNetModel | 
| 26 | 28 | ) | 
| 27 |  | -from diffusers.models.controlnets.controlnet_qwenimage import QwenImageControlNetModel | 
|  | 29 | + | 
| 28 | 30 | from diffusers.utils.testing_utils import enable_full_determinism, torch_device | 
| 29 | 31 | from diffusers.utils.torch_utils import randn_tensor | 
| 30 | 32 | 
 | 
| @@ -213,6 +215,53 @@ def test_qwen_controlnet(self): | 
| 213 | 215 |         generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]]) | 
| 214 | 216 |         self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3)) | 
| 215 | 217 | 
 | 
|  | 218 | +    def test_qwen_controlnet_multicondition(self): | 
|  | 219 | +        device = "cpu" | 
|  | 220 | +        components = self.get_dummy_components() | 
|  | 221 | + | 
|  | 222 | +        components["controlnet"] = QwenImageMultiControlNetModel( | 
|  | 223 | +            [components["controlnet"]] | 
|  | 224 | +        ) | 
|  | 225 | + | 
|  | 226 | +        pipe = self.pipeline_class(**components) | 
|  | 227 | +        pipe.to(device) | 
|  | 228 | +        pipe.set_progress_bar_config(disable=None) | 
|  | 229 | + | 
|  | 230 | +        inputs = self.get_dummy_inputs(device) | 
|  | 231 | +        control_image = inputs["control_image"]  | 
|  | 232 | +        inputs["control_image"] = [control_image, control_image]   | 
|  | 233 | +        inputs["controlnet_conditioning_scale"] = [0.5, 0.5]   | 
|  | 234 | + | 
|  | 235 | +        image = pipe(**inputs).images | 
|  | 236 | +        generated_image = image[0] | 
|  | 237 | +        self.assertEqual(generated_image.shape, (3, 32, 32)) | 
|  | 238 | +        # Expected slice from the generated image | 
|  | 239 | +        expected_slice = torch.tensor( | 
|  | 240 | +            [ | 
|  | 241 | +                0.6239,  | 
|  | 242 | +                0.6642,  | 
|  | 243 | +                0.5768,  | 
|  | 244 | +                0.6039,  | 
|  | 245 | +                0.5270,  | 
|  | 246 | +                0.5070,  | 
|  | 247 | +                0.5006,  | 
|  | 248 | +                0.5271,  | 
|  | 249 | +                0.4506, | 
|  | 250 | +                0.3085,  | 
|  | 251 | +                0.3435,  | 
|  | 252 | +                0.5152,  | 
|  | 253 | +                0.5096,  | 
|  | 254 | +                0.5422,  | 
|  | 255 | +                0.4286,  | 
|  | 256 | +                0.5752 | 
|  | 257 | +            ] | 
|  | 258 | +        ) | 
|  | 259 | + | 
|  | 260 | +        generated_slice = generated_image.flatten() | 
|  | 261 | +        generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]]) | 
|  | 262 | +        self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3)) | 
|  | 263 | +     | 
|  | 264 | + | 
| 216 | 265 |     def test_attention_slicing_forward_pass( | 
| 217 | 266 |         self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3 | 
| 218 | 267 |     ): | 
|  | 
0 commit comments