@@ -431,26 +431,35 @@ def __call__(
431431 elif prompt is not None and isinstance (prompt , list ):
432432 batch_size = len (prompt )
433433 else :
434- batch_size = prompt_embeds . shape [ 0 ]
434+ batch_size = len ( prompt_embeds )
435435
436436 lora_scale = (
437437 self .joint_attention_kwargs .get ("scale" , None ) if self .joint_attention_kwargs is not None else None
438438 )
439- (
440- prompt_embeds ,
441- negative_prompt_embeds ,
442- ) = self .encode_prompt (
443- prompt = prompt ,
444- negative_prompt = negative_prompt ,
445- do_classifier_free_guidance = self .do_classifier_free_guidance ,
446- prompt_embeds = prompt_embeds ,
447- negative_prompt_embeds = negative_prompt_embeds ,
448- dtype = dtype ,
449- device = device ,
450- num_images_per_prompt = num_images_per_prompt ,
451- max_sequence_length = max_sequence_length ,
452- lora_scale = lora_scale ,
453- )
439+
440+ # If prompt_embeds is provided and prompt is None, skip encoding
441+ if prompt_embeds is not None and prompt is None :
442+ if self .do_classifier_free_guidance and negative_prompt_embeds is None :
443+ raise ValueError (
444+ "When `prompt_embeds` is provided without `prompt`, "
445+ "`negative_prompt_embeds` must also be provided for classifier-free guidance."
446+ )
447+ else :
448+ (
449+ prompt_embeds ,
450+ negative_prompt_embeds ,
451+ ) = self .encode_prompt (
452+ prompt = prompt ,
453+ negative_prompt = negative_prompt ,
454+ do_classifier_free_guidance = self .do_classifier_free_guidance ,
455+ prompt_embeds = prompt_embeds ,
456+ negative_prompt_embeds = negative_prompt_embeds ,
457+ dtype = dtype ,
458+ device = device ,
459+ num_images_per_prompt = num_images_per_prompt ,
460+ max_sequence_length = max_sequence_length ,
461+ lora_scale = lora_scale ,
462+ )
454463
455464 # 4. Prepare latent variables
456465 num_channels_latents = self .transformer .in_channels
0 commit comments