From 7c508d455fa1d2048bab1f64f4e287ad0c5888a2 Mon Sep 17 00:00:00 2001 From: hlky Date: Mon, 14 Apr 2025 14:53:26 +0100 Subject: [PATCH] Use float32 on mps or npu in transformer_hidream_image's rope --- .../models/transformers/transformer_hidream_image.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_hidream_image.py b/src/diffusers/models/transformers/transformer_hidream_image.py index 43949f797c3d..04622a7e04b2 100644 --- a/src/diffusers/models/transformers/transformer_hidream_image.py +++ b/src/diffusers/models/transformers/transformer_hidream_image.py @@ -95,7 +95,12 @@ def forward(self, latent): def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor: assert dim % 2 == 0, "The dimension must be even." - scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim + is_mps = pos.device.type == "mps" + is_npu = pos.device.type == "npu" + + dtype = torch.float32 if (is_mps or is_npu) else torch.float64 + + scale = torch.arange(0, dim, 2, dtype=dtype, device=pos.device) / dim omega = 1.0 / (theta**scale) batch_size, seq_length = pos.shape