Skip to content

Commit b7a3942

Browse files
author
sangchengmeng
committed
[add] tri_mrope
1 parent 1bc2342 commit b7a3942

File tree

2 files changed

+179
-24
lines changed

2 files changed

+179
-24
lines changed

lightllm/models/qwen2_vl/layer_infer/transformer_layer_infer.py

Lines changed: 7 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2,26 +2,10 @@
22
import torch.functional as F
33
import torch.distributed as dist
44
import numpy as np
5-
6-
from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer
75
from functools import partial
86

9-
10-
def rotate_half(x):
11-
x1 = x[..., : x.shape[-1] // 2]
12-
x2 = x[..., x.shape[-1] // 2 :]
13-
return torch.cat((-x2, x1), dim=-1)
14-
15-
16-
def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
17-
mrope_section = mrope_section * 2
18-
cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(unsqueeze_dim)
19-
sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(unsqueeze_dim)
20-
21-
q_embed = (q * cos) + (rotate_half(q) * sin)
22-
k_embed = (k * cos) + (rotate_half(k) * sin)
23-
24-
return q_embed, k_embed
7+
from lightllm.models.qwen2_vl.triton_kernel.mrope import mrope_triton
8+
from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer
259

2610

2711
class Qwen2VLTransformerLayerInfer(LlamaTransformerLayerInfer):
@@ -35,12 +19,11 @@ def _get_qkv(self, input, cache_kv, infer_state, layer_weight):
3519
input, out=cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_)
3620
).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_)
3721
seq_len, _ = q.shape
38-
q = q.view(1, seq_len, -1, self.head_dim_).transpose(1, 2)
39-
k = cache_kv[:, : self.tp_k_head_num_, :].view(1, seq_len, -1, self.head_dim_).transpose(1, 2)
40-
new_q, new_k = apply_multimodal_rotary_pos_emb(
41-
q, k, infer_state.position_cos, infer_state.position_sin, self.mrope_section
42-
)
43-
new_q = new_q.transpose(1, 2).reshape(1, seq_len, -1)
22+
q = q.view(1, seq_len, -1, self.head_dim_).transpose(1, 2).contiguous()
23+
k = cache_kv[:, : self.tp_k_head_num_, :].view(1, seq_len, -1, self.head_dim_).transpose(1, 2).contiguous()
24+
new_q, new_k = mrope_triton(q, k, infer_state.position_cos, infer_state.position_sin, self.mrope_section)
25+
new_q = new_q.transpose(1, 2).reshape(1, seq_len, -1).contiguous()
4426
cache_kv[:, : self.tp_k_head_num_, :] = new_k.squeeze(0).permute(1, 0, 2)
4527

4628
return new_q, cache_kv
29+
Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
import time
2+
import torch
3+
import triton
4+
import triton.language as tl
5+
6+
7+
@triton.jit
8+
def mrope_kernel_combined(
9+
Q_ptr,
10+
K_ptr,
11+
COS_ptr,
12+
SIN_ptr,
13+
AXIS_MAP_ptr,
14+
Q_out_ptr,
15+
K_out_ptr,
16+
B: tl.int32,
17+
H_q: tl.int32,
18+
H_k: tl.int32,
19+
L: tl.int32,
20+
D: tl.int32,
21+
HALF: tl.int32,
22+
BLOCK_D: tl.constexpr,
23+
):
24+
total_h = H_q + H_k
25+
26+
pid_bh = tl.program_id(0)
27+
pid_l = tl.program_id(1)
28+
29+
b = pid_bh // total_h
30+
head_local = pid_bh - b * total_h
31+
32+
# decide whether this head comes from q or k
33+
is_q = head_local < H_q
34+
head_q = head_local
35+
head_k = head_local - H_q
36+
37+
base_ptr = tl.where(is_q, Q_ptr, K_ptr)
38+
out_ptr = tl.where(is_q, Q_out_ptr, K_out_ptr)
39+
h_sub = tl.where(is_q, head_q, head_k)
40+
H_sub = tl.where(is_q, H_q, H_k)
41+
42+
# base offset for (b, h_sub, pid_l)
43+
base = ((b * H_sub + h_sub) * L + pid_l) * D
44+
45+
offs = tl.arange(0, BLOCK_D)
46+
idx = base + offs
47+
mask = offs < D
48+
49+
vals = tl.load(base_ptr + idx, mask=mask, other=0.0)
50+
axis_id = tl.load(AXIS_MAP_ptr + offs, mask=mask, other=0)
51+
axis_id_b = b * 3 + axis_id
52+
53+
seq_off = pid_l * D
54+
cos_idx = axis_id_b * (L * D) + seq_off + offs
55+
c = tl.load(COS_ptr + cos_idx, mask=mask, other=0.0)
56+
s = tl.load(SIN_ptr + cos_idx, mask=mask, other=0.0)
57+
58+
# rotate_half
59+
rot_idx = tl.where(offs < HALF, idx + HALF, idx - HALF)
60+
rot_vals = tl.load(base_ptr + rot_idx, mask=mask, other=0.0)
61+
sign = tl.where(offs < HALF, -1.0, 1.0)
62+
rot_vals *= sign
63+
64+
out_vals = vals * c + rot_vals * s
65+
tl.store(out_ptr + idx, out_vals, mask=mask)
66+
67+
68+
def mrope_triton(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, mrope_section):
69+
B, H_q, L, D = q.shape
70+
H_k = k.shape[1]
71+
72+
# build axis_map 0/1/2 label per feature dim
73+
axis_map = []
74+
for i, n in enumerate(mrope_section * 2):
75+
axis_map += [i % 3] * n
76+
axis_map = torch.tensor(axis_map, dtype=torch.int32, device=q.device)
77+
78+
cos_flat = cos.transpose(0, 1).expand(B, 3, L, D).contiguous().reshape(-1)
79+
sin_flat = sin.transpose(0, 1).expand(B, 3, L, D).contiguous().reshape(-1)
80+
81+
q_out = torch.empty_like(q)
82+
k_out = torch.empty_like(k)
83+
84+
grid = (B * (H_q + H_k), L)
85+
mrope_kernel_combined[grid](
86+
q,
87+
k,
88+
cos_flat,
89+
sin_flat,
90+
axis_map,
91+
q_out,
92+
k_out,
93+
B,
94+
H_q,
95+
H_k,
96+
L,
97+
D,
98+
D // 2,
99+
BLOCK_D=128,
100+
)
101+
return q_out, k_out
102+
103+
104+
# ---------------- test ---------------- #
105+
def test():
106+
107+
# torch实现
108+
def rotate_half(x: torch.Tensor):
109+
x1 = x[..., : x.shape[-1] // 2]
110+
x2 = x[..., x.shape[-1] // 2 :]
111+
return torch.cat((-x2, x1), dim=-1)
112+
113+
def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
114+
chunks = mrope_section * 2
115+
cos_embed = torch.cat(
116+
[m[i % 3] for i, m in enumerate(cos.split(chunks, dim=-1))],
117+
dim=-1,
118+
).unsqueeze(unsqueeze_dim)
119+
sin_embed = torch.cat(
120+
[m[i % 3] for i, m in enumerate(sin.split(chunks, dim=-1))],
121+
dim=-1,
122+
).unsqueeze(unsqueeze_dim)
123+
124+
q_out = q * cos_embed + rotate_half(q) * sin_embed
125+
k_out = k * cos_embed + rotate_half(k) * sin_embed
126+
return q_out, k_out
127+
128+
B, H_q, H_k, L, D = 1, 28, 4, 16384, 128
129+
mrope_section = [16, 24, 24]
130+
torch.manual_seed(0)
131+
device = "cuda"
132+
133+
q = torch.rand(B, H_q, L, D, dtype=torch.float32, device=device)
134+
k = torch.rand(B, H_k, L, D, dtype=torch.float32, device=device)
135+
cos = torch.rand(3, 1, L, D, dtype=torch.float32, device=device)
136+
sin = torch.rand(3, 1, L, D, dtype=torch.float32, device=device)
137+
138+
# 精度对比
139+
ref_q, ref_k = apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1)
140+
141+
torch.cuda.synchronize()
142+
out_q, out_k = mrope_triton(q, k, cos, sin, mrope_section)
143+
torch.cuda.synchronize()
144+
145+
err_q = (out_q - ref_q).abs().max().item()
146+
err_k = (out_k - ref_k).abs().max().item()
147+
print(f"abs‑max error q:{err_q:.6f}, k:{err_k:.6f}")
148+
149+
assert err_q < 1e-2 and err_k < 1e-2
150+
151+
# 速度对比
152+
n_iter = 100
153+
e0 = torch.cuda.Event(enable_timing=True)
154+
e1 = torch.cuda.Event(enable_timing=True)
155+
156+
e0.record()
157+
for _ in range(n_iter):
158+
_ = apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1)
159+
e1.record()
160+
torch.cuda.synchronize()
161+
t_ref = e0.elapsed_time(e1) / n_iter
162+
163+
e0.record()
164+
for _ in range(n_iter):
165+
_ = mrope_triton(q, k, cos, sin, mrope_section)
166+
e1.record()
167+
torch.cuda.synchronize()
168+
t_tri = e0.elapsed_time(e1) / n_iter
169+
170+
print(f"torch {t_ref:.2f} ms/iter")
171+
print(f"triton {t_tri:.2f} ms/iter")
172+

0 commit comments

Comments
 (0)