@@ -566,6 +566,41 @@ def prepare_latents(
566566 latents = self ._pack_latents (latents , batch_size , num_channels_latents , height , width )
567567 return latents , latent_image_ids
568568
569+ # Copied from diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet.StableDiffusion3ControlNetPipeline.prepare_image
570+ def prepare_image (
571+ self ,
572+ image ,
573+ width ,
574+ height ,
575+ batch_size ,
576+ num_images_per_prompt ,
577+ device ,
578+ dtype ,
579+ do_classifier_free_guidance = False ,
580+ guess_mode = False ,
581+ ):
582+ if isinstance (image , torch .Tensor ):
583+ pass
584+ else :
585+ image = self .image_processor .preprocess (image , height = height , width = width )
586+
587+ image_batch_size = image .shape [0 ]
588+
589+ if image_batch_size == 1 :
590+ repeat_by = batch_size
591+ else :
592+ # image batch size is the same as prompt batch size
593+ repeat_by = num_images_per_prompt
594+
595+ image = image .repeat_interleave (repeat_by , dim = 0 )
596+
597+ image = image .to (device = device , dtype = dtype )
598+
599+ if do_classifier_free_guidance and not guess_mode :
600+ image = torch .cat ([image ] * 2 )
601+
602+ return image
603+
569604 @property
570605 def guidance_scale (self ):
571606 return self ._guidance_scale
@@ -595,8 +630,10 @@ def __call__(
595630 num_inference_steps : int = 28 ,
596631 timesteps : List [int ] = None ,
597632 guidance_scale : float = 7.0 ,
633+ control_image : PipelineImageInput = None ,
598634 num_images_per_prompt : Optional [int ] = 1 ,
599635 generator : Optional [Union [torch .Generator , List [torch .Generator ]]] = None ,
636+ control_latents : Optional [torch .FloatTensor ] = None ,
600637 latents : Optional [torch .FloatTensor ] = None ,
601638 prompt_embeds : Optional [torch .FloatTensor ] = None ,
602639 pooled_prompt_embeds : Optional [torch .FloatTensor ] = None ,
@@ -646,6 +683,14 @@ def __call__(
646683 Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
647684 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
648685 usually at the expense of lower image quality.
686+ control_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
687+ `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
688+ The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
689+ specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted
690+ as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or
691+ width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`,
692+ images must be passed as a list such that each element of the list can be correctly batched for input
693+ to a single ControlNet.
649694 num_images_per_prompt (`int`, *optional*, defaults to 1):
650695 The number of images to generate per prompt.
651696 generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
@@ -723,6 +768,7 @@ def __call__(
723768
724769 device = self ._execution_device
725770
771+ # 3. Prepare text embeddings
726772 lora_scale = (
727773 self .joint_attention_kwargs .get ("scale" , None ) if self .joint_attention_kwargs is not None else None
728774 )
@@ -769,7 +815,34 @@ def __call__(
769815 latent_timestep = timesteps [:1 ].repeat (batch_size * num_images_per_prompt )
770816
771817 # 5. Prepare latent variables
772- num_channels_latents = self .transformer .config .in_channels // 4
818+ num_channels_latents = (
819+ self .transformer .config .in_channels // 4
820+ if control_image is None
821+ else self .transformer .config .in_channels // 8
822+ )
823+
824+ if control_image is not None and control_latents is None :
825+ control_image = self .prepare_image (
826+ image = control_image ,
827+ width = width ,
828+ height = height ,
829+ batch_size = batch_size * num_images_per_prompt ,
830+ num_images_per_prompt = num_images_per_prompt ,
831+ device = device ,
832+ dtype = self .vae .dtype ,
833+ )
834+
835+ control_latents = self .vae .encode (control_image ).latent_dist .sample (generator = generator )
836+ control_latents = (control_latents - self .vae .config .shift_factor ) * self .vae .config .scaling_factor
837+
838+ height_control_image , width_control_image = control_latents .shape [2 :]
839+ control_latents = self ._pack_latents (
840+ control_latents ,
841+ batch_size * num_images_per_prompt ,
842+ num_channels_latents ,
843+ height_control_image ,
844+ width_control_image ,
845+ )
773846
774847 latents , latent_image_ids = self .prepare_latents (
775848 init_image ,
@@ -800,10 +873,16 @@ def __call__(
800873 if self .interrupt :
801874 continue
802875
876+ if control_latents is not None :
877+ latent_model_input = torch .cat ([latents , control_latents ], dim = 2 )
878+ else :
879+ latent_model_input = latents
880+
803881 # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
804882 timestep = t .expand (latents .shape [0 ]).to (latents .dtype )
883+
805884 noise_pred = self .transformer (
806- hidden_states = latents ,
885+ hidden_states = latent_model_input ,
807886 timestep = timestep / 1000 ,
808887 guidance = guidance ,
809888 pooled_projections = pooled_prompt_embeds ,
0 commit comments