@@ -351,10 +351,10 @@ def encode_prompt(
351351 if do_classifier_free_guidance and negative_pooled_prompt_embeds is None :
352352 negative_prompt = negative_prompt or ""
353353 negative_prompt = [negative_prompt ] if isinstance (negative_prompt , str ) else negative_prompt
354-
354+
355355 if len (negative_prompt ) > 1 and len (negative_prompt ) != batch_size :
356356 raise ValueError (f"negative_prompt must be of length 1 or { batch_size } " )
357-
357+
358358 negative_pooled_prompt_embeds_1 = self ._get_clip_prompt_embeds (
359359 self .tokenizer , self .text_encoder , negative_prompt , max_sequence_length , device , dtype
360360 )
@@ -406,7 +406,7 @@ def encode_prompt(
406406 raise ValueError (f"prompt_3 must be of length 1 or { batch_size } " )
407407
408408 t5_prompt_embeds = self ._get_t5_prompt_embeds (prompt_3 , max_sequence_length , device , dtype )
409-
409+
410410 if t5_prompt_embeds .shape [0 ] == 1 and batch_size > 1 :
411411 t5_prompt_embeds = t5_prompt_embeds .repeat (batch_size , 1 , 1 )
412412
@@ -420,7 +420,7 @@ def encode_prompt(
420420 negative_t5_prompt_embeds = self ._get_t5_prompt_embeds (
421421 negative_prompt_3 , max_sequence_length , device , dtype
422422 )
423-
423+
424424 if negative_t5_prompt_embeds .shape [0 ] == 1 and batch_size > 1 :
425425 negative_t5_prompt_embeds = negative_t5_prompt_embeds .repeat (batch_size , 1 , 1 )
426426
@@ -432,7 +432,7 @@ def encode_prompt(
432432 raise ValueError (f"prompt_4 must be of length 1 or { batch_size } " )
433433
434434 llama3_prompt_embeds = self ._get_llama3_prompt_embeds (prompt_4 , max_sequence_length , device , dtype )
435-
435+
436436 if llama3_prompt_embeds .shape [0 ] == 1 and batch_size > 1 :
437437 llama3_prompt_embeds = llama3_prompt_embeds .repeat (1 , batch_size , 1 , 1 )
438438
@@ -446,10 +446,10 @@ def encode_prompt(
446446 negative_llama3_prompt_embeds = self ._get_llama3_prompt_embeds (
447447 negative_prompt_4 , max_sequence_length , device , dtype
448448 )
449-
449+
450450 if negative_llama3_prompt_embeds .shape [0 ] == 1 and batch_size > 1 :
451451 negative_llama3_prompt_embeds = negative_llama3_prompt_embeds .repeat (1 , batch_size , 1 , 1 )
452-
452+
453453 # duplicate pooled_prompt_embeds for each generation per prompt
454454 pooled_prompt_embeds = pooled_prompt_embeds .repeat (1 , num_images_per_prompt )
455455 pooled_prompt_embeds = pooled_prompt_embeds .view (batch_size * num_images_per_prompt , - 1 )
@@ -472,7 +472,7 @@ def encode_prompt(
472472 llama3_prompt_embeds = llama3_prompt_embeds .repeat (1 , 1 , num_images_per_prompt , 1 )
473473 llama3_prompt_embeds = llama3_prompt_embeds .view (- 1 , batch_size * num_images_per_prompt , seq_len , dim )
474474
475- if do_classifier_free_guidance :
475+ if do_classifier_free_guidance :
476476 # duplicate negative_pooled_prompt_embeds for batch_size and num_images_per_prompt
477477 bs_embed , seq_len = negative_pooled_prompt_embeds .shape
478478 if bs_embed == 1 and batch_size > 1 :
@@ -502,7 +502,14 @@ def encode_prompt(
502502 - 1 , batch_size * num_images_per_prompt , seq_len , dim
503503 )
504504
505- return t5_prompt_embeds , llama3_prompt_embeds , negative_t5_prompt_embeds , negative_llama3_prompt_embeds , pooled_prompt_embeds , negative_pooled_prompt_embeds
505+ return (
506+ t5_prompt_embeds ,
507+ llama3_prompt_embeds ,
508+ negative_t5_prompt_embeds ,
509+ negative_llama3_prompt_embeds ,
510+ pooled_prompt_embeds ,
511+ negative_pooled_prompt_embeds ,
512+ )
506513
507514 def enable_vae_slicing (self ):
508515 r"""
@@ -583,9 +590,13 @@ def check_inputs(
583590 "Provide either `prompt` or `pooled_prompt_embeds`. Cannot leave both `prompt` and `pooled_prompt_embeds` undefined."
584591 )
585592 elif prompt is None and t5_prompt_embeds is None :
586- raise ValueError ("Provide either `prompt` or `t5_prompt_embeds`. Cannot leave both `prompt` and `t5_prompt_embeds` undefined." )
593+ raise ValueError (
594+ "Provide either `prompt` or `t5_prompt_embeds`. Cannot leave both `prompt` and `t5_prompt_embeds` undefined."
595+ )
587596 elif prompt is None and llama3_prompt_embeds is None :
588- raise ValueError ("Provide either `prompt` or `llama3_prompt_embeds`. Cannot leave both `prompt` and `llama3_prompt_embeds` undefined." )
597+ raise ValueError (
598+ "Provide either `prompt` or `llama3_prompt_embeds`. Cannot leave both `prompt` and `llama3_prompt_embeds` undefined."
599+ )
589600 elif prompt is not None and (not isinstance (prompt , str ) and not isinstance (prompt , list )):
590601 raise ValueError (f"`prompt` has to be of type `str` or `list` but is { type (prompt )} " )
591602 elif prompt_2 is not None and (not isinstance (prompt_2 , str ) and not isinstance (prompt_2 , list )):
@@ -602,8 +613,8 @@ def check_inputs(
602613 )
603614 elif negative_prompt_2 is not None and negative_pooled_prompt_embeds is not None :
604615 raise ValueError (
605- f"Cannot forward both `negative_prompt_2`: { negative_prompt_2 } and `negative_prompt_embeds `:"
606- f" { negative_prompt_embeds } . Please make sure to only forward one of the two."
616+ f"Cannot forward both `negative_prompt_2`: { negative_prompt_2 } and `negative_pooled_prompt_embeds `:"
617+ f" { negative_pooled_prompt_embeds } . Please make sure to only forward one of the two."
607618 )
608619 elif negative_prompt_3 is not None and negative_t5_prompt_embeds is not None :
609620 raise ValueError (
@@ -638,8 +649,6 @@ def check_inputs(
638649 f" { negative_llama3_prompt_embeds .shape } ."
639650 )
640651
641-
642-
643652 def prepare_latents (
644653 self ,
645654 batch_size ,
@@ -755,10 +764,8 @@ def __call__(
755764 batch_size = 1
756765 elif prompt is not None and isinstance (prompt , list ):
757766 batch_size = len (prompt )
758- elif prompt_embeds is not None :
759- batch_size = prompt_embeds [0 ].shape [0 ] if isinstance (prompt_embeds , list ) else prompt_embeds .shape [0 ]
760- else :
761- batch_size = 1
767+ elif pooled_prompt_embeds is not None :
768+ batch_size = pooled_prompt_embeds .shape [0 ]
762769
763770 device = self ._execution_device
764771
0 commit comments