Skip to content

Commit bc7e97f

Browse files
author
sangchengmeng
committed
fix pos_emb
1 parent 58f8c1d commit bc7e97f

File tree

1 file changed

+27
-18
lines changed

1 file changed

+27
-18
lines changed

lightllm/models/qwen2_vl/triton_kernel/rotary_pos_emb.py

Lines changed: 27 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -17,37 +17,43 @@ def rotary_kernel(
1717
stride_cos_d,
1818
stride_sin_l,
1919
stride_sin_d,
20+
max_L,
2021
D: tl.constexpr,
2122
HALF_D: tl.constexpr,
2223
BLOCK_D: tl.constexpr,
24+
BLOCK_L: tl.constexpr,
2325
):
2426
pid_h = tl.program_id(0).to(tl.int64)
25-
pid_l = tl.program_id(1).to(tl.int64)
26-
pid_blk = tl.program_id(2).to(tl.int64)
27+
pid_l_blk = tl.program_id(1).to(tl.int64)
28+
pid_blk_d = tl.program_id(2).to(tl.int64)
2729

2830
offs_d = tl.arange(0, BLOCK_D)
29-
d = pid_blk * BLOCK_D + offs_d
31+
d = pid_blk_d * BLOCK_D + offs_d
3032
mask = d < D
3133

32-
base = pid_l * stride_l + pid_h * stride_h
34+
seq_start = pid_l_blk * BLOCK_L
35+
seq_end = (pid_l_blk + 1) * BLOCK_L
36+
seq_end = tl.where(seq_end < max_L, seq_end, max_L)
3337

34-
in_ptr = inp_ptr + base + d * stride_d
35-
cos_ptr_ = cos_ptr + pid_l * stride_cos_l + d
36-
sin_ptr_ = sin_ptr + pid_l * stride_sin_l + d
38+
for seq in tl.range(seq_start, seq_end):
39+
base = seq * stride_l + pid_h * stride_h
3740

38-
x = tl.load(in_ptr, mask=mask)
39-
cos = tl.load(cos_ptr_, mask=mask)
40-
sin = tl.load(sin_ptr_, mask=mask)
41+
in_ptr = inp_ptr + base + d * stride_d
42+
cos_ptr_ = cos_ptr + seq * stride_cos_l + d * stride_cos_d
43+
sin_ptr_ = sin_ptr + seq * stride_sin_l + d * stride_sin_d
4144

42-
partner_d = tl.where(d < HALF_D, d + HALF_D, d - HALF_D)
43-
partner_ptr = inp_ptr + base + partner_d * stride_d
44-
partner_val = tl.load(partner_ptr, mask=mask)
45-
rotated = tl.where(d < HALF_D, -partner_val, partner_val)
45+
x = tl.load(in_ptr, mask=mask, other=0.0)
46+
cos = tl.load(cos_ptr_, mask=mask, other=0.0)
47+
sin = tl.load(sin_ptr_, mask=mask, other=0.0)
4648

47-
y = x * cos + rotated * sin
49+
partner_d = tl.where(d < HALF_D, d + HALF_D, d - HALF_D)
50+
partner_ptr = inp_ptr + base + partner_d * stride_d
51+
partner_val = tl.load(partner_ptr, mask=mask, other=0.0)
52+
rotated = tl.where(d < HALF_D, -partner_val, partner_val)
4853

49-
out_ptr_ = out_ptr + base + d
50-
tl.store(out_ptr_, y, mask=mask)
54+
y = x * cos + rotated * sin
55+
out_ptr_ = out_ptr + base + d * stride_d
56+
tl.store(out_ptr_, y, mask=mask)
5157

5258

5359
def apply_rotary_pos_emb_triton(
@@ -67,7 +73,8 @@ def apply_rotary_pos_emb_triton(
6773
HALF_D = D // 2
6874
y = torch.empty_like(x)
6975

70-
grid = (H, L, triton.cdiv(D, BLOCK_D))
76+
BLOCK_L = 16
77+
grid = (H, triton.cdiv(L, BLOCK_L), triton.cdiv(D, BLOCK_D))
7178

7279
rotary_kernel[grid](
7380
inp_ptr=x,
@@ -81,9 +88,11 @@ def apply_rotary_pos_emb_triton(
8188
stride_cos_d=cos.stride(1),
8289
stride_sin_l=sin.stride(0),
8390
stride_sin_d=sin.stride(1),
91+
max_L=L,
8492
D=D,
8593
HALF_D=HALF_D,
8694
BLOCK_D=BLOCK_D,
95+
BLOCK_L=BLOCK_L,
8796
)
8897

8998
return y.to(orig_dtype)

0 commit comments

Comments
 (0)