@@ -252,6 +252,7 @@ def encode_prompt(
252252 num_images_per_prompt : int = 1 ,
253253 prompt_embeds : Optional [torch .FloatTensor ] = None ,
254254 negative_prompt_embeds : Optional [torch .FloatTensor ] = None ,
255+ do_classifier_free_guidance : bool = True ,
255256 max_sequence_length : int = 512 ,
256257 lora_scale : Optional [float ] = None ,
257258 ):
@@ -298,10 +299,22 @@ def encode_prompt(
298299 max_sequence_length = max_sequence_length ,
299300 device = device ,
300301 )
301-
302- if negative_prompt_embeds is None :
302+ if do_classifier_free_guidance and negative_prompt_embeds is None :
303303 negative_prompt = negative_prompt or ""
304304 negative_prompt = batch_size * [negative_prompt ] if isinstance (negative_prompt , str ) else negative_prompt
305+
306+ if prompt is not None and type (prompt ) is not type (negative_prompt ):
307+ raise TypeError (
308+ f"`negative_prompt` should be the same type to `prompt`, but got { type (negative_prompt )} !="
309+ f" { type (prompt )} ."
310+ )
311+ elif batch_size != len (negative_prompt ):
312+ raise ValueError (
313+ f"`negative_prompt`: { negative_prompt } has batch size { len (negative_prompt )} , but `prompt`:"
314+ f" { prompt } has batch size { batch_size } . Please make sure that passed `negative_prompt` matches"
315+ " the batch size of `prompt`."
316+ )
317+
305318 negative_prompt_embeds = self ._get_t5_prompt_embeds (
306319 prompt = negative_prompt ,
307320 num_images_per_prompt = num_images_per_prompt ,
@@ -693,6 +706,7 @@ def __call__(
693706 negative_prompt = negative_prompt ,
694707 prompt_embeds = prompt_embeds ,
695708 negative_prompt_embeds = negative_prompt_embeds ,
709+ do_classifier_free_guidance = self .do_classifier_free_guidance ,
696710 device = device ,
697711 num_images_per_prompt = num_images_per_prompt ,
698712 max_sequence_length = max_sequence_length ,
0 commit comments