Skip to content

Commit 27940f1

Browse files
author
wangzaijun
committed
fix
1 parent 364a166 commit 27940f1

File tree

1 file changed

+20
-25
lines changed

1 file changed

+20
-25
lines changed

lightllm/models/qwen2_vl/triton_kernel/rotary_pos_emb.py

Lines changed: 20 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -17,42 +17,38 @@ def rotary_kernel(
1717
stride_cos_d,
1818
stride_sin_l,
1919
stride_sin_d,
20-
max_L,
20+
total_len,
2121
D: tl.constexpr,
2222
HALF_D: tl.constexpr,
2323
BLOCK_D: tl.constexpr,
24-
BLOCK_L: tl.constexpr,
2524
):
2625
pid_h = tl.program_id(0).to(tl.int64)
27-
pid_l_blk = tl.program_id(1).to(tl.int64)
28-
pid_blk_d = tl.program_id(2).to(tl.int64)
26+
pid_l_start = tl.program_id(1).to(tl.int64)
27+
pid_blk = tl.program_id(2).to(tl.int64)
2928

3029
offs_d = tl.arange(0, BLOCK_D)
31-
d = pid_blk_d * BLOCK_D + offs_d
30+
d = pid_blk * BLOCK_D + offs_d
3231
mask = d < D
32+
for pid_l in tl.range(pid_l_start, total_len, step=tl.num_programs(axis=1)):
3333

34-
seq_start = pid_l_blk * BLOCK_L
35-
seq_end = (pid_l_blk + 1) * BLOCK_L
36-
seq_end = tl.where(seq_end < max_L, seq_end, max_L)
37-
38-
for seq in tl.range(seq_start, seq_end):
39-
base = seq * stride_l + pid_h * stride_h
34+
base = pid_l * stride_l + pid_h * stride_h
4035

4136
in_ptr = inp_ptr + base + d * stride_d
42-
cos_ptr_ = cos_ptr + seq * stride_cos_l + d * stride_cos_d
43-
sin_ptr_ = sin_ptr + seq * stride_sin_l + d * stride_sin_d
37+
cos_ptr_ = cos_ptr + pid_l * stride_cos_l + d
38+
sin_ptr_ = sin_ptr + pid_l * stride_sin_l + d
4439

45-
x = tl.load(in_ptr, mask=mask, other=0.0)
46-
cos = tl.load(cos_ptr_, mask=mask, other=0.0)
47-
sin = tl.load(sin_ptr_, mask=mask, other=0.0)
40+
x = tl.load(in_ptr, mask=mask)
41+
cos = tl.load(cos_ptr_, mask=mask)
42+
sin = tl.load(sin_ptr_, mask=mask)
4843

4944
partner_d = tl.where(d < HALF_D, d + HALF_D, d - HALF_D)
5045
partner_ptr = inp_ptr + base + partner_d * stride_d
51-
partner_val = tl.load(partner_ptr, mask=mask, other=0.0)
46+
partner_val = tl.load(partner_ptr, mask=mask)
5247
rotated = tl.where(d < HALF_D, -partner_val, partner_val)
5348

5449
y = x * cos + rotated * sin
55-
out_ptr_ = out_ptr + base + d * stride_d
50+
51+
out_ptr_ = out_ptr + base + d
5652
tl.store(out_ptr_, y, mask=mask)
5753

5854

@@ -72,12 +68,12 @@ def apply_rotary_pos_emb_triton(
7268
L, H, D = x.shape
7369
HALF_D = D // 2
7470
y = torch.empty_like(x)
75-
76-
if L <= 4096:
77-
BLOCK_L = 16
71+
if L < 1024:
72+
grid_L = L
7873
else:
79-
BLOCK_L = 128
80-
grid = (H, triton.cdiv(L, BLOCK_L), triton.cdiv(D, BLOCK_D))
74+
grid_L = 1024
75+
76+
grid = (H, grid_L, triton.cdiv(D, BLOCK_D))
8177

8278
rotary_kernel[grid](
8379
inp_ptr=x,
@@ -91,11 +87,10 @@ def apply_rotary_pos_emb_triton(
9187
stride_cos_d=cos.stride(1),
9288
stride_sin_l=sin.stride(0),
9389
stride_sin_d=sin.stride(1),
94-
max_L=L,
90+
total_len=L,
9591
D=D,
9692
HALF_D=HALF_D,
9793
BLOCK_D=BLOCK_D,
98-
BLOCK_L=BLOCK_L,
9994
)
10095

10196
return y.to(orig_dtype)

0 commit comments

Comments
 (0)