Skip to content

Commit 883f5c8

Browse files
committed
update
1 parent 0b09231 commit 883f5c8

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

src/diffusers/models/transformers/transformer_mochi.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -373,7 +373,9 @@ def _get_positions(
373373
return positions
374374

375375
def _create_rope(self, freqs: torch.Tensor, pos: torch.Tensor) -> torch.Tensor:
376-
freqs = torch.einsum("nd,dhf->nhf", pos.to(freqs), freqs)
376+
# Always run ROPE freqs computation in FP32
377+
with torch.set_default_dtype(torch.float32):
378+
freqs = torch.einsum("nd,dhf->nhf", pos, freqs)
377379
freqs_cos = torch.cos(freqs)
378380
freqs_sin = torch.sin(freqs)
379381
return freqs_cos, freqs_sin

0 commit comments

Comments
 (0)