Skip to content

Commit 59c9f5d

Browse files
committed
update
1 parent 883f5c8 commit 59c9f5d

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

src/diffusers/models/transformers/transformer_mochi.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -373,9 +373,10 @@ def _get_positions(
373373
return positions
374374

375375
def _create_rope(self, freqs: torch.Tensor, pos: torch.Tensor) -> torch.Tensor:
376-
# Always run ROPE freqs computation in FP32
377-
with torch.set_default_dtype(torch.float32):
378-
freqs = torch.einsum("nd,dhf->nhf", pos, freqs)
376+
with torch.autocast(freqs.device.type, enabled=False):
377+
# Always run ROPE freqs computation in FP32
378+
freqs = torch.einsum("nd,dhf->nhf", pos.to(torch.float32), freqs.to(torch.float32))
379+
379380
freqs_cos = torch.cos(freqs)
380381
freqs_sin = torch.sin(freqs)
381382
return freqs_cos, freqs_sin

0 commit comments

Comments
 (0)