@@ -566,6 +566,41 @@ def prepare_latents(
566566        latents  =  self ._pack_latents (latents , batch_size , num_channels_latents , height , width )
567567        return  latents , latent_image_ids 
568568
569+     # Copied from diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet.StableDiffusion3ControlNetPipeline.prepare_image 
570+     def  prepare_image (
571+         self ,
572+         image ,
573+         width ,
574+         height ,
575+         batch_size ,
576+         num_images_per_prompt ,
577+         device ,
578+         dtype ,
579+         do_classifier_free_guidance = False ,
580+         guess_mode = False ,
581+     ):
582+         if  isinstance (image , torch .Tensor ):
583+             pass 
584+         else :
585+             image  =  self .image_processor .preprocess (image , height = height , width = width )
586+ 
587+         image_batch_size  =  image .shape [0 ]
588+ 
589+         if  image_batch_size  ==  1 :
590+             repeat_by  =  batch_size 
591+         else :
592+             # image batch size is the same as prompt batch size 
593+             repeat_by  =  num_images_per_prompt 
594+ 
595+         image  =  image .repeat_interleave (repeat_by , dim = 0 )
596+ 
597+         image  =  image .to (device = device , dtype = dtype )
598+ 
599+         if  do_classifier_free_guidance  and  not  guess_mode :
600+             image  =  torch .cat ([image ] *  2 )
601+ 
602+         return  image 
603+ 
569604    @property  
570605    def  guidance_scale (self ):
571606        return  self ._guidance_scale 
@@ -595,8 +630,10 @@ def __call__(
595630        num_inference_steps : int  =  28 ,
596631        timesteps : List [int ] =  None ,
597632        guidance_scale : float  =  7.0 ,
633+         control_image : PipelineImageInput  =  None ,
598634        num_images_per_prompt : Optional [int ] =  1 ,
599635        generator : Optional [Union [torch .Generator , List [torch .Generator ]]] =  None ,
636+         control_latents : Optional [torch .FloatTensor ] =  None ,
600637        latents : Optional [torch .FloatTensor ] =  None ,
601638        prompt_embeds : Optional [torch .FloatTensor ] =  None ,
602639        pooled_prompt_embeds : Optional [torch .FloatTensor ] =  None ,
@@ -646,6 +683,14 @@ def __call__(
646683                Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 
647684                1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, 
648685                usually at the expense of lower image quality. 
686+             control_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: 
687+                     `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): 
688+                 The ControlNet input condition to provide guidance to the `unet` for generation. If the type is 
689+                 specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted 
690+                 as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or 
691+                 width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`, 
692+                 images must be passed as a list such that each element of the list can be correctly batched for input 
693+                 to a single ControlNet. 
649694            num_images_per_prompt (`int`, *optional*, defaults to 1): 
650695                The number of images to generate per prompt. 
651696            generator (`torch.Generator` or `List[torch.Generator]`, *optional*): 
@@ -723,6 +768,7 @@ def __call__(
723768
724769        device  =  self ._execution_device 
725770
771+         # 3. Prepare text embeddings 
726772        lora_scale  =  (
727773            self .joint_attention_kwargs .get ("scale" , None ) if  self .joint_attention_kwargs  is  not None  else  None 
728774        )
@@ -769,7 +815,34 @@ def __call__(
769815        latent_timestep  =  timesteps [:1 ].repeat (batch_size  *  num_images_per_prompt )
770816
771817        # 5. Prepare latent variables 
772-         num_channels_latents  =  self .transformer .config .in_channels  //  4 
818+         num_channels_latents  =  (
819+             self .transformer .config .in_channels  //  4 
820+             if  control_image  is  None 
821+             else  self .transformer .config .in_channels  //  8 
822+         )
823+ 
824+         if  control_image  is  not None  and  control_latents  is  None :
825+             control_image  =  self .prepare_image (
826+                 image = control_image ,
827+                 width = width ,
828+                 height = height ,
829+                 batch_size = batch_size  *  num_images_per_prompt ,
830+                 num_images_per_prompt = num_images_per_prompt ,
831+                 device = device ,
832+                 dtype = self .vae .dtype ,
833+             )
834+ 
835+             control_latents  =  self .vae .encode (control_image ).latent_dist .sample (generator = generator )
836+             control_latents  =  (control_latents  -  self .vae .config .shift_factor ) *  self .vae .config .scaling_factor 
837+ 
838+             height_control_image , width_control_image  =  control_latents .shape [2 :]
839+             control_latents  =  self ._pack_latents (
840+                 control_latents ,
841+                 batch_size  *  num_images_per_prompt ,
842+                 num_channels_latents ,
843+                 height_control_image ,
844+                 width_control_image ,
845+             )
773846
774847        latents , latent_image_ids  =  self .prepare_latents (
775848            init_image ,
@@ -800,10 +873,16 @@ def __call__(
800873                if  self .interrupt :
801874                    continue 
802875
876+                 if  control_latents  is  not None :
877+                     latent_model_input  =  torch .cat ([latents , control_latents ], dim = 2 )
878+                 else :
879+                     latent_model_input  =  latents 
880+ 
803881                # broadcast to batch dimension in a way that's compatible with ONNX/Core ML 
804882                timestep  =  t .expand (latents .shape [0 ]).to (latents .dtype )
883+ 
805884                noise_pred  =  self .transformer (
806-                     hidden_states = latents ,
885+                     hidden_states = latent_model_input ,
807886                    timestep = timestep  /  1000 ,
808887                    guidance = guidance ,
809888                    pooled_projections = pooled_prompt_embeds ,
0 commit comments