Skip to content

Commit d59de2d

Browse files
author
sangchengmeng
committed
[support] add triton_mrope stride support
1 parent 5175f83 commit d59de2d

File tree

3 files changed

+98
-49
lines changed

3 files changed

+98
-49
lines changed

lightllm/models/qwen2_vl/layer_infer/transformer_layer_infer.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,22 @@ class Qwen2VLTransformerLayerInfer(LlamaTransformerLayerInfer):
1212
def __init__(self, layer_num, network_config, mode=[]):
1313
super().__init__(layer_num, network_config, mode)
1414
self.mrope_section = network_config["rope_scaling"]["mrope_section"]
15+
axis_map = []
16+
for i, n in enumerate(self.mrope_section * 2):
17+
axis_map += [i % 3] * n
18+
self.axis_map = torch.tensor(axis_map, dtype=torch.int32, device="cuda")
1519

1620
def _get_qkv(self, input, cache_kv, infer_state, layer_weight):
1721
q = layer_weight.q_proj.mm(input)
1822
cache_kv = layer_weight.kv_proj.mm(
1923
input, out=cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_)
2024
).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_)
2125
seq_len, _ = q.shape
22-
q = q.view(1, seq_len, -1, self.head_dim_).transpose(1, 2).contiguous()
23-
k = cache_kv[:, : self.tp_k_head_num_, :].view(1, seq_len, -1, self.head_dim_).transpose(1, 2).contiguous()
24-
new_q, new_k = mrope_triton(q, k, infer_state.position_cos, infer_state.position_sin, self.mrope_section)
25-
new_q = new_q.transpose(1, 2).reshape(1, seq_len, -1).contiguous()
26+
q = q.view(1, seq_len, -1, self.head_dim_).transpose(1, 2)
27+
self.axis_map = self.axis_map.to(q.device)
28+
k = cache_kv[:, : self.tp_k_head_num_, :].view(1, seq_len, -1, self.head_dim_).transpose(1, 2)
29+
new_q, new_k = mrope_triton(q, k, infer_state.position_cos, infer_state.position_sin, self.axis_map)
30+
new_q = new_q.transpose(1, 2).reshape(1, seq_len, -1)
2631
cache_kv[:, : self.tp_k_head_num_, :] = new_k.squeeze(0).permute(1, 0, 2)
2732

2833
return new_q, cache_kv

lightllm/models/qwen2_vl/triton_kernel/__init__.py

Whitespace-only changes.

lightllm/models/qwen2_vl/triton_kernel/mrope.py

Lines changed: 89 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -5,84 +5,106 @@
55

66

77
@triton.jit
8-
def mrope_kernel_combined(
8+
def mrope_kernel(
99
Q_ptr,
1010
K_ptr,
1111
COS_ptr,
1212
SIN_ptr,
13-
AXIS_MAP_ptr,
14-
Q_out_ptr,
15-
K_out_ptr,
13+
AXIS_ptr,
14+
QO_ptr,
15+
KO_ptr,
1616
B: tl.int32,
1717
H_q: tl.int32,
1818
H_k: tl.int32,
1919
L: tl.int32,
2020
D: tl.int32,
21-
HALF: tl.int32,
21+
HALF: tl.constexpr,
22+
q_sb: tl.int32,
23+
q_sh: tl.int32,
24+
q_sl: tl.int32,
25+
q_sd: tl.int32,
26+
k_sb: tl.int32,
27+
k_sh: tl.int32,
28+
k_sl: tl.int32,
29+
k_sd: tl.int32,
30+
qo_sb: tl.int32,
31+
qo_sh: tl.int32,
32+
qo_sl: tl.int32,
33+
qo_sd: tl.int32,
34+
ko_sb: tl.int32,
35+
ko_sh: tl.int32,
36+
ko_sl: tl.int32,
37+
ko_sd: tl.int32,
2238
BLOCK_D: tl.constexpr,
2339
):
24-
total_h = H_q + H_k
2540

41+
total_h = H_q + H_k
2642
pid_bh = tl.program_id(0)
2743
pid_l = tl.program_id(1)
2844

2945
b = pid_bh // total_h
30-
head_local = pid_bh - b * total_h
46+
h_local = pid_bh - b * total_h
3147

32-
# decide whether this head comes from q or k
33-
is_q = head_local < H_q
34-
head_q = head_local
35-
head_k = head_local - H_q
48+
is_q = h_local < H_q
49+
h_q = h_local
50+
h_k = h_local - H_q
3651

37-
base_ptr = tl.where(is_q, Q_ptr, K_ptr)
38-
out_ptr = tl.where(is_q, Q_out_ptr, K_out_ptr)
39-
h_sub = tl.where(is_q, head_q, head_k)
40-
H_sub = tl.where(is_q, H_q, H_k)
52+
sb = tl.where(is_q, q_sb, k_sb)
53+
sh = tl.where(is_q, q_sh, k_sh)
54+
sl = tl.where(is_q, q_sl, k_sl)
55+
sd = tl.where(is_q, q_sd, k_sd)
56+
57+
osb = tl.where(is_q, qo_sb, ko_sb)
58+
osh = tl.where(is_q, qo_sh, ko_sh)
59+
osl = tl.where(is_q, qo_sl, ko_sl)
60+
osd = tl.where(is_q, qo_sd, ko_sd)
4161

42-
# base offset for (b, h_sub, pid_l)
43-
base = ((b * H_sub + h_sub) * L + pid_l) * D
62+
base_ptr = tl.where(is_q, Q_ptr, K_ptr)
63+
out_ptr = tl.where(is_q, QO_ptr, KO_ptr)
64+
h_index = tl.where(is_q, h_q, h_k)
4465

66+
base = b * sb + h_index * sh + pid_l * sl
4567
offs = tl.arange(0, BLOCK_D)
46-
idx = base + offs
4768
mask = offs < D
4869

70+
idx = base + offs * sd
4971
vals = tl.load(base_ptr + idx, mask=mask, other=0.0)
50-
axis_id = tl.load(AXIS_MAP_ptr + offs, mask=mask, other=0)
51-
axis_id_b = b * 3 + axis_id
5272

53-
seq_off = pid_l * D
54-
cos_idx = axis_id_b * (L * D) + seq_off + offs
73+
rot_offs = tl.where(offs < HALF, (offs + HALF) * sd, (offs - HALF) * sd)
74+
rot_vals = tl.load(base_ptr + base + rot_offs, mask=mask, other=0.0)
75+
rot_vals = tl.where(offs < HALF, -rot_vals, rot_vals)
76+
77+
axis_id = tl.load(AXIS_ptr + offs, mask=mask, other=0) # 0,1,2
78+
cos_idx = axis_id * (L * D) + pid_l * D + offs
5579
c = tl.load(COS_ptr + cos_idx, mask=mask, other=0.0)
5680
s = tl.load(SIN_ptr + cos_idx, mask=mask, other=0.0)
5781

58-
# rotate_half
59-
rot_idx = tl.where(offs < HALF, idx + HALF, idx - HALF)
60-
rot_vals = tl.load(base_ptr + rot_idx, mask=mask, other=0.0)
61-
sign = tl.where(offs < HALF, -1.0, 1.0)
62-
rot_vals *= sign
82+
out = vals * c + rot_vals * s
6383

64-
out_vals = vals * c + rot_vals * s
65-
tl.store(out_ptr + idx, out_vals, mask=mask)
84+
out_idx = b * osb + h_index * osh + pid_l * osl + offs * osd
85+
tl.store(out_ptr + out_idx, out, mask=mask)
6686

6787

68-
def mrope_triton(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, mrope_section):
88+
def mrope_triton(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, axis_map: torch.Tensor):
89+
6990
B, H_q, L, D = q.shape
7091
H_k = k.shape[1]
92+
HALF = D // 2
7193

72-
# build axis_map 0/1/2 label per feature dim
73-
axis_map = []
74-
for i, n in enumerate(mrope_section * 2):
75-
axis_map += [i % 3] * n
76-
axis_map = torch.tensor(axis_map, dtype=torch.int32, device=q.device)
77-
78-
cos_flat = cos.transpose(0, 1).expand(B, 3, L, D).contiguous().reshape(-1)
79-
sin_flat = sin.transpose(0, 1).expand(B, 3, L, D).contiguous().reshape(-1)
94+
q_sb, q_sh, q_sl, q_sd = map(int, q.stride())
95+
k_sb, k_sh, k_sl, k_sd = map(int, k.stride())
8096

8197
q_out = torch.empty_like(q)
8298
k_out = torch.empty_like(k)
99+
qo_sb, qo_sh, qo_sl, qo_sd = map(int, q_out.stride())
100+
ko_sb, ko_sh, ko_sl, ko_sd = map(int, k_out.stride())
101+
102+
cos_flat = cos.transpose(0, 1).contiguous().reshape(-1)
103+
sin_flat = sin.transpose(0, 1).contiguous().reshape(-1)
83104

84105
grid = (B * (H_q + H_k), L)
85-
mrope_kernel_combined[grid](
106+
107+
mrope_kernel[grid](
86108
q,
87109
k,
88110
cos_flat,
@@ -95,8 +117,26 @@ def mrope_triton(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch
95117
H_k,
96118
L,
97119
D,
98-
D // 2,
120+
HALF,
121+
q_sb,
122+
q_sh,
123+
q_sl,
124+
q_sd,
125+
k_sb,
126+
k_sh,
127+
k_sl,
128+
k_sd,
129+
qo_sb,
130+
qo_sh,
131+
qo_sl,
132+
qo_sd,
133+
ko_sb,
134+
ko_sh,
135+
ko_sl,
136+
ko_sd,
99137
BLOCK_D=128,
138+
num_warps=4,
139+
num_stages=3,
100140
)
101141
return q_out, k_out
102142

@@ -125,21 +165,25 @@ def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim
125165
k_out = k * cos_embed + rotate_half(k) * sin_embed
126166
return q_out, k_out
127167

128-
B, H_q, H_k, L, D = 1, 28, 4, 16384, 128
168+
B, H_q, H_k, L, D = 3, 28, 4, 16384, 128
129169
mrope_section = [16, 24, 24]
130170
torch.manual_seed(0)
131171
device = "cuda"
132172

133-
q = torch.rand(B, H_q, L, D, dtype=torch.float32, device=device)
134-
k = torch.rand(B, H_k, L, D, dtype=torch.float32, device=device)
173+
q = torch.rand(B, H_q, L, D, dtype=torch.float32, device=device).transpose(1, 2).contiguous().transpose(1, 2)
174+
k = torch.rand(B, H_k, L, D, dtype=torch.float32, device=device).transpose(1, 2).contiguous().transpose(1, 2)
135175
cos = torch.rand(3, 1, L, D, dtype=torch.float32, device=device)
136176
sin = torch.rand(3, 1, L, D, dtype=torch.float32, device=device)
137177

138178
# 精度对比
179+
axis_map = []
180+
for i, n in enumerate(mrope_section * 2):
181+
axis_map += [i % 3] * n
182+
axis_map = torch.tensor(axis_map, dtype=torch.int32, device="cuda")
139183
ref_q, ref_k = apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1)
140184

141185
torch.cuda.synchronize()
142-
out_q, out_k = mrope_triton(q, k, cos, sin, mrope_section)
186+
out_q, out_k = mrope_triton(q, k, cos, sin, axis_map)
143187
torch.cuda.synchronize()
144188

145189
err_q = (out_q - ref_q).abs().max().item()
@@ -162,7 +206,7 @@ def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim
162206

163207
e0.record()
164208
for _ in range(n_iter):
165-
_ = mrope_triton(q, k, cos, sin, mrope_section)
209+
_ = mrope_triton(q, k, cos, sin, axis_map)
166210
e1.record()
167211
torch.cuda.synchronize()
168212
t_tri = e0.elapsed_time(e1) / n_iter

0 commit comments

Comments
 (0)