@@ -177,7 +177,7 @@ class FluxPipeline(
177177 [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
178178 """
179179
180- model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
180+ model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder-> transformer->vae"
181181 _optional_components = ["image_encoder" , "feature_extractor" ]
182182 _callback_tensor_inputs = ["latents" , "prompt_embeds" ]
183183
@@ -314,17 +314,12 @@ def encode_prompt(
314314 self ,
315315 prompt : Union [str , List [str ]],
316316 prompt_2 : Union [str , List [str ]],
317- negative_prompt : Union [str , List [str ]] = None ,
318- negative_prompt_2 : Union [str , List [str ]] = None ,
319317 device : Optional [torch .device ] = None ,
320318 num_images_per_prompt : int = 1 ,
321319 prompt_embeds : Optional [torch .FloatTensor ] = None ,
322320 pooled_prompt_embeds : Optional [torch .FloatTensor ] = None ,
323- negative_prompt_embeds : Optional [torch .FloatTensor ] = None ,
324- negative_pooled_prompt_embeds : Optional [torch .FloatTensor ] = None ,
325321 max_sequence_length : int = 512 ,
326322 lora_scale : Optional [float ] = None ,
327- do_true_cfg : bool = False ,
328323 ):
329324 r"""
330325
@@ -361,62 +356,24 @@ def encode_prompt(
361356 scale_lora_layers (self .text_encoder_2 , lora_scale )
362357
363358 prompt = [prompt ] if isinstance (prompt , str ) else prompt
364- if prompt is not None :
365- batch_size = len (prompt )
366- else :
367- batch_size = prompt_embeds .shape [0 ]
368-
369- if do_true_cfg and negative_prompt is not None :
370- negative_prompt = [negative_prompt ] if isinstance (negative_prompt , str ) else negative_prompt
371- negative_batch_size = len (negative_prompt )
372-
373- if negative_batch_size != batch_size :
374- raise ValueError (
375- f"Negative prompt batch size ({ negative_batch_size } ) does not match prompt batch size ({ batch_size } )"
376- )
377-
378- # Concatenate prompts
379- prompts = prompt + negative_prompt
380- prompts_2 = (
381- prompt_2 + negative_prompt_2 if prompt_2 is not None and negative_prompt_2 is not None else None
382- )
383- else :
384- prompts = prompt
385- prompts_2 = prompt_2
386359
387360 if prompt_embeds is None :
388- if prompts_2 is None :
389- prompts_2 = prompts
361+ prompt_2 = prompt_2 or prompt
362+ prompt_2 = [ prompt_2 ] if isinstance ( prompt_2 , str ) else prompt_2
390363
391364 # We only use the pooled prompt output from the CLIPTextModel
392365 pooled_prompt_embeds = self ._get_clip_prompt_embeds (
393- prompt = prompts ,
366+ prompt = prompt ,
394367 device = device ,
395368 num_images_per_prompt = num_images_per_prompt ,
396369 )
397370 prompt_embeds = self ._get_t5_prompt_embeds (
398- prompt = prompts_2 ,
371+ prompt = prompt_2 ,
399372 num_images_per_prompt = num_images_per_prompt ,
400373 max_sequence_length = max_sequence_length ,
401374 device = device ,
402375 )
403376
404- if do_true_cfg and negative_prompt is not None :
405- # Split embeddings back into positive and negative parts
406- total_batch_size = batch_size * num_images_per_prompt
407- positive_indices = slice (0 , total_batch_size )
408- negative_indices = slice (total_batch_size , 2 * total_batch_size )
409-
410- positive_pooled_prompt_embeds = pooled_prompt_embeds [positive_indices ]
411- negative_pooled_prompt_embeds = pooled_prompt_embeds [negative_indices ]
412-
413- positive_prompt_embeds = prompt_embeds [positive_indices ]
414- negative_prompt_embeds = prompt_embeds [negative_indices ]
415-
416- pooled_prompt_embeds = positive_pooled_prompt_embeds
417- prompt_embeds = positive_prompt_embeds
418-
419- # Unscale LoRA layers
420377 if self .text_encoder is not None :
421378 if isinstance (self , FluxLoraLoaderMixin ) and USE_PEFT_BACKEND :
422379 # Retrieve the original scale by scaling back the LoRA layers
@@ -430,16 +387,7 @@ def encode_prompt(
430387 dtype = self .text_encoder .dtype if self .text_encoder is not None else self .transformer .dtype
431388 text_ids = torch .zeros (prompt_embeds .shape [1 ], 3 ).to (device = device , dtype = dtype )
432389
433- if do_true_cfg and negative_prompt is not None :
434- return (
435- prompt_embeds ,
436- pooled_prompt_embeds ,
437- text_ids ,
438- negative_prompt_embeds ,
439- negative_pooled_prompt_embeds ,
440- )
441- else :
442- return prompt_embeds , pooled_prompt_embeds , text_ids , None , None
390+ return prompt_embeds , pooled_prompt_embeds , text_ids
443391
444392 def encode_image (self , image , device , num_images_per_prompt ):
445393 dtype = next (self .image_encoder .parameters ()).dtype
@@ -832,22 +780,29 @@ def __call__(
832780 prompt_embeds ,
833781 pooled_prompt_embeds ,
834782 text_ids ,
835- negative_prompt_embeds ,
836- negative_pooled_prompt_embeds ,
837783 ) = self .encode_prompt (
838784 prompt = prompt ,
839785 prompt_2 = prompt_2 ,
840- negative_prompt = negative_prompt ,
841- negative_prompt_2 = negative_prompt_2 ,
842786 prompt_embeds = prompt_embeds ,
843787 pooled_prompt_embeds = pooled_prompt_embeds ,
844- negative_prompt_embeds = negative_prompt_embeds ,
845- negative_pooled_prompt_embeds = negative_pooled_prompt_embeds ,
846788 device = device ,
847789 num_images_per_prompt = num_images_per_prompt ,
848790 max_sequence_length = max_sequence_length ,
849791 lora_scale = lora_scale ,
850- do_true_cfg = do_true_cfg ,
792+ )
793+ (
794+ negative_prompt_embeds ,
795+ negative_pooled_prompt_embeds ,
796+ _ ,
797+ ) = self .encode_prompt (
798+ prompt = negative_prompt ,
799+ prompt_2 = negative_prompt_2 ,
800+ prompt_embeds = negative_prompt_embeds ,
801+ pooled_prompt_embeds = negative_pooled_prompt_embeds ,
802+ device = device ,
803+ num_images_per_prompt = num_images_per_prompt ,
804+ max_sequence_length = max_sequence_length ,
805+ lora_scale = lora_scale ,
851806 )
852807
853808 # 4. Prepare latent variables
0 commit comments