Skip to content

Commit 8e75005

Browse files
Kai zhengKai zheng
authored andcommitted
get_1d_rotary_pos_embed support npu
1 parent c14057c commit 8e75005

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

src/diffusers/models/embeddings.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1151,7 +1151,7 @@ def get_1d_rotary_pos_embed(
11511151
/ (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device)[: (dim // 2)] / dim))
11521152
/ linear_factor
11531153
) # [D/2]
1154-
freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2]
1154+
freqs = torch.outer(pos, freqs).float() # type: ignore # [S, D/2]
11551155
if use_real and repeat_interleave_real:
11561156
# flux, hunyuan-dit, cogvideox
11571157
freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D]

0 commit comments

Comments
 (0)