1818import numpy as np
1919import torch
2020from transformers import (
21+ CLIPImageProcessor ,
2122 CLIPTextModel ,
2223 CLIPTokenizer ,
24+ CLIPVisionModelWithProjection ,
2325 T5EncoderModel ,
2426 T5TokenizerFast ,
2527)
2628
2729from ...image_processor import PipelineImageInput , VaeImageProcessor
28- from ...loaders import FluxLoraLoaderMixin , FromSingleFileMixin , TextualInversionLoaderMixin
30+ from ...loaders import FluxIPAdapterMixin , FluxLoraLoaderMixin , FromSingleFileMixin , TextualInversionLoaderMixin
2931from ...models .autoencoders import AutoencoderKL
3032from ...models .controlnets .controlnet_flux import FluxControlNetModel , FluxMultiControlNetModel
3133from ...models .transformers import FluxTransformer2DModel
@@ -171,7 +173,7 @@ def retrieve_timesteps(
171173 return timesteps , num_inference_steps
172174
173175
174- class FluxControlNetPipeline (DiffusionPipeline , FluxLoraLoaderMixin , FromSingleFileMixin ):
176+ class FluxControlNetPipeline (DiffusionPipeline , FluxLoraLoaderMixin , FromSingleFileMixin , FluxIPAdapterMixin ):
175177 r"""
176178 The Flux pipeline for text-to-image generation.
177179
@@ -198,8 +200,8 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
198200 [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
199201 """
200202
201- model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
202- _optional_components = []
203+ model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder-> transformer->vae"
204+ _optional_components = ["image_encoder" , "feature_extractor" ]
203205 _callback_tensor_inputs = ["latents" , "prompt_embeds" ]
204206
205207 def __init__ (
@@ -214,6 +216,8 @@ def __init__(
214216 controlnet : Union [
215217 FluxControlNetModel , List [FluxControlNetModel ], Tuple [FluxControlNetModel ], FluxMultiControlNetModel
216218 ],
219+ image_encoder : CLIPVisionModelWithProjection = None ,
220+ feature_extractor : CLIPImageProcessor = None ,
217221 ):
218222 super ().__init__ ()
219223 if isinstance (controlnet , (list , tuple )):
@@ -228,6 +232,8 @@ def __init__(
228232 transformer = transformer ,
229233 scheduler = scheduler ,
230234 controlnet = controlnet ,
235+ image_encoder = image_encoder ,
236+ feature_extractor = feature_extractor ,
231237 )
232238 self .vae_scale_factor = 2 ** (len (self .vae .config .block_out_channels ) - 1 ) if getattr (self , "vae" , None ) else 8
233239 # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
@@ -413,14 +419,62 @@ def encode_prompt(
413419
414420 return prompt_embeds , pooled_prompt_embeds , text_ids
415421
422+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_image
423+ def encode_image (self , image , device , num_images_per_prompt ):
424+ dtype = next (self .image_encoder .parameters ()).dtype
425+
426+ if not isinstance (image , torch .Tensor ):
427+ image = self .feature_extractor (image , return_tensors = "pt" ).pixel_values
428+
429+ image = image .to (device = device , dtype = dtype )
430+ image_embeds = self .image_encoder (image ).image_embeds
431+ image_embeds = image_embeds .repeat_interleave (num_images_per_prompt , dim = 0 )
432+ return image_embeds
433+
434+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_ip_adapter_image_embeds
435+ def prepare_ip_adapter_image_embeds (
436+ self , ip_adapter_image , ip_adapter_image_embeds , device , num_images_per_prompt
437+ ):
438+ image_embeds = []
439+ if ip_adapter_image_embeds is None :
440+ if not isinstance (ip_adapter_image , list ):
441+ ip_adapter_image = [ip_adapter_image ]
442+
443+ if len (ip_adapter_image ) != len (self .transformer .encoder_hid_proj .image_projection_layers ):
444+ raise ValueError (
445+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got { len (ip_adapter_image )} images and { len (self .transformer .encoder_hid_proj .image_projection_layers )} IP Adapters."
446+ )
447+
448+ for single_ip_adapter_image , image_proj_layer in zip (
449+ ip_adapter_image , self .transformer .encoder_hid_proj .image_projection_layers
450+ ):
451+ single_image_embeds = self .encode_image (single_ip_adapter_image , device , 1 )
452+
453+ image_embeds .append (single_image_embeds [None , :])
454+ else :
455+ for single_image_embeds in ip_adapter_image_embeds :
456+ image_embeds .append (single_image_embeds )
457+
458+ ip_adapter_image_embeds = []
459+ for i , single_image_embeds in enumerate (image_embeds ):
460+ single_image_embeds = torch .cat ([single_image_embeds ] * num_images_per_prompt , dim = 0 )
461+ single_image_embeds = single_image_embeds .to (device = device )
462+ ip_adapter_image_embeds .append (single_image_embeds )
463+
464+ return ip_adapter_image_embeds
465+
416466 def check_inputs (
417467 self ,
418468 prompt ,
419469 prompt_2 ,
420470 height ,
421471 width ,
472+ negative_prompt = None ,
473+ negative_prompt_2 = None ,
422474 prompt_embeds = None ,
475+ negative_prompt_embeds = None ,
423476 pooled_prompt_embeds = None ,
477+ negative_pooled_prompt_embeds = None ,
424478 callback_on_step_end_tensor_inputs = None ,
425479 max_sequence_length = None ,
426480 ):
@@ -455,10 +509,33 @@ def check_inputs(
455509 elif prompt_2 is not None and (not isinstance (prompt_2 , str ) and not isinstance (prompt_2 , list )):
456510 raise ValueError (f"`prompt_2` has to be of type `str` or `list` but is { type (prompt_2 )} " )
457511
512+ if negative_prompt is not None and negative_prompt_embeds is not None :
513+ raise ValueError (
514+ f"Cannot forward both `negative_prompt`: { negative_prompt } and `negative_prompt_embeds`:"
515+ f" { negative_prompt_embeds } . Please make sure to only forward one of the two."
516+ )
517+ elif negative_prompt_2 is not None and negative_prompt_embeds is not None :
518+ raise ValueError (
519+ f"Cannot forward both `negative_prompt_2`: { negative_prompt_2 } and `negative_prompt_embeds`:"
520+ f" { negative_prompt_embeds } . Please make sure to only forward one of the two."
521+ )
522+
523+ if prompt_embeds is not None and negative_prompt_embeds is not None :
524+ if prompt_embeds .shape != negative_prompt_embeds .shape :
525+ raise ValueError (
526+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
527+ f" got: `prompt_embeds` { prompt_embeds .shape } != `negative_prompt_embeds`"
528+ f" { negative_prompt_embeds .shape } ."
529+ )
530+
458531 if prompt_embeds is not None and pooled_prompt_embeds is None :
459532 raise ValueError (
460533 "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
461534 )
535+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None :
536+ raise ValueError (
537+ "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
538+ )
462539
463540 if max_sequence_length is not None and max_sequence_length > 512 :
464541 raise ValueError (f"`max_sequence_length` cannot be greater than 512 but is { max_sequence_length } " )
@@ -597,6 +674,9 @@ def __call__(
597674 self ,
598675 prompt : Union [str , List [str ]] = None ,
599676 prompt_2 : Optional [Union [str , List [str ]]] = None ,
677+ negative_prompt : Union [str , List [str ]] = None ,
678+ negative_prompt_2 : Optional [Union [str , List [str ]]] = None ,
679+ true_cfg_scale : float = 1.0 ,
600680 height : Optional [int ] = None ,
601681 width : Optional [int ] = None ,
602682 num_inference_steps : int = 28 ,
@@ -612,6 +692,12 @@ def __call__(
612692 latents : Optional [torch .FloatTensor ] = None ,
613693 prompt_embeds : Optional [torch .FloatTensor ] = None ,
614694 pooled_prompt_embeds : Optional [torch .FloatTensor ] = None ,
695+ ip_adapter_image : Optional [PipelineImageInput ] = None ,
696+ ip_adapter_image_embeds : Optional [List [torch .Tensor ]] = None ,
697+ negative_ip_adapter_image : Optional [PipelineImageInput ] = None ,
698+ negative_ip_adapter_image_embeds : Optional [List [torch .Tensor ]] = None ,
699+ negative_prompt_embeds : Optional [torch .FloatTensor ] = None ,
700+ negative_pooled_prompt_embeds : Optional [torch .FloatTensor ] = None ,
615701 output_type : Optional [str ] = "pil" ,
616702 return_dict : bool = True ,
617703 joint_attention_kwargs : Optional [Dict [str , Any ]] = None ,
@@ -679,6 +765,17 @@ def __call__(
679765 pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
680766 Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
681767 If not provided, pooled text embeddings will be generated from `prompt` input argument.
768+ ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
769+ ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
770+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
771+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
772+ provided, embeddings are computed from the `ip_adapter_image` input argument.
773+ negative_ip_adapter_image:
774+ (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
775+ negative_ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
776+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
777+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
778+ provided, embeddings are computed from the `ip_adapter_image` input argument.
682779 output_type (`str`, *optional*, defaults to `"pil"`):
683780 The output format of the generate image. Choose between
684781 [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
@@ -727,8 +824,12 @@ def __call__(
727824 prompt_2 ,
728825 height ,
729826 width ,
827+ negative_prompt = negative_prompt ,
828+ negative_prompt_2 = negative_prompt_2 ,
730829 prompt_embeds = prompt_embeds ,
830+ negative_prompt_embeds = negative_prompt_embeds ,
731831 pooled_prompt_embeds = pooled_prompt_embeds ,
832+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds ,
732833 callback_on_step_end_tensor_inputs = callback_on_step_end_tensor_inputs ,
733834 max_sequence_length = max_sequence_length ,
734835 )
@@ -752,6 +853,7 @@ def __call__(
752853 lora_scale = (
753854 self .joint_attention_kwargs .get ("scale" , None ) if self .joint_attention_kwargs is not None else None
754855 )
856+ do_true_cfg = true_cfg_scale > 1 and negative_prompt is not None
755857 (
756858 prompt_embeds ,
757859 pooled_prompt_embeds ,
@@ -766,6 +868,21 @@ def __call__(
766868 max_sequence_length = max_sequence_length ,
767869 lora_scale = lora_scale ,
768870 )
871+ if do_true_cfg :
872+ (
873+ negative_prompt_embeds ,
874+ negative_pooled_prompt_embeds ,
875+ _ ,
876+ ) = self .encode_prompt (
877+ prompt = negative_prompt ,
878+ prompt_2 = negative_prompt_2 ,
879+ prompt_embeds = negative_prompt_embeds ,
880+ pooled_prompt_embeds = negative_pooled_prompt_embeds ,
881+ device = device ,
882+ num_images_per_prompt = num_images_per_prompt ,
883+ max_sequence_length = max_sequence_length ,
884+ lora_scale = lora_scale ,
885+ )
769886
770887 # 3. Prepare control image
771888 num_channels_latents = self .transformer .config .in_channels // 4
@@ -899,12 +1016,43 @@ def __call__(
8991016 ]
9001017 controlnet_keep .append (keeps [0 ] if isinstance (self .controlnet , FluxControlNetModel ) else keeps )
9011018
1019+ if (ip_adapter_image is not None or ip_adapter_image_embeds is not None ) and (
1020+ negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None
1021+ ):
1022+ negative_ip_adapter_image = np .zeros ((width , height , 3 ), dtype = np .uint8 )
1023+ elif (ip_adapter_image is None and ip_adapter_image_embeds is None ) and (
1024+ negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None
1025+ ):
1026+ ip_adapter_image = np .zeros ((width , height , 3 ), dtype = np .uint8 )
1027+
1028+ if self .joint_attention_kwargs is None :
1029+ self ._joint_attention_kwargs = {}
1030+
1031+ image_embeds = None
1032+ negative_image_embeds = None
1033+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None :
1034+ image_embeds = self .prepare_ip_adapter_image_embeds (
1035+ ip_adapter_image ,
1036+ ip_adapter_image_embeds ,
1037+ device ,
1038+ batch_size * num_images_per_prompt ,
1039+ )
1040+ if negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None :
1041+ negative_image_embeds = self .prepare_ip_adapter_image_embeds (
1042+ negative_ip_adapter_image ,
1043+ negative_ip_adapter_image_embeds ,
1044+ device ,
1045+ batch_size * num_images_per_prompt ,
1046+ )
1047+
9021048 # 7. Denoising loop
9031049 with self .progress_bar (total = num_inference_steps ) as progress_bar :
9041050 for i , t in enumerate (timesteps ):
9051051 if self .interrupt :
9061052 continue
9071053
1054+ if image_embeds is not None :
1055+ self ._joint_attention_kwargs ["ip_adapter_image_embeds" ] = image_embeds
9081056 # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
9091057 timestep = t .expand (latents .shape [0 ]).to (latents .dtype )
9101058
@@ -960,6 +1108,25 @@ def __call__(
9601108 controlnet_blocks_repeat = controlnet_blocks_repeat ,
9611109 )[0 ]
9621110
1111+ if do_true_cfg :
1112+ if negative_image_embeds is not None :
1113+ self ._joint_attention_kwargs ["ip_adapter_image_embeds" ] = negative_image_embeds
1114+ neg_noise_pred = self .transformer (
1115+ hidden_states = latents ,
1116+ timestep = timestep / 1000 ,
1117+ guidance = guidance ,
1118+ pooled_projections = negative_pooled_prompt_embeds ,
1119+ encoder_hidden_states = negative_prompt_embeds ,
1120+ controlnet_block_samples = controlnet_block_samples ,
1121+ controlnet_single_block_samples = controlnet_single_block_samples ,
1122+ txt_ids = text_ids ,
1123+ img_ids = latent_image_ids ,
1124+ joint_attention_kwargs = self .joint_attention_kwargs ,
1125+ return_dict = False ,
1126+ controlnet_blocks_repeat = controlnet_blocks_repeat ,
1127+ )[0 ]
1128+ noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred )
1129+
9631130 # compute the previous noisy sample x_t -> x_t-1
9641131 latents_dtype = latents .dtype
9651132 latents = self .scheduler .step (noise_pred , t , latents , return_dict = False )[0 ]
0 commit comments