Skip to content

Commit 058fe2f

Browse files
committed
Added Ip adapter support to Flux img2img Pipeline
1 parent 5d2d239 commit 058fe2f

File tree

2 files changed

+173
-8
lines changed

2 files changed

+173
-8
lines changed

src/diffusers/pipelines/flux/pipeline_flux_img2img.py

Lines changed: 169 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,17 @@
1717

1818
import numpy as np
1919
import torch
20-
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
20+
from transformers import (
21+
CLIPImageProcessor,
22+
CLIPTextModel,
23+
CLIPTokenizer,
24+
CLIPVisionModelWithProjection,
25+
T5EncoderModel,
26+
T5TokenizerFast,
27+
)
2128

2229
from ...image_processor import PipelineImageInput, VaeImageProcessor
23-
from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
30+
from ...loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
2431
from ...models.autoencoders import AutoencoderKL
2532
from ...models.transformers import FluxTransformer2DModel
2633
from ...schedulers import FlowMatchEulerDiscreteScheduler
@@ -159,7 +166,7 @@ def retrieve_timesteps(
159166
return timesteps, num_inference_steps
160167

161168

162-
class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin):
169+
class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin, FluxIPAdapterMixin):
163170
r"""
164171
The Flux pipeline for image inpainting.
165172
@@ -186,8 +193,8 @@ class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
186193
[T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
187194
"""
188195

189-
model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
190-
_optional_components = []
196+
model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->transformer->vae"
197+
_optional_components = ["image_encoder", "feature_extractor"]
191198
_callback_tensor_inputs = ["latents", "prompt_embeds"]
192199

193200
def __init__(
@@ -199,6 +206,8 @@ def __init__(
199206
text_encoder_2: T5EncoderModel,
200207
tokenizer_2: T5TokenizerFast,
201208
transformer: FluxTransformer2DModel,
209+
image_encoder: CLIPVisionModelWithProjection = None,
210+
feature_extractor: CLIPImageProcessor = None,
202211
):
203212
super().__init__()
204213

@@ -210,6 +219,8 @@ def __init__(
210219
tokenizer_2=tokenizer_2,
211220
transformer=transformer,
212221
scheduler=scheduler,
222+
image_encoder=image_encoder,
223+
feature_extractor=feature_extractor,
213224
)
214225
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
215226
# Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
@@ -394,6 +405,47 @@ def encode_prompt(
394405
text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
395406

396407
return prompt_embeds, pooled_prompt_embeds, text_ids
408+
def encode_image(self, image, device, num_images_per_prompt):
409+
dtype = next(self.image_encoder.parameters()).dtype
410+
411+
if not isinstance(image, torch.Tensor):
412+
image = self.feature_extractor(image, return_tensors="pt").pixel_values
413+
414+
image = image.to(device=device, dtype=dtype)
415+
image_embeds = self.image_encoder(image).image_embeds
416+
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
417+
return image_embeds
418+
419+
def prepare_ip_adapter_image_embeds(
420+
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt
421+
):
422+
image_embeds = []
423+
if ip_adapter_image_embeds is None:
424+
if not isinstance(ip_adapter_image, list):
425+
ip_adapter_image = [ip_adapter_image]
426+
427+
if len(ip_adapter_image) != len(self.transformer.encoder_hid_proj.image_projection_layers):
428+
raise ValueError(
429+
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.transformer.encoder_hid_proj.image_projection_layers)} IP Adapters."
430+
)
431+
432+
for single_ip_adapter_image, image_proj_layer in zip(
433+
ip_adapter_image, self.transformer.encoder_hid_proj.image_projection_layers
434+
):
435+
single_image_embeds = self.encode_image(single_ip_adapter_image, device, 1)
436+
437+
image_embeds.append(single_image_embeds[None, :])
438+
else:
439+
for single_image_embeds in ip_adapter_image_embeds:
440+
image_embeds.append(single_image_embeds)
441+
442+
ip_adapter_image_embeds = []
443+
for i, single_image_embeds in enumerate(image_embeds):
444+
single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
445+
single_image_embeds = single_image_embeds.to(device=device)
446+
ip_adapter_image_embeds.append(single_image_embeds)
447+
448+
return ip_adapter_image_embeds
397449

398450
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image
399451
def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
@@ -429,8 +481,12 @@ def check_inputs(
429481
strength,
430482
height,
431483
width,
484+
negative_prompt=None,
485+
negative_prompt_2=None,
432486
prompt_embeds=None,
487+
negative_prompt_embeds=None,
433488
pooled_prompt_embeds=None,
489+
negative_pooled_prompt_embeds=None,
434490
callback_on_step_end_tensor_inputs=None,
435491
max_sequence_length=None,
436492
):
@@ -467,12 +523,32 @@ def check_inputs(
467523
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
468524
elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
469525
raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
526+
if negative_prompt is not None and negative_prompt_embeds is not None:
527+
raise ValueError(
528+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
529+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
530+
)
531+
elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
532+
raise ValueError(
533+
f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
534+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
535+
)
470536

537+
if prompt_embeds is not None and negative_prompt_embeds is not None:
538+
if prompt_embeds.shape != negative_prompt_embeds.shape:
539+
raise ValueError(
540+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
541+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
542+
f" {negative_prompt_embeds.shape}."
543+
)
471544
if prompt_embeds is not None and pooled_prompt_embeds is None:
472545
raise ValueError(
473546
"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
474547
)
475-
548+
if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
549+
raise ValueError(
550+
"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
551+
)
476552
if max_sequence_length is not None and max_sequence_length > 512:
477553
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
478554

@@ -586,6 +662,9 @@ def __call__(
586662
self,
587663
prompt: Union[str, List[str]] = None,
588664
prompt_2: Optional[Union[str, List[str]]] = None,
665+
negative_prompt: Union[str, List[str]] = None,
666+
negative_prompt_2: Optional[Union[str, List[str]]] = None,
667+
true_cfg_scale: float = 1.0,
589668
image: PipelineImageInput = None,
590669
height: Optional[int] = None,
591670
width: Optional[int] = None,
@@ -598,6 +677,12 @@ def __call__(
598677
latents: Optional[torch.FloatTensor] = None,
599678
prompt_embeds: Optional[torch.FloatTensor] = None,
600679
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
680+
ip_adapter_image: Optional[PipelineImageInput] = None,
681+
ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
682+
negative_ip_adapter_image: Optional[PipelineImageInput] = None,
683+
negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
684+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
685+
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
601686
output_type: Optional[str] = "pil",
602687
return_dict: bool = True,
603688
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
@@ -659,6 +744,17 @@ def __call__(
659744
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
660745
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
661746
If not provided, pooled text embeddings will be generated from `prompt` input argument.
747+
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
748+
ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
749+
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
750+
IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
751+
provided, embeddings are computed from the `ip_adapter_image` input argument.
752+
negative_ip_adapter_image:
753+
(`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
754+
negative_ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
755+
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
756+
IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
757+
provided, embeddings are computed from the `ip_adapter_image` input argument.
662758
output_type (`str`, *optional*, defaults to `"pil"`):
663759
The output format of the generate image. Choose between
664760
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
@@ -697,8 +793,12 @@ def __call__(
697793
strength,
698794
height,
699795
width,
796+
negative_prompt=negative_prompt,
797+
negative_prompt_2=negative_prompt_2,
700798
prompt_embeds=prompt_embeds,
799+
negative_prompt_embeds=negative_prompt_embeds,
701800
pooled_prompt_embeds=pooled_prompt_embeds,
801+
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
702802
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
703803
max_sequence_length=max_sequence_length,
704804
)
@@ -724,6 +824,7 @@ def __call__(
724824
lora_scale = (
725825
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
726826
)
827+
do_true_cfg = true_cfg_scale > 1 and negative_prompt is not None
727828
(
728829
prompt_embeds,
729830
pooled_prompt_embeds,
@@ -738,6 +839,21 @@ def __call__(
738839
max_sequence_length=max_sequence_length,
739840
lora_scale=lora_scale,
740841
)
842+
if do_true_cfg:
843+
(
844+
negative_prompt_embeds,
845+
negative_pooled_prompt_embeds,
846+
_,
847+
) = self.encode_prompt(
848+
prompt=negative_prompt,
849+
prompt_2=negative_prompt_2,
850+
prompt_embeds=negative_prompt_embeds,
851+
pooled_prompt_embeds=negative_pooled_prompt_embeds,
852+
device=device,
853+
num_images_per_prompt=num_images_per_prompt,
854+
max_sequence_length=max_sequence_length,
855+
lora_scale=lora_scale,
856+
)
741857

742858
# 4.Prepare timesteps
743859
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
@@ -791,11 +907,42 @@ def __call__(
791907
else:
792908
guidance = None
793909

910+
if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and (
911+
negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None
912+
):
913+
negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
914+
elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and (
915+
negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None
916+
):
917+
ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
918+
919+
if self.joint_attention_kwargs is None:
920+
self._joint_attention_kwargs = {}
921+
922+
image_embeds = None
923+
negative_image_embeds = None
924+
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
925+
image_embeds = self.prepare_ip_adapter_image_embeds(
926+
ip_adapter_image,
927+
ip_adapter_image_embeds,
928+
device,
929+
batch_size * num_images_per_prompt,
930+
)
931+
if negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None:
932+
negative_image_embeds = self.prepare_ip_adapter_image_embeds(
933+
negative_ip_adapter_image,
934+
negative_ip_adapter_image_embeds,
935+
device,
936+
batch_size * num_images_per_prompt,
937+
)
938+
794939
# 6. Denoising loop
795940
with self.progress_bar(total=num_inference_steps) as progress_bar:
796941
for i, t in enumerate(timesteps):
797942
if self.interrupt:
798943
continue
944+
if image_embeds is not None:
945+
self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds
799946

800947
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
801948
timestep = t.expand(latents.shape[0]).to(latents.dtype)
@@ -810,6 +957,22 @@ def __call__(
810957
joint_attention_kwargs=self.joint_attention_kwargs,
811958
return_dict=False,
812959
)[0]
960+
if do_true_cfg:
961+
if negative_image_embeds is not None:
962+
self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
963+
neg_noise_pred = self.transformer(
964+
hidden_states=latents,
965+
timestep=timestep / 1000,
966+
guidance=guidance,
967+
pooled_projections=negative_pooled_prompt_embeds,
968+
encoder_hidden_states=negative_prompt_embeds,
969+
txt_ids=text_ids,
970+
img_ids=latent_image_ids,
971+
joint_attention_kwargs=self.joint_attention_kwargs,
972+
return_dict=False,
973+
)[0]
974+
noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
975+
813976

814977
# compute the previous noisy sample x_t -> x_t-1
815978
latents_dtype = latents.dtype

tests/pipelines/flux/test_pipeline_flux_img2img.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,13 @@
1212
torch_device,
1313
)
1414

15-
from ..test_pipelines_common import PipelineTesterMixin
15+
from ..test_pipelines_common import PipelineTesterMixin,FluxIPAdapterTesterMixin
1616

1717

1818
enable_full_determinism()
1919

2020

21-
class FluxImg2ImgPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
21+
class FluxImg2ImgPipelineFastTests(unittest.TestCase, PipelineTesterMixin,FluxIPAdapterTesterMixin):
2222
pipeline_class = FluxImg2ImgPipeline
2323
params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"])
2424
batch_params = frozenset(["prompt"])
@@ -85,6 +85,8 @@ def get_dummy_components(self):
8585
"tokenizer_2": tokenizer_2,
8686
"transformer": transformer,
8787
"vae": vae,
88+
"image_encoder": None,
89+
"feature_extractor": None,
8890
}
8991

9092
def get_dummy_inputs(self, device, seed=0):

0 commit comments

Comments
 (0)