@@ -142,6 +142,34 @@ def __init__(
142142 self .tokenizer .model_max_length if hasattr (self , "tokenizer" ) and self .tokenizer is not None else 77
143143 )
144144
145+ def check_inputs (
146+ self ,
147+ prompt ,
148+ prompt_2 ,
149+ prompt_embeds = None ,
150+ pooled_prompt_embeds = None ,
151+ ):
152+
153+ if prompt is not None and prompt_embeds is not None :
154+ raise ValueError (
155+ f"Cannot forward both `prompt`: { prompt } and `prompt_embeds`: { prompt_embeds } . Please make sure to"
156+ " only forward one of the two."
157+ )
158+ elif prompt_2 is not None and prompt_embeds is not None :
159+ raise ValueError (
160+ f"Cannot forward both `prompt_2`: { prompt_2 } and `prompt_embeds`: { prompt_embeds } . Please make sure to"
161+ " only forward one of the two."
162+ )
163+ elif prompt is not None and (not isinstance (prompt , str ) and not isinstance (prompt , list )):
164+ raise ValueError (f"`prompt` has to be of type `str` or `list` but is { type (prompt )} " )
165+ elif prompt_2 is not None and (not isinstance (prompt_2 , str ) and not isinstance (prompt_2 , list )):
166+ raise ValueError (f"`prompt_2` has to be of type `str` or `list` but is { type (prompt_2 )} " )
167+
168+ if prompt_embeds is not None and pooled_prompt_embeds is None :
169+ raise ValueError (
170+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
171+ )
172+
145173 def encode_image (self , image , device , num_images_per_prompt ):
146174 dtype = next (self .image_encoder .parameters ()).dtype
147175 image = self .feature_extractor .preprocess (
@@ -367,6 +395,11 @@ def __call__(
367395 batch_size = len (image )
368396 else :
369397 batch_size = image .shape [0 ]
398+ if prompt is not None and isinstance (prompt , str ):
399+ prompt = batch_size * [prompt ]
400+
401+
402+
370403 device = self ._execution_device
371404
372405 # 3. Prepare image embeddings
@@ -382,7 +415,7 @@ def __call__(
382415 pooled_prompt_embeds ,
383416 _ ,
384417 ) = self .encode_prompt (
385- prompt = prompt * batch_size ,
418+ prompt = prompt ,
386419 prompt_2 = prompt_2 ,
387420 prompt_embeds = prompt_embeds ,
388421 pooled_prompt_embeds = pooled_prompt_embeds ,
0 commit comments