Skip to content

Commit 3fe01a9

Browse files
author
sangchengmeng
committed
fix triton_rotary_rope_emb
1 parent 3607297 commit 3fe01a9

File tree

1 file changed

+8
-16
lines changed

1 file changed

+8
-16
lines changed

lightllm/models/qwen2_vl/triton_kernel/rotary_pos_emb.py

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -21,40 +21,33 @@ def rotary_kernel(
2121
HALF_D: tl.constexpr,
2222
BLOCK_D: tl.constexpr,
2323
):
24-
# program_id(0) : seq idx)
25-
# program_id(1) : head idx)
26-
# program_id(2) : 块化的 d 维
2724
pid_l = tl.program_id(0).to(tl.int64)
2825
pid_h = tl.program_id(1).to(tl.int64)
2926
pid_blk = tl.program_id(2).to(tl.int64)
3027

31-
# 当前 block 处理的 d 下标
3228
offs_d = tl.arange(0, BLOCK_D)
3329
d = pid_blk * BLOCK_D + offs_d
3430
mask = d < D
3531

36-
# base 是 (l, h, 0) 的起始地址
3732
base = pid_l * stride_l + pid_h * stride_h
3833

39-
# ---- load x / cos / sin ----
4034
in_ptr = inp_ptr + base + d * stride_d
4135

42-
cos_ptr_ = cos_ptr + pid_l * stride_cos_l + d * stride_cos_d
43-
sin_ptr_ = sin_ptr + pid_l * stride_sin_l + d * stride_sin_d
36+
cos_ptr_ = cos_ptr + pid_l * stride_cos_l + d
37+
sin_ptr_ = sin_ptr + pid_l * stride_sin_l + d
4438

45-
x = tl.load(in_ptr, mask=mask, other=0.0)
46-
cos = tl.load(cos_ptr_, mask=mask, other=0.0)
47-
sin = tl.load(sin_ptr_, mask=mask, other=0.0)
39+
x = tl.load(in_ptr, mask=mask)
40+
cos = tl.load(cos_ptr_, mask=mask)
41+
sin = tl.load(sin_ptr_, mask=mask)
4842

49-
# ---- rotary partner (偶/奇半维互换) ----
5043
partner_d = tl.where(d < HALF_D, d + HALF_D, d - HALF_D)
5144
partner_ptr = inp_ptr + base + partner_d * stride_d
52-
partner_val = tl.load(partner_ptr, mask=mask, other=0.0)
45+
partner_val = tl.load(partner_ptr, mask=mask)
5346
rotated = tl.where(d < HALF_D, -partner_val, partner_val)
5447

5548
y = x * cos + rotated * sin
5649

57-
out_ptr_ = out_ptr + base + d * stride_d
50+
out_ptr_ = out_ptr + base + d
5851
tl.store(out_ptr_, y, mask=mask)
5952

6053

@@ -67,9 +60,8 @@ def apply_rotary_pos_emb_triton(
6760
raise RuntimeError("tensor shape should be [L, H, D]")
6861

6962
orig_dtype = tensor.dtype
70-
x = tensor.float() # [L, H, D]
63+
x = tensor.float()
7164

72-
# cos/sin: [L, D/2] -> [L, D]
7365
cos = cos.repeat(1, 2).view(cos.size(0), -1).contiguous().float()
7466
sin = sin.repeat(1, 2).view(sin.size(0), -1).contiguous().float()
7567

0 commit comments

Comments
 (0)