Skip to content

Commit 77f9d19

Browse files
committed
update
1 parent 53dbc37 commit 77f9d19

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

src/diffusers/pipelines/mochi/pipeline_mochi.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,9 @@ def _get_t5_prompt_embeds(
235235
text_input_ids = text_inputs.input_ids
236236
prompt_attention_mask = text_inputs.attention_mask
237237
prompt_attention_mask = prompt_attention_mask.bool().to(device)
238+
if prompt == "" or prompt[-1] == "":
239+
text_input_ids = torch.zeros_like(text_input_ids, device=device)
240+
prompt_attention_mask = torch.zeros_like(prompt_attention_mask, dtype=torch.bool, device=device)
238241

239242
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
240243

@@ -450,7 +453,8 @@ def prepare_latents(
450453
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
451454
)
452455

453-
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
456+
latents = randn_tensor(shape, generator=generator, device=device, dtype=torch.float32)
457+
latents = latents.to(dtype)
454458
return latents
455459

456460
@property

0 commit comments

Comments
 (0)