Skip to content

Commit bf2e149

Browse files
committed
check inputs
1 parent ac266fb commit bf2e149

File tree

1 file changed

+34
-1
lines changed

1 file changed

+34
-1
lines changed

src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,34 @@ def __init__(
142142
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
143143
)
144144

145+
def check_inputs(
146+
self,
147+
prompt,
148+
prompt_2,
149+
prompt_embeds=None,
150+
pooled_prompt_embeds=None,
151+
):
152+
153+
if prompt is not None and prompt_embeds is not None:
154+
raise ValueError(
155+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
156+
" only forward one of the two."
157+
)
158+
elif prompt_2 is not None and prompt_embeds is not None:
159+
raise ValueError(
160+
f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
161+
" only forward one of the two."
162+
)
163+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
164+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
165+
elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
166+
raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
167+
168+
if prompt_embeds is not None and pooled_prompt_embeds is None:
169+
raise ValueError(
170+
"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
171+
)
172+
145173
def encode_image(self, image, device, num_images_per_prompt):
146174
dtype = next(self.image_encoder.parameters()).dtype
147175
image = self.feature_extractor.preprocess(
@@ -367,6 +395,11 @@ def __call__(
367395
batch_size = len(image)
368396
else:
369397
batch_size = image.shape[0]
398+
if prompt is not None and isinstance(prompt, str):
399+
prompt = batch_size * [prompt]
400+
401+
402+
370403
device = self._execution_device
371404

372405
# 3. Prepare image embeddings
@@ -382,7 +415,7 @@ def __call__(
382415
pooled_prompt_embeds,
383416
_,
384417
) = self.encode_prompt(
385-
prompt=prompt * batch_size,
418+
prompt=prompt,
386419
prompt_2=prompt_2,
387420
prompt_embeds=prompt_embeds,
388421
pooled_prompt_embeds=pooled_prompt_embeds,

0 commit comments

Comments
 (0)