Skip to content

Commit ba9c185

Browse files
committed
update
1 parent 7854061 commit ba9c185

File tree

2 files changed

+22
-21
lines changed

2 files changed

+22
-21
lines changed

src/diffusers/models/transformers/transformer_mochi.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import torch
1919
import torch.nn as nn
2020
import torch.nn.functional as F
21+
from transformers.tokenization_utils_base import import_protobuf_decode_error
2122

2223
from ...configuration_utils import ConfigMixin, register_to_config
2324
from ...utils import is_torch_version, logging
@@ -478,9 +479,8 @@ def _get_positions(
478479
return positions
479480

480481
def _create_rope(self, freqs: torch.Tensor, pos: torch.Tensor) -> torch.Tensor:
481-
with torch.autocast(freqs.device.type, enabled=False):
482-
# Always run ROPE freqs computation in FP32
483-
freqs = torch.einsum("nd,dhf->nhf", pos.to(torch.float32), freqs.to(torch.float32))
482+
# Always run ROPE freqs computation in FP32
483+
freqs = torch.einsum("nd,dhf->nhf", pos.to(torch.float32), freqs.to(torch.float32))
484484

485485
freqs_cos = torch.cos(freqs)
486486
freqs_sin = torch.sin(freqs)

src/diffusers/pipelines/mochi/pipeline_mochi.py

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -594,24 +594,25 @@ def __call__(
594594
batch_size = prompt_embeds.shape[0]
595595

596596
device = self._execution_device
597-
# 3. Prepare text embeddings
598-
(
599-
prompt_embeds,
600-
prompt_attention_mask,
601-
negative_prompt_embeds,
602-
negative_prompt_attention_mask,
603-
) = self.encode_prompt(
604-
prompt=prompt,
605-
negative_prompt=negative_prompt,
606-
do_classifier_free_guidance=self.do_classifier_free_guidance,
607-
num_videos_per_prompt=num_videos_per_prompt,
608-
prompt_embeds=prompt_embeds,
609-
negative_prompt_embeds=negative_prompt_embeds,
610-
prompt_attention_mask=prompt_attention_mask,
611-
negative_prompt_attention_mask=negative_prompt_attention_mask,
612-
max_sequence_length=max_sequence_length,
613-
device=device,
614-
)
597+
with torch.autocast("cuda", torch.float32):
598+
# 3. Prepare text embeddings
599+
(
600+
prompt_embeds,
601+
prompt_attention_mask,
602+
negative_prompt_embeds,
603+
negative_prompt_attention_mask,
604+
) = self.encode_prompt(
605+
prompt=prompt,
606+
negative_prompt=negative_prompt,
607+
do_classifier_free_guidance=self.do_classifier_free_guidance,
608+
num_videos_per_prompt=num_videos_per_prompt,
609+
prompt_embeds=prompt_embeds,
610+
negative_prompt_embeds=negative_prompt_embeds,
611+
prompt_attention_mask=prompt_attention_mask,
612+
negative_prompt_attention_mask=negative_prompt_attention_mask,
613+
max_sequence_length=max_sequence_length,
614+
device=device,
615+
)
615616
# 4. Prepare latent variables
616617
num_channels_latents = self.transformer.config.in_channels
617618
latents = self.prepare_latents(

0 commit comments

Comments
 (0)