Skip to content

Commit 4f00bae

Browse files
committed
update
1 parent a967e66 commit 4f00bae

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

src/diffusers/pipelines/chroma/pipeline_chroma.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,8 @@ def encode_prompt(
299299
max_sequence_length=max_sequence_length,
300300
device=device,
301301
)
302+
303+
negative_text_ids = None
302304
if do_classifier_free_guidance and negative_prompt_embeds is None:
303305
negative_prompt = negative_prompt or ""
304306
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
@@ -321,6 +323,7 @@ def encode_prompt(
321323
max_sequence_length=max_sequence_length,
322324
device=device,
323325
)
326+
negative_text_ids = torch.zeros(negative_prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
324327

325328
if self.text_encoder is not None:
326329
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
@@ -329,7 +332,6 @@ def encode_prompt(
329332

330333
dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
331334
text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
332-
negative_text_ids = torch.zeros(negative_prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
333335

334336
return prompt_embeds, text_ids, negative_prompt_embeds, negative_text_ids
335337

0 commit comments

Comments
 (0)