@@ -178,7 +178,7 @@ class FluxPipeline(
178178 """
179179
180180 model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
181- _optional_components = []
181+ _optional_components = ["image_encoder" , "feature_extractor" ]
182182 _callback_tensor_inputs = ["latents" , "prompt_embeds" ]
183183
184184 def __init__ (
@@ -314,12 +314,17 @@ def encode_prompt(
314314 self ,
315315 prompt : Union [str , List [str ]],
316316 prompt_2 : Union [str , List [str ]],
317+ negative_prompt : Union [str , List [str ]] = None ,
318+ negative_prompt_2 : Union [str , List [str ]] = None ,
317319 device : Optional [torch .device ] = None ,
318320 num_images_per_prompt : int = 1 ,
319321 prompt_embeds : Optional [torch .FloatTensor ] = None ,
320322 pooled_prompt_embeds : Optional [torch .FloatTensor ] = None ,
323+ negative_prompt_embeds : Optional [torch .FloatTensor ] = None ,
324+ negative_pooled_prompt_embeds : Optional [torch .FloatTensor ] = None ,
321325 max_sequence_length : int = 512 ,
322326 lora_scale : Optional [float ] = None ,
327+ do_true_cfg : bool = False ,
323328 ):
324329 r"""
325330
@@ -356,24 +361,59 @@ def encode_prompt(
356361 scale_lora_layers (self .text_encoder_2 , lora_scale )
357362
358363 prompt = [prompt ] if isinstance (prompt , str ) else prompt
364+ batch_size = len (prompt )
365+
366+ if do_true_cfg and negative_prompt is not None :
367+ negative_prompt = [negative_prompt ] if isinstance (negative_prompt , str ) else negative_prompt
368+ negative_batch_size = len (negative_prompt )
369+
370+ if negative_batch_size != batch_size :
371+ raise ValueError (
372+ f"Negative prompt batch size ({ negative_batch_size } ) does not match prompt batch size ({ batch_size } )"
373+ )
374+
375+ # Concatenate prompts
376+ prompts = prompt + negative_prompt
377+ prompts_2 = (
378+ prompt_2 + negative_prompt_2 if prompt_2 is not None and negative_prompt_2 is not None else None
379+ )
380+ else :
381+ prompts = prompt
382+ prompts_2 = prompt_2
359383
360384 if prompt_embeds is None :
361- prompt_2 = prompt_2 or prompt
362- prompt_2 = [ prompt_2 ] if isinstance ( prompt_2 , str ) else prompt_2
385+ if prompts_2 is None :
386+ prompts_2 = prompts
363387
364388 # We only use the pooled prompt output from the CLIPTextModel
365389 pooled_prompt_embeds = self ._get_clip_prompt_embeds (
366- prompt = prompt ,
390+ prompt = prompts ,
367391 device = device ,
368392 num_images_per_prompt = num_images_per_prompt ,
369393 )
370394 prompt_embeds = self ._get_t5_prompt_embeds (
371- prompt = prompt_2 ,
395+ prompt = prompts_2 ,
372396 num_images_per_prompt = num_images_per_prompt ,
373397 max_sequence_length = max_sequence_length ,
374398 device = device ,
375399 )
376400
401+ if do_true_cfg and negative_prompt is not None :
402+ # Split embeddings back into positive and negative parts
403+ total_batch_size = batch_size * num_images_per_prompt
404+ positive_indices = slice (0 , total_batch_size )
405+ negative_indices = slice (total_batch_size , 2 * total_batch_size )
406+
407+ positive_pooled_prompt_embeds = pooled_prompt_embeds [positive_indices ]
408+ negative_pooled_prompt_embeds = pooled_prompt_embeds [negative_indices ]
409+
410+ positive_prompt_embeds = prompt_embeds [positive_indices ]
411+ negative_prompt_embeds = prompt_embeds [negative_indices ]
412+
413+ pooled_prompt_embeds = positive_pooled_prompt_embeds
414+ prompt_embeds = positive_prompt_embeds
415+
416+ # Unscale LoRA layers
377417 if self .text_encoder is not None :
378418 if isinstance (self , FluxLoraLoaderMixin ) and USE_PEFT_BACKEND :
379419 # Retrieve the original scale by scaling back the LoRA layers
@@ -387,7 +427,16 @@ def encode_prompt(
387427 dtype = self .text_encoder .dtype if self .text_encoder is not None else self .transformer .dtype
388428 text_ids = torch .zeros (prompt_embeds .shape [1 ], 3 ).to (device = device , dtype = dtype )
389429
390- return prompt_embeds , pooled_prompt_embeds , text_ids
430+ if do_true_cfg and negative_prompt is not None :
431+ return (
432+ prompt_embeds ,
433+ pooled_prompt_embeds ,
434+ text_ids ,
435+ negative_prompt_embeds ,
436+ negative_pooled_prompt_embeds ,
437+ )
438+ else :
439+ return prompt_embeds , pooled_prompt_embeds , text_ids , None , None
391440
392441 def encode_image (self , image , device , num_images_per_prompt ):
393442 dtype = next (self .image_encoder .parameters ()).dtype
@@ -439,8 +488,12 @@ def check_inputs(
439488 prompt_2 ,
440489 height ,
441490 width ,
491+ negative_prompt = None ,
492+ negative_prompt_2 = None ,
442493 prompt_embeds = None ,
494+ negative_prompt_embeds = None ,
443495 pooled_prompt_embeds = None ,
496+ negative_pooled_prompt_embeds = None ,
444497 callback_on_step_end_tensor_inputs = None ,
445498 max_sequence_length = None ,
446499 ):
@@ -475,10 +528,33 @@ def check_inputs(
475528 elif prompt_2 is not None and (not isinstance (prompt_2 , str ) and not isinstance (prompt_2 , list )):
476529 raise ValueError (f"`prompt_2` has to be of type `str` or `list` but is { type (prompt_2 )} " )
477530
531+ if negative_prompt is not None and negative_prompt_embeds is not None :
532+ raise ValueError (
533+ f"Cannot forward both `negative_prompt`: { negative_prompt } and `negative_prompt_embeds`:"
534+ f" { negative_prompt_embeds } . Please make sure to only forward one of the two."
535+ )
536+ elif negative_prompt_2 is not None and negative_prompt_embeds is not None :
537+ raise ValueError (
538+ f"Cannot forward both `negative_prompt_2`: { negative_prompt_2 } and `negative_prompt_embeds`:"
539+ f" { negative_prompt_embeds } . Please make sure to only forward one of the two."
540+ )
541+
542+ if prompt_embeds is not None and negative_prompt_embeds is not None :
543+ if prompt_embeds .shape != negative_prompt_embeds .shape :
544+ raise ValueError (
545+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
546+ f" got: `prompt_embeds` { prompt_embeds .shape } != `negative_prompt_embeds`"
547+ f" { negative_prompt_embeds .shape } ."
548+ )
549+
478550 if prompt_embeds is not None and pooled_prompt_embeds is None :
479551 raise ValueError (
480552 "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
481553 )
554+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None :
555+ raise ValueError (
556+ "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
557+ )
482558
483559 if max_sequence_length is not None and max_sequence_length > 512 :
484560 raise ValueError (f"`max_sequence_length` cannot be greater than 512 but is { max_sequence_length } " )
@@ -607,6 +683,9 @@ def __call__(
607683 self ,
608684 prompt : Union [str , List [str ]] = None ,
609685 prompt_2 : Optional [Union [str , List [str ]]] = None ,
686+ negative_prompt : Union [str , List [str ]] = None ,
687+ negative_prompt_2 : Optional [Union [str , List [str ]]] = None ,
688+ true_cfg : float = 1.0 ,
610689 height : Optional [int ] = None ,
611690 width : Optional [int ] = None ,
612691 num_inference_steps : int = 28 ,
@@ -619,6 +698,10 @@ def __call__(
619698 pooled_prompt_embeds : Optional [torch .FloatTensor ] = None ,
620699 ip_adapter_image : Optional [PipelineImageInput ] = None ,
621700 ip_adapter_image_embeds : Optional [List [torch .Tensor ]] = None ,
701+ negative_ip_adapter_image : Optional [PipelineImageInput ] = None ,
702+ negative_ip_adapter_image_embeds : Optional [List [torch .Tensor ]] = None ,
703+ negative_prompt_embeds : Optional [torch .FloatTensor ] = None ,
704+ negative_pooled_prompt_embeds : Optional [torch .FloatTensor ] = None ,
622705 output_type : Optional [str ] = "pil" ,
623706 return_dict : bool = True ,
624707 joint_attention_kwargs : Optional [Dict [str , Any ]] = None ,
@@ -673,6 +756,11 @@ def __call__(
673756 Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
674757 IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
675758 provided, embeddings are computed from the `ip_adapter_image` input argument.
759+ negative_ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
760+ negative_ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
761+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
762+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
763+ provided, embeddings are computed from the `ip_adapter_image` input argument.
676764 output_type (`str`, *optional*, defaults to `"pil"`):
677765 The output format of the generate image. Choose between
678766 [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
@@ -710,8 +798,12 @@ def __call__(
710798 prompt_2 ,
711799 height ,
712800 width ,
801+ negative_prompt = negative_prompt ,
802+ negative_prompt_2 = negative_prompt_2 ,
713803 prompt_embeds = prompt_embeds ,
804+ negative_prompt_embeds = negative_prompt_embeds ,
714805 pooled_prompt_embeds = pooled_prompt_embeds ,
806+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds ,
715807 callback_on_step_end_tensor_inputs = callback_on_step_end_tensor_inputs ,
716808 max_sequence_length = max_sequence_length ,
717809 )
@@ -733,19 +825,27 @@ def __call__(
733825 lora_scale = (
734826 self .joint_attention_kwargs .get ("scale" , None ) if self .joint_attention_kwargs is not None else None
735827 )
828+ do_true_cfg = true_cfg > 1 and negative_prompt is not None
736829 (
737830 prompt_embeds ,
738831 pooled_prompt_embeds ,
739832 text_ids ,
833+ negative_prompt_embeds ,
834+ negative_pooled_prompt_embeds ,
740835 ) = self .encode_prompt (
741836 prompt = prompt ,
742837 prompt_2 = prompt_2 ,
838+ negative_prompt = negative_prompt ,
839+ negative_prompt_2 = negative_prompt_2 ,
743840 prompt_embeds = prompt_embeds ,
744841 pooled_prompt_embeds = pooled_prompt_embeds ,
842+ negative_prompt_embeds = negative_prompt_embeds ,
843+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds ,
745844 device = device ,
746845 num_images_per_prompt = num_images_per_prompt ,
747846 max_sequence_length = max_sequence_length ,
748847 lora_scale = lora_scale ,
848+ do_true_cfg = do_true_cfg ,
749849 )
750850
751851 # 4. Prepare latent variables
@@ -788,23 +888,43 @@ def __call__(
788888 else :
789889 guidance = None
790890
891+ if (ip_adapter_image is not None or ip_adapter_image_embeds is not None ) and (
892+ negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None
893+ ):
894+ negative_ip_adapter_image = np .zeros ((width , height , 3 ), dtype = np .uint8 )
895+ elif (ip_adapter_image is None and ip_adapter_image_embeds is None ) and (
896+ negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None
897+ ):
898+ ip_adapter_image = np .zeros ((width , height , 3 ), dtype = np .uint8 )
899+
900+ if self .joint_attention_kwargs is None :
901+ self ._joint_attention_kwargs = {}
902+
903+ image_embeds = None
904+ negative_image_embeds = None
791905 if ip_adapter_image is not None or ip_adapter_image_embeds is not None :
792906 image_embeds = self .prepare_ip_adapter_image_embeds (
793907 ip_adapter_image ,
794908 ip_adapter_image_embeds ,
795909 device ,
796910 batch_size * num_images_per_prompt ,
797911 )
798- if self .joint_attention_kwargs is None :
799- self ._joint_attention_kwargs = {}
800912 self ._joint_attention_kwargs ["image_projection" ] = image_embeds
913+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None :
914+ negative_image_embeds = self .prepare_ip_adapter_image_embeds (
915+ negative_ip_adapter_image ,
916+ negative_ip_adapter_image_embeds ,
917+ device ,
918+ batch_size * num_images_per_prompt ,
919+ )
801920
802921 # 6. Denoising loop
803922 with self .progress_bar (total = num_inference_steps ) as progress_bar :
804923 for i , t in enumerate (timesteps ):
805924 if self .interrupt :
806925 continue
807926
927+ self ._joint_attention_kwargs ["image_projection" ] = image_embeds
808928 # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
809929 timestep = t .expand (latents .shape [0 ]).to (latents .dtype )
810930
@@ -820,6 +940,21 @@ def __call__(
820940 return_dict = False ,
821941 )[0 ]
822942
943+ if do_true_cfg :
944+ self ._joint_attention_kwargs ["image_projection" ] = negative_image_embeds
945+ neg_noise_pred = self .transformer (
946+ hidden_states = latents ,
947+ timestep = timestep / 1000 ,
948+ guidance = guidance ,
949+ pooled_projections = negative_pooled_prompt_embeds ,
950+ encoder_hidden_states = negative_prompt_embeds ,
951+ txt_ids = text_ids ,
952+ img_ids = latent_image_ids ,
953+ joint_attention_kwargs = self .joint_attention_kwargs ,
954+ return_dict = False ,
955+ )[0 ]
956+ noise_pred = neg_noise_pred + true_cfg * (noise_pred - neg_noise_pred )
957+
823958 # compute the previous noisy sample x_t -> x_t-1
824959 latents_dtype = latents .dtype
825960 latents = self .scheduler .step (noise_pred , t , latents , return_dict = False )[0 ]
0 commit comments