1919import torch
2020from transformers import CLIPTextModel , CLIPTokenizer , T5EncoderModel , T5TokenizerFast
2121
22- from ...image_processor import VaeImageProcessor
22+ from ...image_processor import PipelineImageInput , VaeImageProcessor
2323from ...loaders import FluxLoraLoaderMixin , FromSingleFileMixin , TextualInversionLoaderMixin
2424from ...models .autoencoders import AutoencoderKL
2525from ...models .transformers import FluxTransformer2DModel
@@ -513,7 +513,7 @@ def prepare_latents(
513513 shape = (batch_size , num_channels_latents , height , width )
514514
515515 if latents is not None :
516- latent_image_ids = self ._prepare_latent_image_ids (batch_size , height , width , device , dtype )
516+ latent_image_ids = self ._prepare_latent_image_ids (batch_size , height // 2 , width // 2 , device , dtype )
517517 return latents .to (device = device , dtype = dtype ), latent_image_ids
518518
519519 if isinstance (generator , list ) and len (generator ) != batch_size :
@@ -529,6 +529,41 @@ def prepare_latents(
529529
530530 return latents , latent_image_ids
531531
532+ # Copied from diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet.StableDiffusion3ControlNetPipeline.prepare_image
533+ def prepare_image (
534+ self ,
535+ image ,
536+ width ,
537+ height ,
538+ batch_size ,
539+ num_images_per_prompt ,
540+ device ,
541+ dtype ,
542+ do_classifier_free_guidance = False ,
543+ guess_mode = False ,
544+ ):
545+ if isinstance (image , torch .Tensor ):
546+ pass
547+ else :
548+ image = self .image_processor .preprocess (image , height = height , width = width )
549+
550+ image_batch_size = image .shape [0 ]
551+
552+ if image_batch_size == 1 :
553+ repeat_by = batch_size
554+ else :
555+ # image batch size is the same as prompt batch size
556+ repeat_by = num_images_per_prompt
557+
558+ image = image .repeat_interleave (repeat_by , dim = 0 )
559+
560+ image = image .to (device = device , dtype = dtype )
561+
562+ if do_classifier_free_guidance and not guess_mode :
563+ image = torch .cat ([image ] * 2 )
564+
565+ return image
566+
532567 @property
533568 def guidance_scale (self ):
534569 return self ._guidance_scale
@@ -556,9 +591,11 @@ def __call__(
556591 num_inference_steps : int = 28 ,
557592 timesteps : List [int ] = None ,
558593 guidance_scale : float = 3.5 ,
594+ control_image : PipelineImageInput = None ,
559595 num_images_per_prompt : Optional [int ] = 1 ,
560596 generator : Optional [Union [torch .Generator , List [torch .Generator ]]] = None ,
561597 latents : Optional [torch .FloatTensor ] = None ,
598+ control_latents : Optional [torch .FloatTensor ] = None ,
562599 prompt_embeds : Optional [torch .FloatTensor ] = None ,
563600 pooled_prompt_embeds : Optional [torch .FloatTensor ] = None ,
564601 output_type : Optional [str ] = "pil" ,
@@ -595,6 +632,14 @@ def __call__(
595632 Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
596633 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
597634 usually at the expense of lower image quality.
635+ control_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
636+ `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
637+ The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
638+ specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted
639+ as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or
640+ width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`,
641+ images must be passed as a list such that each element of the list can be correctly batched for input
642+ to a single ControlNet.
598643 num_images_per_prompt (`int`, *optional*, defaults to 1):
599644 The number of images to generate per prompt.
600645 generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
@@ -667,6 +712,7 @@ def __call__(
667712
668713 device = self ._execution_device
669714
715+ # 3. Prepare text embeddings
670716 lora_scale = (
671717 self .joint_attention_kwargs .get ("scale" , None ) if self .joint_attention_kwargs is not None else None
672718 )
@@ -686,7 +732,35 @@ def __call__(
686732 )
687733
688734 # 4. Prepare latent variables
689- num_channels_latents = self .transformer .config .in_channels // 4
735+ num_channels_latents = (
736+ self .transformer .config .in_channels // 4
737+ if control_image is None
738+ else self .transformer .config .in_channels // 8
739+ )
740+
741+ if control_image is not None and control_latents is None :
742+ control_image = self .prepare_image (
743+ image = control_image ,
744+ width = width ,
745+ height = height ,
746+ batch_size = batch_size * num_images_per_prompt ,
747+ num_images_per_prompt = num_images_per_prompt ,
748+ device = device ,
749+ dtype = self .vae .dtype ,
750+ )
751+
752+ control_latents = self .vae .encode (control_image ).latent_dist .sample (generator = generator )
753+ control_latents = (control_latents - self .vae .config .shift_factor ) * self .vae .config .scaling_factor
754+
755+ height_control_image , width_control_image = control_latents .shape [2 :]
756+ control_latents = self ._pack_latents (
757+ control_latents ,
758+ batch_size * num_images_per_prompt ,
759+ num_channels_latents ,
760+ height_control_image ,
761+ width_control_image ,
762+ )
763+
690764 latents , latent_image_ids = self .prepare_latents (
691765 batch_size * num_images_per_prompt ,
692766 num_channels_latents ,
@@ -732,11 +806,16 @@ def __call__(
732806 if self .interrupt :
733807 continue
734808
809+ if control_latents is not None :
810+ latent_model_input = torch .cat ([latents , control_latents ], dim = 2 )
811+ else :
812+ latent_model_input = latents
813+
735814 # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
736815 timestep = t .expand (latents .shape [0 ]).to (latents .dtype )
737816
738817 noise_pred = self .transformer (
739- hidden_states = latents ,
818+ hidden_states = latent_model_input ,
740819 timestep = timestep / 1000 ,
741820 guidance = guidance ,
742821 pooled_projections = pooled_prompt_embeds ,
@@ -774,7 +853,6 @@ def __call__(
774853
775854 if output_type == "latent" :
776855 image = latents
777-
778856 else :
779857 latents = self ._unpack_latents (latents , height , width , self .vae_scale_factor )
780858 latents = (latents / self .vae .config .scaling_factor ) + self .vae .config .shift_factor
0 commit comments