Skip to content

Commit 3bd5aab

Browse files
SangChengCsangchengmeng
andauthored
[add] tri_mrope (#844)
Co-authored-by: sangchengmeng <sangchengmeng@sensetime.com>
1 parent f696584 commit 3bd5aab

File tree

4 files changed

+291
-21
lines changed

4 files changed

+291
-21
lines changed

lightllm/models/qwen2_vl/layer_infer/transformer_layer_infer.py

Lines changed: 8 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2,32 +2,20 @@
22
import torch.functional as F
33
import torch.distributed as dist
44
import numpy as np
5-
6-
from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer
75
from functools import partial
86

9-
10-
def rotate_half(x):
11-
x1 = x[..., : x.shape[-1] // 2]
12-
x2 = x[..., x.shape[-1] // 2 :]
13-
return torch.cat((-x2, x1), dim=-1)
14-
15-
16-
def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
17-
mrope_section = mrope_section * 2
18-
cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(unsqueeze_dim)
19-
sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(unsqueeze_dim)
20-
21-
q_embed = (q * cos) + (rotate_half(q) * sin)
22-
k_embed = (k * cos) + (rotate_half(k) * sin)
23-
24-
return q_embed, k_embed
7+
from lightllm.models.qwen2_vl.triton_kernel.mrope import mrope_triton
8+
from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer
259

2610

2711
class Qwen2VLTransformerLayerInfer(LlamaTransformerLayerInfer):
2812
def __init__(self, layer_num, network_config, mode=[]):
2913
super().__init__(layer_num, network_config, mode)
3014
self.mrope_section = network_config["rope_scaling"]["mrope_section"]
15+
axis_map = []
16+
for i, n in enumerate(self.mrope_section * 2):
17+
axis_map += [i % 3] * n
18+
self.axis_map = torch.tensor(axis_map, dtype=torch.int32, device="cuda")
3119

3220
def _get_qkv(self, input, cache_kv, infer_state, layer_weight):
3321
q = layer_weight.q_proj.mm(input)
@@ -36,10 +24,9 @@ def _get_qkv(self, input, cache_kv, infer_state, layer_weight):
3624
).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_)
3725
seq_len, _ = q.shape
3826
q = q.view(1, seq_len, -1, self.head_dim_).transpose(1, 2)
27+
self.axis_map = self.axis_map.to(q.device)
3928
k = cache_kv[:, : self.tp_k_head_num_, :].view(1, seq_len, -1, self.head_dim_).transpose(1, 2)
40-
new_q, new_k = apply_multimodal_rotary_pos_emb(
41-
q, k, infer_state.position_cos, infer_state.position_sin, self.mrope_section
42-
)
29+
new_q, new_k = mrope_triton(q, k, infer_state.position_cos, infer_state.position_sin, self.axis_map)
4330
new_q = new_q.transpose(1, 2).reshape(1, seq_len, -1)
4431
cache_kv[:, : self.tp_k_head_num_, :] = new_k.squeeze(0).permute(1, 0, 2)
4532

lightllm/models/qwen2_vl/triton_kernel/__init__.py

Whitespace-only changes.
Lines changed: 222 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,222 @@
1+
import time
2+
import torch
3+
import triton
4+
import triton.language as tl
5+
6+
7+
@triton.jit
8+
def mrope_kernel(
9+
Q_ptr,
10+
K_ptr,
11+
COS_ptr,
12+
SIN_ptr,
13+
AXIS_ptr,
14+
QO_ptr,
15+
KO_ptr,
16+
B: tl.int32,
17+
H_q: tl.int32,
18+
H_k: tl.int32,
19+
L: tl.int32,
20+
D: tl.int32,
21+
HALF: tl.constexpr,
22+
s_tok: tl.int32,
23+
s_ax: tl.int32,
24+
q_sb: tl.int32,
25+
q_sh: tl.int32,
26+
q_sl: tl.int32,
27+
q_sd: tl.int32,
28+
k_sb: tl.int32,
29+
k_sh: tl.int32,
30+
k_sl: tl.int32,
31+
k_sd: tl.int32,
32+
qo_sb: tl.int32,
33+
qo_sh: tl.int32,
34+
qo_sl: tl.int32,
35+
qo_sd: tl.int32,
36+
ko_sb: tl.int32,
37+
ko_sh: tl.int32,
38+
ko_sl: tl.int32,
39+
ko_sd: tl.int32,
40+
BLOCK_D: tl.constexpr,
41+
):
42+
43+
total_h = H_q + H_k
44+
pid_bh = tl.program_id(0)
45+
pid_l = tl.program_id(1)
46+
47+
b = pid_bh // total_h
48+
h_local = pid_bh - b * total_h
49+
50+
is_q = h_local < H_q
51+
h_q = h_local
52+
h_k = h_local - H_q
53+
54+
sb = tl.where(is_q, q_sb, k_sb)
55+
sh = tl.where(is_q, q_sh, k_sh)
56+
sl = tl.where(is_q, q_sl, k_sl)
57+
sd = tl.where(is_q, q_sd, k_sd)
58+
59+
osb = tl.where(is_q, qo_sb, ko_sb)
60+
osh = tl.where(is_q, qo_sh, ko_sh)
61+
osl = tl.where(is_q, qo_sl, ko_sl)
62+
osd = tl.where(is_q, qo_sd, ko_sd)
63+
64+
base_ptr = tl.where(is_q, Q_ptr, K_ptr)
65+
out_ptr = tl.where(is_q, QO_ptr, KO_ptr)
66+
h_index = tl.where(is_q, h_q, h_k)
67+
68+
base = b * sb + h_index * sh + pid_l * sl
69+
offs = tl.arange(0, BLOCK_D)
70+
mask = offs < D
71+
72+
idx = base + offs * sd
73+
vals = tl.load(base_ptr + idx, mask=mask, other=0.0)
74+
75+
rot_offs = tl.where(offs < HALF, (offs + HALF) * sd, (offs - HALF) * sd)
76+
rot_vals = tl.load(base_ptr + base + rot_offs, mask=mask, other=0.0)
77+
rot_vals = tl.where(offs < HALF, -rot_vals, rot_vals)
78+
79+
axis_id = tl.load(AXIS_ptr + offs, mask=mask, other=0) # 0,1,2
80+
cos_idx = pid_l * s_tok + axis_id * s_ax + offs
81+
c = tl.load(COS_ptr + cos_idx, mask=mask, other=0.0)
82+
s = tl.load(SIN_ptr + cos_idx, mask=mask, other=0.0)
83+
84+
out = vals * c + rot_vals * s
85+
86+
out_idx = b * osb + h_index * osh + pid_l * osl + offs * osd
87+
tl.store(out_ptr + out_idx, out, mask=mask)
88+
89+
90+
def mrope_triton(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, axis_map: torch.Tensor):
91+
92+
B, H_q, L, D = q.shape
93+
H_k = k.shape[1]
94+
HALF = D // 2
95+
96+
q_sb, q_sh, q_sl, q_sd = map(int, q.stride())
97+
k_sb, k_sh, k_sl, k_sd = map(int, k.stride())
98+
99+
q_out = torch.empty_like(q)
100+
k_out = torch.empty_like(k)
101+
qo_sb, qo_sh, qo_sl, qo_sd = map(int, q_out.stride())
102+
ko_sb, ko_sh, ko_sl, ko_sd = map(int, k_out.stride())
103+
104+
token_dim = next(i for i, s in enumerate(cos.shape) if s == L)
105+
axis_dim = next(i for i, s in enumerate(cos.shape) if s == 3)
106+
107+
s_token = int(cos.stride(token_dim))
108+
s_axis = int(cos.stride(axis_dim))
109+
110+
grid = (B * (H_q + H_k), L)
111+
112+
mrope_kernel[grid](
113+
q,
114+
k,
115+
cos,
116+
sin,
117+
axis_map,
118+
q_out,
119+
k_out,
120+
B,
121+
H_q,
122+
H_k,
123+
L,
124+
D,
125+
HALF,
126+
s_token,
127+
s_axis,
128+
q_sb,
129+
q_sh,
130+
q_sl,
131+
q_sd,
132+
k_sb,
133+
k_sh,
134+
k_sl,
135+
k_sd,
136+
qo_sb,
137+
qo_sh,
138+
qo_sl,
139+
qo_sd,
140+
ko_sb,
141+
ko_sh,
142+
ko_sl,
143+
ko_sd,
144+
BLOCK_D=128,
145+
num_warps=4,
146+
num_stages=3,
147+
)
148+
return q_out, k_out
149+
150+
151+
# ---------------- test ---------------- #
152+
def test():
153+
154+
# torch实现
155+
def rotate_half(x: torch.Tensor):
156+
x1 = x[..., : x.shape[-1] // 2]
157+
x2 = x[..., x.shape[-1] // 2 :]
158+
return torch.cat((-x2, x1), dim=-1)
159+
160+
def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
161+
chunks = mrope_section * 2
162+
cos_embed = torch.cat(
163+
[m[i % 3] for i, m in enumerate(cos.split(chunks, dim=-1))],
164+
dim=-1,
165+
).unsqueeze(unsqueeze_dim)
166+
sin_embed = torch.cat(
167+
[m[i % 3] for i, m in enumerate(sin.split(chunks, dim=-1))],
168+
dim=-1,
169+
).unsqueeze(unsqueeze_dim)
170+
171+
q_out = q * cos_embed + rotate_half(q) * sin_embed
172+
k_out = k * cos_embed + rotate_half(k) * sin_embed
173+
return q_out, k_out
174+
175+
B, H_q, H_k, L, D = 3, 28, 4, 16384, 128
176+
mrope_section = [16, 24, 24]
177+
torch.manual_seed(0)
178+
device = "cuda"
179+
180+
q = torch.rand(B, H_q, L, D, dtype=torch.float32, device=device).transpose(1, 2).contiguous().transpose(1, 2)
181+
k = torch.rand(B, H_k, L, D, dtype=torch.float32, device=device).transpose(1, 2).contiguous().transpose(1, 2)
182+
cos = torch.rand(3, 1, L, D, dtype=torch.float32, device=device)
183+
sin = torch.rand(3, 1, L, D, dtype=torch.float32, device=device)
184+
185+
# 精度对比
186+
axis_map = []
187+
for i, n in enumerate(mrope_section * 2):
188+
axis_map += [i % 3] * n
189+
axis_map = torch.tensor(axis_map, dtype=torch.int32, device="cuda")
190+
ref_q, ref_k = apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1)
191+
192+
torch.cuda.synchronize()
193+
out_q, out_k = mrope_triton(q, k, cos, sin, axis_map)
194+
torch.cuda.synchronize()
195+
196+
err_q = (out_q - ref_q).abs().max().item()
197+
err_k = (out_k - ref_k).abs().max().item()
198+
print(f"abs‑max error q:{err_q:.6f}, k:{err_k:.6f}")
199+
200+
assert err_q < 1e-2 and err_k < 1e-2
201+
202+
# 速度对比
203+
n_iter = 100
204+
e0 = torch.cuda.Event(enable_timing=True)
205+
e1 = torch.cuda.Event(enable_timing=True)
206+
207+
e0.record()
208+
for _ in range(n_iter):
209+
_ = apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1)
210+
e1.record()
211+
torch.cuda.synchronize()
212+
t_ref = e0.elapsed_time(e1) / n_iter
213+
214+
e0.record()
215+
for _ in range(n_iter):
216+
_ = mrope_triton(q, k, cos, sin, axis_map)
217+
e1.record()
218+
torch.cuda.synchronize()
219+
t_tri = e0.elapsed_time(e1) / n_iter
220+
221+
print(f"torch {t_ref:.2f} ms/iter")
222+
print(f"triton {t_tri:.2f} ms/iter")
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
import torch
2+
import pytest
3+
4+
# Import the Triton kernel function under test. Adjust the import path as needed.
5+
from lightllm.models.qwen2_vl.triton_kernel.mrope import mrope_triton
6+
7+
# Reference Python implementation for multimodal rotary positional embeddings
8+
9+
10+
def rotate_half(x):
11+
x1 = x[..., : x.shape[-1] // 2]
12+
x2 = x[..., x.shape[-1] // 2 :]
13+
return torch.cat((-x2, x1), dim=-1)
14+
15+
16+
def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
17+
mrope_section = mrope_section * 2
18+
cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(unsqueeze_dim)
19+
sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(unsqueeze_dim)
20+
21+
q_embed = (q * cos) + (rotate_half(q) * sin)
22+
k_embed = (k * cos) + (rotate_half(k) * sin)
23+
24+
return q_embed, k_embed
25+
26+
27+
@pytest.mark.parametrize(
28+
"B,H_q,H_k,L,D,mrope_section",
29+
[
30+
(1, 2, 1, 16, 8, [4]),
31+
(1, 4, 2, 32, 16, [8]),
32+
(2, 3, 2, 16, 8, [4]),
33+
],
34+
)
35+
def test_mrope_triton_correctness(B, H_q, H_k, L, D, mrope_section):
36+
"""
37+
Test that the Triton kernel matches the reference PyTorch implementation.
38+
"""
39+
axis_map = []
40+
for i, n in enumerate(mrope_section * 2):
41+
axis_map += [i % 3] * n
42+
axis_map = torch.tensor(axis_map, dtype=torch.int32, device="cuda")
43+
44+
torch.manual_seed(0)
45+
device = "cuda"
46+
47+
q = torch.rand((B, H_q, L, D), dtype=torch.float32, device=device)
48+
k = torch.rand((B, H_k, L, D), dtype=torch.float32, device=device)
49+
cos = torch.rand((3, 1, L, D), dtype=torch.float32, device=device)
50+
sin = torch.rand((3, 1, L, D), dtype=torch.float32, device=device)
51+
52+
ref_q, ref_k = apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1)
53+
54+
out_q, out_k = mrope_triton(q, k, cos, sin, axis_map)
55+
56+
assert torch.allclose(out_q, ref_q, rtol=1e-3, atol=1e-3)
57+
assert torch.allclose(out_k, ref_k, rtol=1e-3, atol=1e-3)
58+
59+
60+
if __name__ == "__main__":
61+
pytest.main()

0 commit comments

Comments
 (0)