diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index eb5067c37700..0b946e18782c 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -690,7 +690,7 @@ def forward(self, ids: torch.Tensor) -> torch.Tensor: n_axes = ids.shape[-1] cos_out = [] sin_out = [] - pos = ids.squeeze().float() + pos = ids.float() is_mps = ids.device.type == "mps" freqs_dtype = torch.float32 if is_mps else torch.float64 for i in range(n_axes):