diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 1768c81ce039..c64b9587be77 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -1248,7 +1248,8 @@ def forward(self, ids: torch.Tensor) -> torch.Tensor: sin_out = [] pos = ids.float() is_mps = ids.device.type == "mps" - freqs_dtype = torch.float32 if is_mps else torch.float64 + is_npu = ids.device.type == "npu" + freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64 for i in range(n_axes): cos, sin = get_1d_rotary_pos_embed( self.axes_dim[i],