@@ -314,12 +314,17 @@ def encode_prompt(
314314 self ,
315315 prompt : Union [str , List [str ]],
316316 prompt_2 : Union [str , List [str ]],
317+ negative_prompt : Union [str , List [str ]] = None ,
318+ negative_prompt_2 : Union [str , List [str ]] = None ,
317319 device : Optional [torch .device ] = None ,
318320 num_images_per_prompt : int = 1 ,
319321 prompt_embeds : Optional [torch .FloatTensor ] = None ,
320322 pooled_prompt_embeds : Optional [torch .FloatTensor ] = None ,
323+ negative_prompt_embeds : Optional [torch .FloatTensor ] = None ,
324+ negative_pooled_prompt_embeds : Optional [torch .FloatTensor ] = None ,
321325 max_sequence_length : int = 512 ,
322326 lora_scale : Optional [float ] = None ,
327+ do_true_cfg : bool = False ,
323328 ):
324329 r"""
325330
@@ -356,24 +361,59 @@ def encode_prompt(
356361 scale_lora_layers (self .text_encoder_2 , lora_scale )
357362
358363 prompt = [prompt ] if isinstance (prompt , str ) else prompt
364+ batch_size = len (prompt )
365+
366+ if do_true_cfg and negative_prompt is not None :
367+ negative_prompt = [negative_prompt ] if isinstance (negative_prompt , str ) else negative_prompt
368+ negative_batch_size = len (negative_prompt )
369+
370+ if negative_batch_size != batch_size :
371+ raise ValueError (
372+ f"Negative prompt batch size ({ negative_batch_size } ) does not match prompt batch size ({ batch_size } )"
373+ )
374+
375+ # Concatenate prompts
376+ prompts = prompt + negative_prompt
377+ prompts_2 = (
378+ prompt_2 + negative_prompt_2 if prompt_2 is not None and negative_prompt_2 is not None else None
379+ )
380+ else :
381+ prompts = prompt
382+ prompts_2 = prompt_2
359383
360384 if prompt_embeds is None :
361- prompt_2 = prompt_2 or prompt
362- prompt_2 = [ prompt_2 ] if isinstance ( prompt_2 , str ) else prompt_2
385+ if prompts_2 is None :
386+ prompts_2 = prompts
363387
364388 # We only use the pooled prompt output from the CLIPTextModel
365389 pooled_prompt_embeds = self ._get_clip_prompt_embeds (
366- prompt = prompt ,
390+ prompt = prompts ,
367391 device = device ,
368392 num_images_per_prompt = num_images_per_prompt ,
369393 )
370394 prompt_embeds = self ._get_t5_prompt_embeds (
371- prompt = prompt_2 ,
395+ prompt = prompts_2 ,
372396 num_images_per_prompt = num_images_per_prompt ,
373397 max_sequence_length = max_sequence_length ,
374398 device = device ,
375399 )
376400
401+ if do_true_cfg and negative_prompt is not None :
402+ # Split embeddings back into positive and negative parts
403+ total_batch_size = batch_size * num_images_per_prompt
404+ positive_indices = slice (0 , total_batch_size )
405+ negative_indices = slice (total_batch_size , 2 * total_batch_size )
406+
407+ positive_pooled_prompt_embeds = pooled_prompt_embeds [positive_indices ]
408+ negative_pooled_prompt_embeds = pooled_prompt_embeds [negative_indices ]
409+
410+ positive_prompt_embeds = prompt_embeds [positive_indices ]
411+ negative_prompt_embeds = prompt_embeds [negative_indices ]
412+
413+ pooled_prompt_embeds = positive_pooled_prompt_embeds
414+ prompt_embeds = positive_prompt_embeds
415+
416+ # Unscale LoRA layers
377417 if self .text_encoder is not None :
378418 if isinstance (self , FluxLoraLoaderMixin ) and USE_PEFT_BACKEND :
379419 # Retrieve the original scale by scaling back the LoRA layers
@@ -387,7 +427,16 @@ def encode_prompt(
387427 dtype = self .text_encoder .dtype if self .text_encoder is not None else self .transformer .dtype
388428 text_ids = torch .zeros (prompt_embeds .shape [1 ], 3 ).to (device = device , dtype = dtype )
389429
390- return prompt_embeds , pooled_prompt_embeds , text_ids
430+ if do_true_cfg and negative_prompt is not None :
431+ return (
432+ prompt_embeds ,
433+ pooled_prompt_embeds ,
434+ text_ids ,
435+ negative_prompt_embeds ,
436+ negative_pooled_prompt_embeds ,
437+ )
438+ else :
439+ return prompt_embeds , pooled_prompt_embeds , text_ids , None , None
391440
392441 def encode_image (self , image , device , num_images_per_prompt ):
393442 dtype = next (self .image_encoder .parameters ()).dtype
@@ -439,8 +488,12 @@ def check_inputs(
439488 prompt_2 ,
440489 height ,
441490 width ,
491+ negative_prompt = None ,
492+ negative_prompt_2 = None ,
442493 prompt_embeds = None ,
494+ negative_prompt_embeds = None ,
443495 pooled_prompt_embeds = None ,
496+ negative_pooled_prompt_embeds = None ,
444497 callback_on_step_end_tensor_inputs = None ,
445498 max_sequence_length = None ,
446499 ):
@@ -475,10 +528,33 @@ def check_inputs(
475528 elif prompt_2 is not None and (not isinstance (prompt_2 , str ) and not isinstance (prompt_2 , list )):
476529 raise ValueError (f"`prompt_2` has to be of type `str` or `list` but is { type (prompt_2 )} " )
477530
531+ if negative_prompt is not None and negative_prompt_embeds is not None :
532+ raise ValueError (
533+ f"Cannot forward both `negative_prompt`: { negative_prompt } and `negative_prompt_embeds`:"
534+ f" { negative_prompt_embeds } . Please make sure to only forward one of the two."
535+ )
536+ elif negative_prompt_2 is not None and negative_prompt_embeds is not None :
537+ raise ValueError (
538+ f"Cannot forward both `negative_prompt_2`: { negative_prompt_2 } and `negative_prompt_embeds`:"
539+ f" { negative_prompt_embeds } . Please make sure to only forward one of the two."
540+ )
541+
542+ if prompt_embeds is not None and negative_prompt_embeds is not None :
543+ if prompt_embeds .shape != negative_prompt_embeds .shape :
544+ raise ValueError (
545+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
546+ f" got: `prompt_embeds` { prompt_embeds .shape } != `negative_prompt_embeds`"
547+ f" { negative_prompt_embeds .shape } ."
548+ )
549+
478550 if prompt_embeds is not None and pooled_prompt_embeds is None :
479551 raise ValueError (
480552 "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
481553 )
554+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None :
555+ raise ValueError (
556+ "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
557+ )
482558
483559 if max_sequence_length is not None and max_sequence_length > 512 :
484560 raise ValueError (f"`max_sequence_length` cannot be greater than 512 but is { max_sequence_length } " )
@@ -607,6 +683,9 @@ def __call__(
607683 self ,
608684 prompt : Union [str , List [str ]] = None ,
609685 prompt_2 : Optional [Union [str , List [str ]]] = None ,
686+ negative_prompt : Union [str , List [str ]] = None ,
687+ negative_prompt_2 : Optional [Union [str , List [str ]]] = None ,
688+ true_cfg : float = 1.0 ,
610689 height : Optional [int ] = None ,
611690 width : Optional [int ] = None ,
612691 num_inference_steps : int = 28 ,
@@ -619,6 +698,10 @@ def __call__(
619698 pooled_prompt_embeds : Optional [torch .FloatTensor ] = None ,
620699 ip_adapter_image : Optional [PipelineImageInput ] = None ,
621700 ip_adapter_image_embeds : Optional [List [torch .Tensor ]] = None ,
701+ negative_ip_adapter_image : Optional [PipelineImageInput ] = None ,
702+ negative_ip_adapter_image_embeds : Optional [List [torch .Tensor ]] = None ,
703+ negative_prompt_embeds : Optional [torch .FloatTensor ] = None ,
704+ negative_pooled_prompt_embeds : Optional [torch .FloatTensor ] = None ,
622705 output_type : Optional [str ] = "pil" ,
623706 return_dict : bool = True ,
624707 joint_attention_kwargs : Optional [Dict [str , Any ]] = None ,
@@ -673,6 +756,11 @@ def __call__(
673756 Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
674757 IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
675758 provided, embeddings are computed from the `ip_adapter_image` input argument.
759+ negative_ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
760+ negative_ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
761+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
762+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
763+ provided, embeddings are computed from the `ip_adapter_image` input argument.
676764 output_type (`str`, *optional*, defaults to `"pil"`):
677765 The output format of the generate image. Choose between
678766 [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
@@ -710,8 +798,12 @@ def __call__(
710798 prompt_2 ,
711799 height ,
712800 width ,
801+ negative_prompt = negative_prompt ,
802+ negative_prompt_2 = negative_prompt_2 ,
713803 prompt_embeds = prompt_embeds ,
804+ negative_prompt_embeds = negative_prompt_embeds ,
714805 pooled_prompt_embeds = pooled_prompt_embeds ,
806+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds ,
715807 callback_on_step_end_tensor_inputs = callback_on_step_end_tensor_inputs ,
716808 max_sequence_length = max_sequence_length ,
717809 )
@@ -733,21 +825,34 @@ def __call__(
733825 lora_scale = (
734826 self .joint_attention_kwargs .get ("scale" , None ) if self .joint_attention_kwargs is not None else None
735827 )
828+ do_true_cfg = true_cfg > 1 and negative_prompt is not None
736829 (
737830 prompt_embeds ,
738831 pooled_prompt_embeds ,
739832 text_ids ,
833+ negative_prompt_embeds ,
834+ negative_pooled_prompt_embeds ,
740835 ) = self .encode_prompt (
741836 prompt = prompt ,
742837 prompt_2 = prompt_2 ,
838+ negative_prompt = negative_prompt ,
839+ negative_prompt_2 = negative_prompt_2 ,
743840 prompt_embeds = prompt_embeds ,
744841 pooled_prompt_embeds = pooled_prompt_embeds ,
842+ negative_prompt_embeds = negative_prompt_embeds ,
843+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds ,
745844 device = device ,
746845 num_images_per_prompt = num_images_per_prompt ,
747846 max_sequence_length = max_sequence_length ,
748847 lora_scale = lora_scale ,
848+ do_true_cfg = do_true_cfg ,
749849 )
750850
851+ if do_true_cfg :
852+ # Concatenate embeddings
853+ prompt_embeds = torch .cat ([negative_prompt_embeds , prompt_embeds ], dim = 0 )
854+ pooled_prompt_embeds = torch .cat ([negative_pooled_prompt_embeds , pooled_prompt_embeds ], dim = 0 )
855+
751856 # 4. Prepare latent variables
752857 num_channels_latents = self .transformer .config .in_channels // 4
753858 latents , latent_image_ids = self .prepare_latents (
@@ -781,12 +886,17 @@ def __call__(
781886 num_warmup_steps = max (len (timesteps ) - num_inference_steps * self .scheduler .order , 0 )
782887 self ._num_timesteps = len (timesteps )
783888
784- # handle guidance
785- if self .transformer .config .guidance_embeds :
786- guidance = torch .full ([1 ], guidance_scale , device = device , dtype = torch .float32 )
787- guidance = guidance .expand (latents .shape [0 ])
788- else :
789- guidance = None
889+ if (ip_adapter_image is not None or ip_adapter_image_embeds is not None ) and (
890+ negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None
891+ ):
892+ negative_ip_adapter_image = np .zeros ((width , height , 3 ), dtype = np .uint8 )
893+ elif (ip_adapter_image is None and ip_adapter_image_embeds is None ) and (
894+ negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None
895+ ):
896+ ip_adapter_image = np .zeros ((width , height , 3 ), dtype = np .uint8 )
897+
898+ if self .joint_attention_kwargs is None :
899+ self ._joint_attention_kwargs = {}
790900
791901 if ip_adapter_image is not None or ip_adapter_image_embeds is not None :
792902 image_embeds = self .prepare_ip_adapter_image_embeds (
@@ -795,21 +905,37 @@ def __call__(
795905 device ,
796906 batch_size * num_images_per_prompt ,
797907 )
798- if self .joint_attention_kwargs is None :
799- self ._joint_attention_kwargs = {}
800908 self ._joint_attention_kwargs ["image_projection" ] = image_embeds
909+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None :
910+ negative_image_embeds = self .prepare_ip_adapter_image_embeds (
911+ negative_ip_adapter_image ,
912+ negative_ip_adapter_image_embeds ,
913+ device ,
914+ batch_size * num_images_per_prompt ,
915+ )
916+ image_embeds = self ._joint_attention_kwargs ["image_projection" ]
917+ self ._joint_attention_kwargs ["image_projection" ] = torch .cat ([negative_image_embeds , image_embeds ])
801918
802919 # 6. Denoising loop
803920 with self .progress_bar (total = num_inference_steps ) as progress_bar :
804921 for i , t in enumerate (timesteps ):
805922 if self .interrupt :
806923 continue
807924
925+ latent_model_input = torch .cat ([latents ] * 2 ) if do_true_cfg else latents
926+
927+ # handle guidance
928+ if self .transformer .config .guidance_embeds :
929+ guidance = torch .full ([1 ], guidance_scale , device = device , dtype = torch .float32 )
930+ guidance = guidance .expand (latent_model_input .shape [0 ])
931+ else :
932+ guidance = None
933+
808934 # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
809- timestep = t .expand (latents .shape [0 ]).to (latents .dtype )
935+ timestep = t .expand (latent_model_input .shape [0 ]).to (latent_model_input .dtype )
810936
811937 noise_pred = self .transformer (
812- hidden_states = latents ,
938+ hidden_states = latent_model_input ,
813939 timestep = timestep / 1000 ,
814940 guidance = guidance ,
815941 pooled_projections = pooled_prompt_embeds ,
@@ -820,6 +946,10 @@ def __call__(
820946 return_dict = False ,
821947 )[0 ]
822948
949+ if do_true_cfg :
950+ neg_noise_pred , noise_pred = noise_pred .chunk (2 )
951+ noise_pred = neg_noise_pred + true_cfg * (noise_pred - neg_noise_pred )
952+
823953 # compute the previous noisy sample x_t -> x_t-1
824954 latents_dtype = latents .dtype
825955 latents = self .scheduler .step (noise_pred , t , latents , return_dict = False )[0 ]
0 commit comments