Skip to content

Commit a967e66

Browse files
committed
update
1 parent 2b559e9 commit a967e66

File tree

1 file changed

+16
-2
lines changed

1 file changed

+16
-2
lines changed

src/diffusers/pipelines/chroma/pipeline_chroma.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,7 @@ def encode_prompt(
252252
num_images_per_prompt: int = 1,
253253
prompt_embeds: Optional[torch.FloatTensor] = None,
254254
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
255+
do_classifier_free_guidance: bool = True,
255256
max_sequence_length: int = 512,
256257
lora_scale: Optional[float] = None,
257258
):
@@ -298,10 +299,22 @@ def encode_prompt(
298299
max_sequence_length=max_sequence_length,
299300
device=device,
300301
)
301-
302-
if negative_prompt_embeds is None:
302+
if do_classifier_free_guidance and negative_prompt_embeds is None:
303303
negative_prompt = negative_prompt or ""
304304
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
305+
306+
if prompt is not None and type(prompt) is not type(negative_prompt):
307+
raise TypeError(
308+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
309+
f" {type(prompt)}."
310+
)
311+
elif batch_size != len(negative_prompt):
312+
raise ValueError(
313+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
314+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
315+
" the batch size of `prompt`."
316+
)
317+
305318
negative_prompt_embeds = self._get_t5_prompt_embeds(
306319
prompt=negative_prompt,
307320
num_images_per_prompt=num_images_per_prompt,
@@ -693,6 +706,7 @@ def __call__(
693706
negative_prompt=negative_prompt,
694707
prompt_embeds=prompt_embeds,
695708
negative_prompt_embeds=negative_prompt_embeds,
709+
do_classifier_free_guidance=self.do_classifier_free_guidance,
696710
device=device,
697711
num_images_per_prompt=num_images_per_prompt,
698712
max_sequence_length=max_sequence_length,

0 commit comments

Comments
 (0)