Skip to content

Commit bde26d3

Browse files
committed
fix for mps
1 parent c5412b9 commit bde26d3

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

src/diffusers/models/transformers/transformer_lumina2.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,8 +241,9 @@ def __init__(self, theta: int, axes_dim: List[int], axes_lens: List[int] = (300,
241241

242242
def _precompute_freqs_cis(self, axes_dim: List[int], axes_lens: List[int], theta: int) -> List[torch.Tensor]:
243243
freqs_cis = []
244+
freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
244245
for i, (d, e) in enumerate(zip(axes_dim, axes_lens)):
245-
emb = get_1d_rotary_pos_embed(d, e, theta=self.theta, freqs_dtype=torch.float64)
246+
emb = get_1d_rotary_pos_embed(d, e, theta=self.theta, freqs_dtype=freqs_dtype)
246247
freqs_cis.append(emb)
247248
return freqs_cis
248249

0 commit comments

Comments
 (0)