Skip to content

Commit 7cf74df

Browse files
authored
fix mrope kernel error. (#849)
1 parent e59d067 commit 7cf74df

File tree

1 file changed

+3
-2
lines changed
  • lightllm/models/qwen2_vl/triton_kernel

1 file changed

+3
-2
lines changed

lightllm/models/qwen2_vl/triton_kernel/mrope.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,9 @@ def mrope_triton(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch
101101
qo_sb, qo_sh, qo_sl, qo_sd = map(int, q_out.stride())
102102
ko_sb, ko_sh, ko_sl, ko_sd = map(int, k_out.stride())
103103

104-
token_dim = next(i for i, s in enumerate(cos.shape) if s == L)
105-
axis_dim = next(i for i, s in enumerate(cos.shape) if s == 3)
104+
assert len(cos.shape) == 4
105+
token_dim = 2
106+
axis_dim = 0
106107

107108
s_token = int(cos.stride(token_dim))
108109
s_axis = int(cos.stride(axis_dim))

0 commit comments

Comments
 (0)