Skip to content

Commit 7e4df06

Browse files
committed
update
1 parent be67dbd commit 7e4df06

File tree

1 file changed

+81
-2
lines changed

1 file changed

+81
-2
lines changed

src/diffusers/pipelines/flux/pipeline_flux_img2img.py

Lines changed: 81 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -566,6 +566,41 @@ def prepare_latents(
566566
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
567567
return latents, latent_image_ids
568568

569+
# Copied from diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet.StableDiffusion3ControlNetPipeline.prepare_image
570+
def prepare_image(
571+
self,
572+
image,
573+
width,
574+
height,
575+
batch_size,
576+
num_images_per_prompt,
577+
device,
578+
dtype,
579+
do_classifier_free_guidance=False,
580+
guess_mode=False,
581+
):
582+
if isinstance(image, torch.Tensor):
583+
pass
584+
else:
585+
image = self.image_processor.preprocess(image, height=height, width=width)
586+
587+
image_batch_size = image.shape[0]
588+
589+
if image_batch_size == 1:
590+
repeat_by = batch_size
591+
else:
592+
# image batch size is the same as prompt batch size
593+
repeat_by = num_images_per_prompt
594+
595+
image = image.repeat_interleave(repeat_by, dim=0)
596+
597+
image = image.to(device=device, dtype=dtype)
598+
599+
if do_classifier_free_guidance and not guess_mode:
600+
image = torch.cat([image] * 2)
601+
602+
return image
603+
569604
@property
570605
def guidance_scale(self):
571606
return self._guidance_scale
@@ -595,8 +630,10 @@ def __call__(
595630
num_inference_steps: int = 28,
596631
timesteps: List[int] = None,
597632
guidance_scale: float = 7.0,
633+
control_image: PipelineImageInput = None,
598634
num_images_per_prompt: Optional[int] = 1,
599635
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
636+
control_latents: Optional[torch.FloatTensor] = None,
600637
latents: Optional[torch.FloatTensor] = None,
601638
prompt_embeds: Optional[torch.FloatTensor] = None,
602639
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
@@ -646,6 +683,14 @@ def __call__(
646683
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
647684
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
648685
usually at the expense of lower image quality.
686+
control_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
687+
`List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
688+
The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
689+
specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted
690+
as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or
691+
width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`,
692+
images must be passed as a list such that each element of the list can be correctly batched for input
693+
to a single ControlNet.
649694
num_images_per_prompt (`int`, *optional*, defaults to 1):
650695
The number of images to generate per prompt.
651696
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
@@ -723,6 +768,7 @@ def __call__(
723768

724769
device = self._execution_device
725770

771+
# 3. Prepare text embeddings
726772
lora_scale = (
727773
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
728774
)
@@ -769,7 +815,34 @@ def __call__(
769815
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
770816

771817
# 5. Prepare latent variables
772-
num_channels_latents = self.transformer.config.in_channels // 4
818+
num_channels_latents = (
819+
self.transformer.config.in_channels // 4
820+
if control_image is None
821+
else self.transformer.config.in_channels // 8
822+
)
823+
824+
if control_image is not None and control_latents is None:
825+
control_image = self.prepare_image(
826+
image=control_image,
827+
width=width,
828+
height=height,
829+
batch_size=batch_size * num_images_per_prompt,
830+
num_images_per_prompt=num_images_per_prompt,
831+
device=device,
832+
dtype=self.vae.dtype,
833+
)
834+
835+
control_latents = self.vae.encode(control_image).latent_dist.sample(generator=generator)
836+
control_latents = (control_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
837+
838+
height_control_image, width_control_image = control_latents.shape[2:]
839+
control_latents = self._pack_latents(
840+
control_latents,
841+
batch_size * num_images_per_prompt,
842+
num_channels_latents,
843+
height_control_image,
844+
width_control_image,
845+
)
773846

774847
latents, latent_image_ids = self.prepare_latents(
775848
init_image,
@@ -800,10 +873,16 @@ def __call__(
800873
if self.interrupt:
801874
continue
802875

876+
if control_latents is not None:
877+
latent_model_input = torch.cat([latents, control_latents], dim=2)
878+
else:
879+
latent_model_input = latents
880+
803881
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
804882
timestep = t.expand(latents.shape[0]).to(latents.dtype)
883+
805884
noise_pred = self.transformer(
806-
hidden_states=latents,
885+
hidden_states=latent_model_input,
807886
timestep=timestep / 1000,
808887
guidance=guidance,
809888
pooled_projections=pooled_prompt_embeds,

0 commit comments

Comments
 (0)