Skip to content

Commit 66f835b

Browse files
committed
add multi-cn
1 parent 3316b98 commit 66f835b

File tree

2 files changed

+73
-11
lines changed

2 files changed

+73
-11
lines changed

src/diffusers/models/controlnets/controlnet_qwenimage.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,7 @@ def forward(
295295
)
296296

297297

298-
class QwenImageMultiControlNetModel(ModelMixin):
298+
class QwenImageMultiControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin):
299299
r"""
300300
`QwenImageMultiControlNetModel` wrapper class for Multi-QwenImageControlNetModel
301301
@@ -356,4 +356,4 @@ def forward(
356356
else:
357357
raise ValueError("QwenImageMultiControlNetModel only supports controlnet-union now.")
358358

359-
return control_block_samples
359+
return control_block_samples

src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet.py

Lines changed: 71 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,9 @@
4545
```py
4646
>>> import torch
4747
>>> from diffusers.utils import load_image
48-
>>> from diffusers import QwenImageControlNetModel, QwenImageControlNetPipeline
48+
>>> from diffusers import QwenImageControlNetModel, QwenImageMultiControlNetModel, QwenImageControlNetPipeline
4949
50+
>>> # QwenImageControlNetModel
5051
>>> controlnet = QwenImageControlNetModel.from_pretrained("InstantX/Qwen-Image-ControlNet-Union", torch_dtype=torch.bfloat16)
5152
>>> pipe = QwenImageControlNetPipeline.from_pretrained("Qwen/Qwen-Image", controlnet=controlnet, torch_dtype=torch.bfloat16)
5253
>>> pipe.to("cuda")
@@ -57,6 +58,19 @@
5758
>>> # Refer to the pipeline documentation for more details.
5859
>>> image = pipe(prompt, negative_prompt=negative_prompt, control_image=control_image, controlnet_conditioning_scale=1.0, num_inference_steps=30, true_cfg_scale=4.0).images[0]
5960
>>> image.save("qwenimage_cn_union.png")
61+
62+
>>> # QwenImageMultiControlNetModel
63+
>>> controlnet = QwenImageControlNetModel.from_pretrained("InstantX/Qwen-Image-ControlNet-Union", torch_dtype=torch.bfloat16)
64+
>>> controlnet = QwenImageMultiControlNetModel([controlnet])
65+
>>> pipe = QwenImageControlNetPipeline.from_pretrained("Qwen/Qwen-Image", controlnet=controlnet, torch_dtype=torch.bfloat16)
66+
>>> pipe.to("cuda")
67+
>>> prompt = "Aesthetics art, traditional asian pagoda, elaborate golden accents, sky blue and white color palette, swirling cloud pattern, digital illustration, east asian architecture, ornamental rooftop, intricate detailing on building, cultural representation."
68+
>>> negative_prompt = " "
69+
>>> control_image = load_image("https://huggingface.co/InstantX/Qwen-Image-ControlNet-Union/resolve/main/conds/canny.png")
70+
>>> # Depending on the variant being used, the pipeline call will slightly vary.
71+
>>> # Refer to the pipeline documentation for more details.
72+
>>> image = pipe(prompt, negative_prompt=negative_prompt, control_image=[control_image, control_image], controlnet_conditioning_scale=[0.5, 0.5], num_inference_steps=30, true_cfg_scale=4.0).images[0]
73+
>>> image.save("qwenimage_cn_union_multi.png")
6074
```
6175
"""
6276

@@ -177,7 +191,9 @@ def __init__(
177191
text_encoder: Qwen2_5_VLForConditionalGeneration,
178192
tokenizer: Qwen2Tokenizer,
179193
transformer: QwenImageTransformer2DModel,
180-
controlnet: QwenImageControlNetModel,
194+
controlnet: Union[
195+
QwenImageControlNetModel, QwenImageMultiControlNetModel
196+
],
181197
):
182198
super().__init__()
183199

@@ -589,7 +605,7 @@ def __call__(
589605
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
590606
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
591607
elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
592-
mult = len(self.controlnet.nets) if isinstance(self.controlnet, QwenImageMultiControlNetModel) else 1
608+
mult = len(control_image) if isinstance(self.controlnet, QwenImageMultiControlNetModel) else 1
593609
control_guidance_start, control_guidance_end = (
594610
mult * [control_guidance_start],
595611
mult * [control_guidance_end],
@@ -657,11 +673,11 @@ def __call__(
657673
num_images_per_prompt=num_images_per_prompt,
658674
device=device,
659675
dtype=self.vae.dtype,
660-
) # torch.Size([1, 3, height_ori, width_ori])
676+
)
661677
height, width = control_image.shape[-2:]
662678

663679
if control_image.ndim == 4:
664-
control_image = control_image.unsqueeze(2) # torch.Size([1, 3, 1, height_ori, width_ori])
680+
control_image = control_image.unsqueeze(2)
665681

666682
# vae encode
667683
self.vae_scale_factor = 2 ** len(self.vae.temperal_downsample)
@@ -675,7 +691,7 @@ def __call__(
675691
control_image = retrieve_latents(self.vae.encode(control_image), generator=generator)
676692
control_image = (control_image - latents_mean) * latents_std
677693

678-
control_image = control_image.permute(0, 2, 1, 3, 4) # torch.Size([1, 1, 16, height_ori//8, width_ori//8])
694+
control_image = control_image.permute(0, 2, 1, 3, 4)
679695

680696
# pack
681697
control_image = self._pack_latents(
@@ -684,7 +700,53 @@ def __call__(
684700
num_channels_latents=num_channels_latents,
685701
height=control_image.shape[3],
686702
width=control_image.shape[4],
687-
)
703+
).to(dtype=prompt_embeds.dtype, device=device)
704+
705+
else:
706+
if isinstance(self.controlnet, QwenImageMultiControlNetModel):
707+
control_images = []
708+
for control_image_ in control_image:
709+
control_image_ = self.prepare_image(
710+
image=control_image_,
711+
width=width,
712+
height=height,
713+
batch_size=batch_size * num_images_per_prompt,
714+
num_images_per_prompt=num_images_per_prompt,
715+
device=device,
716+
dtype=self.vae.dtype,
717+
)
718+
719+
height, width = control_image_.shape[-2:]
720+
721+
if control_image_.ndim == 4:
722+
control_image_ = control_image_.unsqueeze(2)
723+
724+
# vae encode
725+
self.vae_scale_factor = 2 ** len(self.vae.temperal_downsample)
726+
latents_mean = (torch.tensor(self.vae.config.latents_mean).view(1, self.vae.config.z_dim, 1, 1, 1)).to(
727+
device
728+
)
729+
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
730+
device
731+
)
732+
733+
control_image_ = retrieve_latents(self.vae.encode(control_image_), generator=generator)
734+
control_image_ = (control_image_ - latents_mean) * latents_std
735+
736+
control_image_ = control_image_.permute(0, 2, 1, 3, 4)
737+
738+
# pack
739+
control_image_ = self._pack_latents(
740+
control_image_,
741+
batch_size=control_image_.shape[0],
742+
num_channels_latents=num_channels_latents,
743+
height=control_image_.shape[3],
744+
width=control_image_.shape[4],
745+
).to(dtype=prompt_embeds.dtype, device=device)
746+
747+
control_images.append(control_image_)
748+
749+
control_image = control_images
688750

689751
# 4. Prepare latent variables
690752
num_channels_latents = self.transformer.config.in_channels // 4
@@ -756,11 +818,11 @@ def __call__(
756818
if isinstance(controlnet_cond_scale, list):
757819
controlnet_cond_scale = controlnet_cond_scale[0]
758820
cond_scale = controlnet_cond_scale * controlnet_keep[i]
759-
821+
760822
# controlnet
761823
controlnet_block_samples = self.controlnet(
762824
hidden_states=latents,
763-
controlnet_cond=control_image.to(dtype=latents.dtype, device=device),
825+
controlnet_cond=control_image,
764826
conditioning_scale=cond_scale,
765827
timestep=timestep / 1000,
766828
encoder_hidden_states=prompt_embeds,

0 commit comments

Comments
 (0)