Skip to content

Commit 8d55997

Browse files
committed
added test qwenimage multicontrolnet
1 parent f7e210b commit 8d55997

File tree

1 file changed

+50
-1
lines changed

1 file changed

+50
-1
lines changed

tests/pipelines/qwenimage/test_qwenimage_controlnet.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,10 @@
2323
FlowMatchEulerDiscreteScheduler,
2424
QwenImageControlNetPipeline,
2525
QwenImageTransformer2DModel,
26+
QwenImageControlNetModel,
27+
QwenImageMultiControlNetModel
2628
)
27-
from diffusers.models.controlnets.controlnet_qwenimage import QwenImageControlNetModel
29+
2830
from diffusers.utils.testing_utils import enable_full_determinism, torch_device
2931
from diffusers.utils.torch_utils import randn_tensor
3032

@@ -213,6 +215,53 @@ def test_qwen_controlnet(self):
213215
generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
214216
self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3))
215217

218+
def test_qwen_controlnet_multicondition(self):
219+
device = "cpu"
220+
components = self.get_dummy_components()
221+
222+
components["controlnet"] = QwenImageMultiControlNetModel(
223+
[components["controlnet"]]
224+
)
225+
226+
pipe = self.pipeline_class(**components)
227+
pipe.to(device)
228+
pipe.set_progress_bar_config(disable=None)
229+
230+
inputs = self.get_dummy_inputs(device)
231+
control_image = inputs["control_image"]
232+
inputs["control_image"] = [control_image, control_image]
233+
inputs["controlnet_conditioning_scale"] = [0.5, 0.5]
234+
235+
image = pipe(**inputs).images
236+
generated_image = image[0]
237+
self.assertEqual(generated_image.shape, (3, 32, 32))
238+
# Expected slice from the generated image
239+
expected_slice = torch.tensor(
240+
[
241+
0.6239,
242+
0.6642,
243+
0.5768,
244+
0.6039,
245+
0.5270,
246+
0.5070,
247+
0.5006,
248+
0.5271,
249+
0.4506,
250+
0.3085,
251+
0.3435,
252+
0.5152,
253+
0.5096,
254+
0.5422,
255+
0.4286,
256+
0.5752
257+
]
258+
)
259+
260+
generated_slice = generated_image.flatten()
261+
generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
262+
self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3))
263+
264+
216265
def test_attention_slicing_forward_pass(
217266
self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3
218267
):

0 commit comments

Comments
 (0)