1717
1818import numpy as np
1919import torch
20- from transformers import CLIPTextModel , CLIPTokenizer , T5EncoderModel , T5TokenizerFast
20+ from transformers import (
21+ CLIPImageProcessor ,
22+ CLIPTextModel ,
23+ CLIPTokenizer ,
24+ CLIPVisionModelWithProjection ,
25+ T5EncoderModel ,
26+ T5TokenizerFast ,
27+ )
2128
2229from ...image_processor import PipelineImageInput , VaeImageProcessor
23- from ...loaders import FluxLoraLoaderMixin , FromSingleFileMixin , TextualInversionLoaderMixin
30+ from ...loaders import FluxIPAdapterMixin , FluxLoraLoaderMixin , FromSingleFileMixin , TextualInversionLoaderMixin
2431from ...models .autoencoders import AutoencoderKL
2532from ...models .transformers import FluxTransformer2DModel
2633from ...schedulers import FlowMatchEulerDiscreteScheduler
@@ -159,7 +166,7 @@ def retrieve_timesteps(
159166 return timesteps , num_inference_steps
160167
161168
162- class FluxImg2ImgPipeline (DiffusionPipeline , FluxLoraLoaderMixin , FromSingleFileMixin ):
169+ class FluxImg2ImgPipeline (DiffusionPipeline , FluxLoraLoaderMixin , FromSingleFileMixin , FluxIPAdapterMixin ):
163170 r"""
164171 The Flux pipeline for image inpainting.
165172
@@ -186,8 +193,8 @@ class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
186193 [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
187194 """
188195
189- model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
190- _optional_components = []
196+ model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder-> transformer->vae"
197+ _optional_components = ["image_encoder" , "feature_extractor" ]
191198 _callback_tensor_inputs = ["latents" , "prompt_embeds" ]
192199
193200 def __init__ (
@@ -199,6 +206,8 @@ def __init__(
199206 text_encoder_2 : T5EncoderModel ,
200207 tokenizer_2 : T5TokenizerFast ,
201208 transformer : FluxTransformer2DModel ,
209+ image_encoder : CLIPVisionModelWithProjection = None ,
210+ feature_extractor : CLIPImageProcessor = None ,
202211 ):
203212 super ().__init__ ()
204213
@@ -210,6 +219,8 @@ def __init__(
210219 tokenizer_2 = tokenizer_2 ,
211220 transformer = transformer ,
212221 scheduler = scheduler ,
222+ image_encoder = image_encoder ,
223+ feature_extractor = feature_extractor ,
213224 )
214225 self .vae_scale_factor = 2 ** (len (self .vae .config .block_out_channels ) - 1 ) if getattr (self , "vae" , None ) else 8
215226 # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
@@ -394,6 +405,47 @@ def encode_prompt(
394405 text_ids = torch .zeros (prompt_embeds .shape [1 ], 3 ).to (device = device , dtype = dtype )
395406
396407 return prompt_embeds , pooled_prompt_embeds , text_ids
408+ def encode_image (self , image , device , num_images_per_prompt ):
409+ dtype = next (self .image_encoder .parameters ()).dtype
410+
411+ if not isinstance (image , torch .Tensor ):
412+ image = self .feature_extractor (image , return_tensors = "pt" ).pixel_values
413+
414+ image = image .to (device = device , dtype = dtype )
415+ image_embeds = self .image_encoder (image ).image_embeds
416+ image_embeds = image_embeds .repeat_interleave (num_images_per_prompt , dim = 0 )
417+ return image_embeds
418+
419+ def prepare_ip_adapter_image_embeds (
420+ self , ip_adapter_image , ip_adapter_image_embeds , device , num_images_per_prompt
421+ ):
422+ image_embeds = []
423+ if ip_adapter_image_embeds is None :
424+ if not isinstance (ip_adapter_image , list ):
425+ ip_adapter_image = [ip_adapter_image ]
426+
427+ if len (ip_adapter_image ) != len (self .transformer .encoder_hid_proj .image_projection_layers ):
428+ raise ValueError (
429+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got { len (ip_adapter_image )} images and { len (self .transformer .encoder_hid_proj .image_projection_layers )} IP Adapters."
430+ )
431+
432+ for single_ip_adapter_image , image_proj_layer in zip (
433+ ip_adapter_image , self .transformer .encoder_hid_proj .image_projection_layers
434+ ):
435+ single_image_embeds = self .encode_image (single_ip_adapter_image , device , 1 )
436+
437+ image_embeds .append (single_image_embeds [None , :])
438+ else :
439+ for single_image_embeds in ip_adapter_image_embeds :
440+ image_embeds .append (single_image_embeds )
441+
442+ ip_adapter_image_embeds = []
443+ for i , single_image_embeds in enumerate (image_embeds ):
444+ single_image_embeds = torch .cat ([single_image_embeds ] * num_images_per_prompt , dim = 0 )
445+ single_image_embeds = single_image_embeds .to (device = device )
446+ ip_adapter_image_embeds .append (single_image_embeds )
447+
448+ return ip_adapter_image_embeds
397449
398450 # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image
399451 def _encode_vae_image (self , image : torch .Tensor , generator : torch .Generator ):
@@ -429,8 +481,12 @@ def check_inputs(
429481 strength ,
430482 height ,
431483 width ,
484+ negative_prompt = None ,
485+ negative_prompt_2 = None ,
432486 prompt_embeds = None ,
487+ negative_prompt_embeds = None ,
433488 pooled_prompt_embeds = None ,
489+ negative_pooled_prompt_embeds = None ,
434490 callback_on_step_end_tensor_inputs = None ,
435491 max_sequence_length = None ,
436492 ):
@@ -467,12 +523,32 @@ def check_inputs(
467523 raise ValueError (f"`prompt` has to be of type `str` or `list` but is { type (prompt )} " )
468524 elif prompt_2 is not None and (not isinstance (prompt_2 , str ) and not isinstance (prompt_2 , list )):
469525 raise ValueError (f"`prompt_2` has to be of type `str` or `list` but is { type (prompt_2 )} " )
526+ if negative_prompt is not None and negative_prompt_embeds is not None :
527+ raise ValueError (
528+ f"Cannot forward both `negative_prompt`: { negative_prompt } and `negative_prompt_embeds`:"
529+ f" { negative_prompt_embeds } . Please make sure to only forward one of the two."
530+ )
531+ elif negative_prompt_2 is not None and negative_prompt_embeds is not None :
532+ raise ValueError (
533+ f"Cannot forward both `negative_prompt_2`: { negative_prompt_2 } and `negative_prompt_embeds`:"
534+ f" { negative_prompt_embeds } . Please make sure to only forward one of the two."
535+ )
470536
537+ if prompt_embeds is not None and negative_prompt_embeds is not None :
538+ if prompt_embeds .shape != negative_prompt_embeds .shape :
539+ raise ValueError (
540+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
541+ f" got: `prompt_embeds` { prompt_embeds .shape } != `negative_prompt_embeds`"
542+ f" { negative_prompt_embeds .shape } ."
543+ )
471544 if prompt_embeds is not None and pooled_prompt_embeds is None :
472545 raise ValueError (
473546 "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
474547 )
475-
548+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None :
549+ raise ValueError (
550+ "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
551+ )
476552 if max_sequence_length is not None and max_sequence_length > 512 :
477553 raise ValueError (f"`max_sequence_length` cannot be greater than 512 but is { max_sequence_length } " )
478554
@@ -586,6 +662,9 @@ def __call__(
586662 self ,
587663 prompt : Union [str , List [str ]] = None ,
588664 prompt_2 : Optional [Union [str , List [str ]]] = None ,
665+ negative_prompt : Union [str , List [str ]] = None ,
666+ negative_prompt_2 : Optional [Union [str , List [str ]]] = None ,
667+ true_cfg_scale : float = 1.0 ,
589668 image : PipelineImageInput = None ,
590669 height : Optional [int ] = None ,
591670 width : Optional [int ] = None ,
@@ -598,6 +677,12 @@ def __call__(
598677 latents : Optional [torch .FloatTensor ] = None ,
599678 prompt_embeds : Optional [torch .FloatTensor ] = None ,
600679 pooled_prompt_embeds : Optional [torch .FloatTensor ] = None ,
680+ ip_adapter_image : Optional [PipelineImageInput ] = None ,
681+ ip_adapter_image_embeds : Optional [List [torch .Tensor ]] = None ,
682+ negative_ip_adapter_image : Optional [PipelineImageInput ] = None ,
683+ negative_ip_adapter_image_embeds : Optional [List [torch .Tensor ]] = None ,
684+ negative_prompt_embeds : Optional [torch .FloatTensor ] = None ,
685+ negative_pooled_prompt_embeds : Optional [torch .FloatTensor ] = None ,
601686 output_type : Optional [str ] = "pil" ,
602687 return_dict : bool = True ,
603688 joint_attention_kwargs : Optional [Dict [str , Any ]] = None ,
@@ -659,6 +744,17 @@ def __call__(
659744 pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
660745 Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
661746 If not provided, pooled text embeddings will be generated from `prompt` input argument.
747+ ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
748+ ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
749+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
750+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
751+ provided, embeddings are computed from the `ip_adapter_image` input argument.
752+ negative_ip_adapter_image:
753+ (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
754+ negative_ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
755+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
756+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
757+ provided, embeddings are computed from the `ip_adapter_image` input argument.
662758 output_type (`str`, *optional*, defaults to `"pil"`):
663759 The output format of the generate image. Choose between
664760 [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
@@ -697,8 +793,12 @@ def __call__(
697793 strength ,
698794 height ,
699795 width ,
796+ negative_prompt = negative_prompt ,
797+ negative_prompt_2 = negative_prompt_2 ,
700798 prompt_embeds = prompt_embeds ,
799+ negative_prompt_embeds = negative_prompt_embeds ,
701800 pooled_prompt_embeds = pooled_prompt_embeds ,
801+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds ,
702802 callback_on_step_end_tensor_inputs = callback_on_step_end_tensor_inputs ,
703803 max_sequence_length = max_sequence_length ,
704804 )
@@ -724,6 +824,7 @@ def __call__(
724824 lora_scale = (
725825 self .joint_attention_kwargs .get ("scale" , None ) if self .joint_attention_kwargs is not None else None
726826 )
827+ do_true_cfg = true_cfg_scale > 1 and negative_prompt is not None
727828 (
728829 prompt_embeds ,
729830 pooled_prompt_embeds ,
@@ -738,6 +839,21 @@ def __call__(
738839 max_sequence_length = max_sequence_length ,
739840 lora_scale = lora_scale ,
740841 )
842+ if do_true_cfg :
843+ (
844+ negative_prompt_embeds ,
845+ negative_pooled_prompt_embeds ,
846+ _ ,
847+ ) = self .encode_prompt (
848+ prompt = negative_prompt ,
849+ prompt_2 = negative_prompt_2 ,
850+ prompt_embeds = negative_prompt_embeds ,
851+ pooled_prompt_embeds = negative_pooled_prompt_embeds ,
852+ device = device ,
853+ num_images_per_prompt = num_images_per_prompt ,
854+ max_sequence_length = max_sequence_length ,
855+ lora_scale = lora_scale ,
856+ )
741857
742858 # 4.Prepare timesteps
743859 sigmas = np .linspace (1.0 , 1 / num_inference_steps , num_inference_steps ) if sigmas is None else sigmas
@@ -791,11 +907,42 @@ def __call__(
791907 else :
792908 guidance = None
793909
910+ if (ip_adapter_image is not None or ip_adapter_image_embeds is not None ) and (
911+ negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None
912+ ):
913+ negative_ip_adapter_image = np .zeros ((width , height , 3 ), dtype = np .uint8 )
914+ elif (ip_adapter_image is None and ip_adapter_image_embeds is None ) and (
915+ negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None
916+ ):
917+ ip_adapter_image = np .zeros ((width , height , 3 ), dtype = np .uint8 )
918+
919+ if self .joint_attention_kwargs is None :
920+ self ._joint_attention_kwargs = {}
921+
922+ image_embeds = None
923+ negative_image_embeds = None
924+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None :
925+ image_embeds = self .prepare_ip_adapter_image_embeds (
926+ ip_adapter_image ,
927+ ip_adapter_image_embeds ,
928+ device ,
929+ batch_size * num_images_per_prompt ,
930+ )
931+ if negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None :
932+ negative_image_embeds = self .prepare_ip_adapter_image_embeds (
933+ negative_ip_adapter_image ,
934+ negative_ip_adapter_image_embeds ,
935+ device ,
936+ batch_size * num_images_per_prompt ,
937+ )
938+
794939 # 6. Denoising loop
795940 with self .progress_bar (total = num_inference_steps ) as progress_bar :
796941 for i , t in enumerate (timesteps ):
797942 if self .interrupt :
798943 continue
944+ if image_embeds is not None :
945+ self ._joint_attention_kwargs ["ip_adapter_image_embeds" ] = image_embeds
799946
800947 # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
801948 timestep = t .expand (latents .shape [0 ]).to (latents .dtype )
@@ -810,6 +957,22 @@ def __call__(
810957 joint_attention_kwargs = self .joint_attention_kwargs ,
811958 return_dict = False ,
812959 )[0 ]
960+ if do_true_cfg :
961+ if negative_image_embeds is not None :
962+ self ._joint_attention_kwargs ["ip_adapter_image_embeds" ] = negative_image_embeds
963+ neg_noise_pred = self .transformer (
964+ hidden_states = latents ,
965+ timestep = timestep / 1000 ,
966+ guidance = guidance ,
967+ pooled_projections = negative_pooled_prompt_embeds ,
968+ encoder_hidden_states = negative_prompt_embeds ,
969+ txt_ids = text_ids ,
970+ img_ids = latent_image_ids ,
971+ joint_attention_kwargs = self .joint_attention_kwargs ,
972+ return_dict = False ,
973+ )[0 ]
974+ noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred )
975+
813976
814977 # compute the previous noisy sample x_t -> x_t-1
815978 latents_dtype = latents .dtype
0 commit comments