Skip to content

Commit dcf320f

Browse files
yiyixuxusayakpaul
andauthored
small update on rotary embedding (#9354)
* update * fix --------- Co-authored-by: Sayak Paul <[email protected]>
1 parent 8ba90aa commit dcf320f

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

src/diffusers/models/embeddings.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -608,8 +608,11 @@ def get_1d_rotary_pos_embed(
608608
pos = torch.from_numpy(pos) # type: ignore # [S]
609609

610610
theta = theta * ntk_factor
611-
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype)[: (dim // 2)] / dim)) / linear_factor # [D/2]
612-
freqs = freqs.to(pos.device)
611+
freqs = (
612+
1.0
613+
/ (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device)[: (dim // 2)] / dim))
614+
/ linear_factor
615+
) # [D/2]
613616
freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2]
614617
if use_real and repeat_interleave_real:
615618
# flux, hunyuan-dit, cogvideox

0 commit comments

Comments
 (0)