Skip to content

Commit 43a5279

Browse files
author
sangchengmeng
committed
[add] cos sin
1 parent d59de2d commit 43a5279

File tree

1 file changed

+12
-5
lines changed
  • lightllm/models/qwen2_vl/triton_kernel

1 file changed

+12
-5
lines changed

lightllm/models/qwen2_vl/triton_kernel/mrope.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ def mrope_kernel(
1919
L: tl.int32,
2020
D: tl.int32,
2121
HALF: tl.constexpr,
22+
s_tok: tl.int32,
23+
s_ax: tl.int32,
2224
q_sb: tl.int32,
2325
q_sh: tl.int32,
2426
q_sl: tl.int32,
@@ -75,7 +77,7 @@ def mrope_kernel(
7577
rot_vals = tl.where(offs < HALF, -rot_vals, rot_vals)
7678

7779
axis_id = tl.load(AXIS_ptr + offs, mask=mask, other=0) # 0,1,2
78-
cos_idx = axis_id * (L * D) + pid_l * D + offs
80+
cos_idx = pid_l * s_tok + axis_id * s_ax + offs
7981
c = tl.load(COS_ptr + cos_idx, mask=mask, other=0.0)
8082
s = tl.load(SIN_ptr + cos_idx, mask=mask, other=0.0)
8183

@@ -99,16 +101,19 @@ def mrope_triton(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch
99101
qo_sb, qo_sh, qo_sl, qo_sd = map(int, q_out.stride())
100102
ko_sb, ko_sh, ko_sl, ko_sd = map(int, k_out.stride())
101103

102-
cos_flat = cos.transpose(0, 1).contiguous().reshape(-1)
103-
sin_flat = sin.transpose(0, 1).contiguous().reshape(-1)
104+
token_dim = next(i for i, s in enumerate(cos.shape) if s == L)
105+
axis_dim = next(i for i, s in enumerate(cos.shape) if s == 3)
106+
107+
s_token = int(cos.stride(token_dim))
108+
s_axis = int(cos.stride(axis_dim))
104109

105110
grid = (B * (H_q + H_k), L)
106111

107112
mrope_kernel[grid](
108113
q,
109114
k,
110-
cos_flat,
111-
sin_flat,
115+
cos,
116+
sin,
112117
axis_map,
113118
q_out,
114119
k_out,
@@ -118,6 +123,8 @@ def mrope_triton(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch
118123
L,
119124
D,
120125
HALF,
126+
s_token,
127+
s_axis,
121128
q_sb,
122129
q_sh,
123130
q_sl,

0 commit comments

Comments
 (0)