4545        ```py 
4646        >>> import torch 
4747        >>> from diffusers.utils import load_image 
48-         >>> from diffusers import QwenImageControlNetModel, QwenImageControlNetPipeline 
48+         >>> from diffusers import QwenImageControlNetModel, QwenImageMultiControlNetModel,  QwenImageControlNetPipeline 
4949
50+         >>> # QwenImageControlNetModel 
5051        >>> controlnet = QwenImageControlNetModel.from_pretrained("InstantX/Qwen-Image-ControlNet-Union", torch_dtype=torch.bfloat16) 
5152        >>> pipe = QwenImageControlNetPipeline.from_pretrained("Qwen/Qwen-Image", controlnet=controlnet, torch_dtype=torch.bfloat16) 
5253        >>> pipe.to("cuda") 
5758        >>> # Refer to the pipeline documentation for more details. 
5859        >>> image = pipe(prompt, negative_prompt=negative_prompt, control_image=control_image, controlnet_conditioning_scale=1.0, num_inference_steps=30, true_cfg_scale=4.0).images[0] 
5960        >>> image.save("qwenimage_cn_union.png") 
61+ 
62+         >>> # QwenImageMultiControlNetModel 
63+         >>> controlnet = QwenImageControlNetModel.from_pretrained("InstantX/Qwen-Image-ControlNet-Union", torch_dtype=torch.bfloat16) 
64+         >>> controlnet = QwenImageMultiControlNetModel([controlnet]) 
65+         >>> pipe = QwenImageControlNetPipeline.from_pretrained("Qwen/Qwen-Image", controlnet=controlnet, torch_dtype=torch.bfloat16) 
66+         >>> pipe.to("cuda") 
67+         >>> prompt = "Aesthetics art, traditional asian pagoda, elaborate golden accents, sky blue and white color palette, swirling cloud pattern, digital illustration, east asian architecture, ornamental rooftop, intricate detailing on building, cultural representation." 
68+         >>> negative_prompt = " " 
69+         >>> control_image = load_image("https://huggingface.co/InstantX/Qwen-Image-ControlNet-Union/resolve/main/conds/canny.png") 
70+         >>> # Depending on the variant being used, the pipeline call will slightly vary. 
71+         >>> # Refer to the pipeline documentation for more details. 
72+         >>> image = pipe(prompt, negative_prompt=negative_prompt, control_image=[control_image, control_image], controlnet_conditioning_scale=[0.5, 0.5], num_inference_steps=30, true_cfg_scale=4.0).images[0] 
73+         >>> image.save("qwenimage_cn_union_multi.png") 
6074        ``` 
6175""" 
6276
@@ -177,7 +191,9 @@ def __init__(
177191        text_encoder : Qwen2_5_VLForConditionalGeneration ,
178192        tokenizer : Qwen2Tokenizer ,
179193        transformer : QwenImageTransformer2DModel ,
180-         controlnet : QwenImageControlNetModel ,
194+         controlnet : Union [
195+             QwenImageControlNetModel , QwenImageMultiControlNetModel 
196+         ],
181197    ):
182198        super ().__init__ ()
183199
@@ -589,7 +605,7 @@ def __call__(
589605        elif  not  isinstance (control_guidance_end , list ) and  isinstance (control_guidance_start , list ):
590606            control_guidance_end  =  len (control_guidance_start ) *  [control_guidance_end ]
591607        elif  not  isinstance (control_guidance_start , list ) and  not  isinstance (control_guidance_end , list ):
592-             mult  =  len (self . controlnet . nets ) if  isinstance (self .controlnet , QwenImageMultiControlNetModel ) else  1 
608+             mult  =  len (control_image ) if  isinstance (self .controlnet , QwenImageMultiControlNetModel ) else  1 
593609            control_guidance_start , control_guidance_end  =  (
594610                mult  *  [control_guidance_start ],
595611                mult  *  [control_guidance_end ],
@@ -657,11 +673,11 @@ def __call__(
657673                num_images_per_prompt = num_images_per_prompt ,
658674                device = device ,
659675                dtype = self .vae .dtype ,
660-             )   # torch.Size([1, 3, height_ori, width_ori]) 
676+             )
661677            height , width  =  control_image .shape [- 2 :]
662678
663679            if  control_image .ndim  ==  4 :
664-                 control_image  =  control_image .unsqueeze (2 )   # torch.Size([1, 3, 1, height_ori, width_ori]) 
680+                 control_image  =  control_image .unsqueeze (2 )
665681
666682            # vae encode 
667683            self .vae_scale_factor  =  2  **  len (self .vae .temperal_downsample )
@@ -675,7 +691,7 @@ def __call__(
675691            control_image  =  retrieve_latents (self .vae .encode (control_image ), generator = generator )
676692            control_image  =  (control_image  -  latents_mean ) *  latents_std 
677693
678-             control_image  =  control_image .permute (0 , 2 , 1 , 3 , 4 )   # torch.Size([1, 1, 16, height_ori//8, width_ori//8]) 
694+             control_image  =  control_image .permute (0 , 2 , 1 , 3 , 4 )
679695
680696            # pack 
681697            control_image  =  self ._pack_latents (
@@ -684,7 +700,53 @@ def __call__(
684700                num_channels_latents = num_channels_latents ,
685701                height = control_image .shape [3 ],
686702                width = control_image .shape [4 ],
687-             )
703+             ).to (dtype = prompt_embeds .dtype , device = device )
704+         
705+         else :
706+             if  isinstance (self .controlnet , QwenImageMultiControlNetModel ):
707+                 control_images  =  []
708+                 for  control_image_  in  control_image :
709+                     control_image_  =  self .prepare_image (
710+                         image = control_image_ ,
711+                         width = width ,
712+                         height = height ,
713+                         batch_size = batch_size  *  num_images_per_prompt ,
714+                         num_images_per_prompt = num_images_per_prompt ,
715+                         device = device ,
716+                         dtype = self .vae .dtype ,
717+                     )
718+ 
719+                     height , width  =  control_image_ .shape [- 2 :]
720+ 
721+                     if  control_image_ .ndim  ==  4 :
722+                         control_image_  =  control_image_ .unsqueeze (2 )
723+ 
724+                     # vae encode 
725+                     self .vae_scale_factor  =  2  **  len (self .vae .temperal_downsample )
726+                     latents_mean  =  (torch .tensor (self .vae .config .latents_mean ).view (1 , self .vae .config .z_dim , 1 , 1 , 1 )).to (
727+                         device 
728+                     )
729+                     latents_std  =  1.0  /  torch .tensor (self .vae .config .latents_std ).view (1 , self .vae .config .z_dim , 1 , 1 , 1 ).to (
730+                         device 
731+                     )
732+ 
733+                     control_image_  =  retrieve_latents (self .vae .encode (control_image_ ), generator = generator )
734+                     control_image_  =  (control_image_  -  latents_mean ) *  latents_std 
735+ 
736+                     control_image_  =  control_image_ .permute (0 , 2 , 1 , 3 , 4 )
737+ 
738+                     # pack 
739+                     control_image_  =  self ._pack_latents (
740+                         control_image_ ,
741+                         batch_size = control_image_ .shape [0 ],
742+                         num_channels_latents = num_channels_latents ,
743+                         height = control_image_ .shape [3 ],
744+                         width = control_image_ .shape [4 ],
745+                     ).to (dtype = prompt_embeds .dtype , device = device )
746+ 
747+                     control_images .append (control_image_ )
748+ 
749+                 control_image  =  control_images 
688750
689751        # 4. Prepare latent variables 
690752        num_channels_latents  =  self .transformer .config .in_channels  //  4 
@@ -756,11 +818,11 @@ def __call__(
756818                    if  isinstance (controlnet_cond_scale , list ):
757819                        controlnet_cond_scale  =  controlnet_cond_scale [0 ]
758820                    cond_scale  =  controlnet_cond_scale  *  controlnet_keep [i ]
759- 
821+                  
760822                # controlnet 
761823                controlnet_block_samples  =  self .controlnet (
762824                    hidden_states = latents ,
763-                     controlnet_cond = control_image . to ( dtype = latents . dtype ,  device = device ) ,
825+                     controlnet_cond = control_image ,
764826                    conditioning_scale = cond_scale ,
765827                    timestep = timestep  /  1000 ,
766828                    encoder_hidden_states = prompt_embeds ,
0 commit comments