Skip to content

Commit 0232b29

Browse files
Apply style fixes
1 parent 8d55997 commit 0232b29

File tree

1 file changed

+21
-25
lines changed

1 file changed

+21
-25
lines changed

tests/pipelines/qwenimage/test_qwenimage_controlnet.py

Lines changed: 21 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,11 @@
2121
from diffusers import (
2222
AutoencoderKLQwenImage,
2323
FlowMatchEulerDiscreteScheduler,
24+
QwenImageControlNetModel,
2425
QwenImageControlNetPipeline,
26+
QwenImageMultiControlNetModel,
2527
QwenImageTransformer2DModel,
26-
QwenImageControlNetModel,
27-
QwenImageMultiControlNetModel
2828
)
29-
3029
from diffusers.utils.testing_utils import enable_full_determinism, torch_device
3130
from diffusers.utils.torch_utils import randn_tensor
3231

@@ -219,48 +218,45 @@ def test_qwen_controlnet_multicondition(self):
219218
device = "cpu"
220219
components = self.get_dummy_components()
221220

222-
components["controlnet"] = QwenImageMultiControlNetModel(
223-
[components["controlnet"]]
224-
)
221+
components["controlnet"] = QwenImageMultiControlNetModel([components["controlnet"]])
225222

226223
pipe = self.pipeline_class(**components)
227224
pipe.to(device)
228225
pipe.set_progress_bar_config(disable=None)
229226

230227
inputs = self.get_dummy_inputs(device)
231-
control_image = inputs["control_image"]
232-
inputs["control_image"] = [control_image, control_image]
233-
inputs["controlnet_conditioning_scale"] = [0.5, 0.5]
228+
control_image = inputs["control_image"]
229+
inputs["control_image"] = [control_image, control_image]
230+
inputs["controlnet_conditioning_scale"] = [0.5, 0.5]
234231

235232
image = pipe(**inputs).images
236233
generated_image = image[0]
237234
self.assertEqual(generated_image.shape, (3, 32, 32))
238235
# Expected slice from the generated image
239236
expected_slice = torch.tensor(
240237
[
241-
0.6239,
242-
0.6642,
243-
0.5768,
244-
0.6039,
245-
0.5270,
246-
0.5070,
247-
0.5006,
248-
0.5271,
238+
0.6239,
239+
0.6642,
240+
0.5768,
241+
0.6039,
242+
0.5270,
243+
0.5070,
244+
0.5006,
245+
0.5271,
249246
0.4506,
250-
0.3085,
251-
0.3435,
252-
0.5152,
253-
0.5096,
254-
0.5422,
255-
0.4286,
256-
0.5752
247+
0.3085,
248+
0.3435,
249+
0.5152,
250+
0.5096,
251+
0.5422,
252+
0.4286,
253+
0.5752,
257254
]
258255
)
259256

260257
generated_slice = generated_image.flatten()
261258
generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
262259
self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3))
263-
264260

265261
def test_attention_slicing_forward_pass(
266262
self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3

0 commit comments

Comments
 (0)