Skip to content

Commit f4fdb3a

Browse files
fix bug for ascend npu (#10429)
1 parent 7ab7c12 commit f4fdb3a

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

src/diffusers/models/embeddings.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1248,7 +1248,8 @@ def forward(self, ids: torch.Tensor) -> torch.Tensor:
12481248
sin_out = []
12491249
pos = ids.float()
12501250
is_mps = ids.device.type == "mps"
1251-
freqs_dtype = torch.float32 if is_mps else torch.float64
1251+
is_npu = ids.device.type == "npu"
1252+
freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64
12521253
for i in range(n_axes):
12531254
cos, sin = get_1d_rotary_pos_embed(
12541255
self.axes_dim[i],

0 commit comments

Comments
 (0)