Skip to content

Commit 3ffa711

Browse files
committed
update
1 parent 66a5f59 commit 3ffa711

File tree

1 file changed

+19
-18
lines changed

1 file changed

+19
-18
lines changed

src/diffusers/pipelines/mochi/pipeline_mochi.py

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -601,24 +601,25 @@ def __call__(
601601
batch_size = prompt_embeds.shape[0]
602602

603603
device = self._execution_device
604-
# 3. Prepare text embeddings
605-
(
606-
prompt_embeds,
607-
prompt_attention_mask,
608-
negative_prompt_embeds,
609-
negative_prompt_attention_mask,
610-
) = self.encode_prompt(
611-
prompt=prompt,
612-
negative_prompt=negative_prompt,
613-
do_classifier_free_guidance=self.do_classifier_free_guidance,
614-
num_videos_per_prompt=num_videos_per_prompt,
615-
prompt_embeds=prompt_embeds,
616-
negative_prompt_embeds=negative_prompt_embeds,
617-
prompt_attention_mask=prompt_attention_mask,
618-
negative_prompt_attention_mask=negative_prompt_attention_mask,
619-
max_sequence_length=max_sequence_length,
620-
device=device,
621-
)
604+
with torch.autocast("cuda", torch.float32):
605+
# 3. Prepare text embeddings
606+
(
607+
prompt_embeds,
608+
prompt_attention_mask,
609+
negative_prompt_embeds,
610+
negative_prompt_attention_mask,
611+
) = self.encode_prompt(
612+
prompt=prompt,
613+
negative_prompt=negative_prompt,
614+
do_classifier_free_guidance=self.do_classifier_free_guidance,
615+
num_videos_per_prompt=num_videos_per_prompt,
616+
prompt_embeds=prompt_embeds,
617+
negative_prompt_embeds=negative_prompt_embeds,
618+
prompt_attention_mask=prompt_attention_mask,
619+
negative_prompt_attention_mask=negative_prompt_attention_mask,
620+
max_sequence_length=max_sequence_length,
621+
device=device,
622+
)
622623
# if self.do_classifier_free_guidance:
623624
# prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
624625
# prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)

0 commit comments

Comments
 (0)