Skip to content

Commit 3607297

Browse files
author
sangchengmeng
committed
fix triton_rotary_rope_emb
1 parent 58f8c1d commit 3607297

File tree

1 file changed

+21
-11
lines changed

1 file changed

+21
-11
lines changed

lightllm/models/qwen2_vl/triton_kernel/rotary_pos_emb.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,32 +21,40 @@ def rotary_kernel(
2121
HALF_D: tl.constexpr,
2222
BLOCK_D: tl.constexpr,
2323
):
24-
pid_h = tl.program_id(0).to(tl.int64)
25-
pid_l = tl.program_id(1).to(tl.int64)
24+
# program_id(0) : seq idx)
25+
# program_id(1) : head idx)
26+
# program_id(2) : 块化的 d 维
27+
pid_l = tl.program_id(0).to(tl.int64)
28+
pid_h = tl.program_id(1).to(tl.int64)
2629
pid_blk = tl.program_id(2).to(tl.int64)
2730

31+
# 当前 block 处理的 d 下标
2832
offs_d = tl.arange(0, BLOCK_D)
2933
d = pid_blk * BLOCK_D + offs_d
3034
mask = d < D
3135

36+
# base 是 (l, h, 0) 的起始地址
3237
base = pid_l * stride_l + pid_h * stride_h
3338

39+
# ---- load x / cos / sin ----
3440
in_ptr = inp_ptr + base + d * stride_d
35-
cos_ptr_ = cos_ptr + pid_l * stride_cos_l + d
36-
sin_ptr_ = sin_ptr + pid_l * stride_sin_l + d
3741

38-
x = tl.load(in_ptr, mask=mask)
39-
cos = tl.load(cos_ptr_, mask=mask)
40-
sin = tl.load(sin_ptr_, mask=mask)
42+
cos_ptr_ = cos_ptr + pid_l * stride_cos_l + d * stride_cos_d
43+
sin_ptr_ = sin_ptr + pid_l * stride_sin_l + d * stride_sin_d
4144

45+
x = tl.load(in_ptr, mask=mask, other=0.0)
46+
cos = tl.load(cos_ptr_, mask=mask, other=0.0)
47+
sin = tl.load(sin_ptr_, mask=mask, other=0.0)
48+
49+
# ---- rotary partner (偶/奇半维互换) ----
4250
partner_d = tl.where(d < HALF_D, d + HALF_D, d - HALF_D)
4351
partner_ptr = inp_ptr + base + partner_d * stride_d
44-
partner_val = tl.load(partner_ptr, mask=mask)
52+
partner_val = tl.load(partner_ptr, mask=mask, other=0.0)
4553
rotated = tl.where(d < HALF_D, -partner_val, partner_val)
4654

4755
y = x * cos + rotated * sin
4856

49-
out_ptr_ = out_ptr + base + d
57+
out_ptr_ = out_ptr + base + d * stride_d
5058
tl.store(out_ptr_, y, mask=mask)
5159

5260

@@ -57,17 +65,19 @@ def apply_rotary_pos_emb_triton(
5765
assert cos.is_contiguous() and sin.is_contiguous()
5866
if tensor.ndim != 3:
5967
raise RuntimeError("tensor shape should be [L, H, D]")
68+
6069
orig_dtype = tensor.dtype
61-
x = tensor.float()
70+
x = tensor.float() # [L, H, D]
6271

72+
# cos/sin: [L, D/2] -> [L, D]
6373
cos = cos.repeat(1, 2).view(cos.size(0), -1).contiguous().float()
6474
sin = sin.repeat(1, 2).view(sin.size(0), -1).contiguous().float()
6575

6676
L, H, D = x.shape
6777
HALF_D = D // 2
6878
y = torch.empty_like(x)
6979

70-
grid = (H, L, triton.cdiv(D, BLOCK_D))
80+
grid = (L, H, triton.cdiv(D, BLOCK_D))
7181

7282
rotary_kernel[grid](
7383
inp_ptr=x,

0 commit comments

Comments
 (0)