Skip to content

Commit 6f3eeef

Browse files
author
wangzaijun
committed
fix kernel impl
1 parent 5525b77 commit 6f3eeef

File tree

3 files changed

+117
-9
lines changed

3 files changed

+117
-9
lines changed

lightllm/models/llama/triton_kernel/gqa_flash_decoding_stage1.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def _fwd_kernel_flash_decode_stage1(
7979
).to(tl.int64)
8080
off_k = k_loc[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None]
8181
k = tl.load(K + off_k, mask=offs_n_new[None, :] < cur_batch_end_index, other=0.0)
82-
att_value = tl.dot(q, k)
82+
att_value = tl.dot(q, k.to(q.dtype))
8383
att_value *= sm_scale
8484
att_value = tl.where(offs_n_new[None, :] < cur_batch_end_index, att_value, float("-inf"))
8585
v = tl.load(

lightllm/models/llama/triton_kernel/ppl_int8kv_flash_decoding_diverse_stage1.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def _fwd_kernel_flash_decode_stage1(
7777

7878
offs_n = cur_batch_start_index + tl.arange(0, BLOCK_N)
7979
Q_BATCH_HEAD_NUM: tl.constexpr = BLOCK_BATCH * BLOCK_HEAD
80-
q = tl.load(Q + off_q, other=0.0).view(Q_BATCH_HEAD_NUM, BLOCK_HEADDIM)
80+
q = tl.load(Q + off_q).reshape(Q_BATCH_HEAD_NUM, BLOCK_HEADDIM)
8181

8282
sum_exp = tl.zeros([Q_BATCH_HEAD_NUM], dtype=tl.float32)
8383
max_logic = tl.zeros([Q_BATCH_HEAD_NUM], dtype=tl.float32) - float("inf")
@@ -88,17 +88,17 @@ def _fwd_kernel_flash_decode_stage1(
8888
n_mask = offs_n_new < cur_batch_end_index
8989
k_loc = tl.load(
9090
Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n_new,
91-
mask=offs_n_new < cur_batch_end_index,
91+
mask=n_mask,
9292
other=0,
9393
).to(tl.int64)
9494
off_k = k_loc[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None]
9595
off_k_scale = off_k // KV_QUANT_GROUP_SIZE
9696
k = tl.load(K + off_k, mask=n_mask[None, :], other=0)
9797
k_scale = tl.load(K_scale + off_k_scale, mask=n_mask[None, :], other=0.0)
9898
k = k * k_scale
99-
att_value = tl.dot(q, k)
99+
att_value = tl.dot(q, k.to(q.dtype))
100100
att_value *= sm_scale
101-
att_value = tl.where(offs_n_new[None, :] < cur_batch_end_index, att_value, -1000000000.0)
101+
att_value = tl.where(n_mask[None, :], att_value, float("-inf"))
102102
v = tl.load(
103103
V + off_k.T,
104104
mask=n_mask[:, None],
@@ -117,7 +117,7 @@ def _fwd_kernel_flash_decode_stage1(
117117
exp_logic = tl.exp(att_value - new_max_logic[:, None])
118118
logic_scale = tl.exp(max_logic - new_max_logic)
119119
acc *= logic_scale[:, None]
120-
acc += tl.dot(exp_logic.to(v.dtype), v)
120+
acc += tl.dot(exp_logic.to(q.dtype), v.to(q.dtype))
121121

122122
sum_exp = sum_exp * logic_scale + tl.sum(exp_logic, axis=1)
123123
max_logic = new_max_logic
@@ -135,11 +135,11 @@ def _fwd_kernel_flash_decode_stage1(
135135
)
136136
tl.store(
137137
Mid_O + off_mid_o,
138-
(acc / sum_exp[:, None]).view(BLOCK_BATCH, BLOCK_HEAD, BLOCK_HEADDIM),
138+
(acc / sum_exp[:, None]).reshape(BLOCK_BATCH, BLOCK_HEAD, BLOCK_HEADDIM),
139139
)
140140
tl.store(
141141
Mid_O_LogExpSum + off_mid_o_logexpsum,
142-
(max_logic + tl.log(sum_exp)).view(BLOCK_BATCH, BLOCK_HEAD),
142+
(max_logic + tl.log(sum_exp)).reshape(BLOCK_BATCH, BLOCK_HEAD),
143143
)
144144
return
145145

@@ -169,6 +169,7 @@ def flash_decode_stage1(
169169
b_mark_shared_group 中每一个不为0的位置都代表其与前面多少个请求形成一个共享前缀组。属于
170170
同一个共享前缀组的请求, 其在对应的 b_shared_seq_len 中的内容必然相同。
171171
"""
172+
assert q.dim() == 3 and k.dim() == 3 and v.dim() == 3
172173
BLOCK_SEQ = block_seq
173174
BLOCK_N = 16
174175
assert BLOCK_SEQ % BLOCK_N == 0
@@ -182,6 +183,7 @@ def flash_decode_stage1(
182183
gqa_group_size = q.shape[1] // k.shape[1]
183184
assert triton.next_power_of_2(Lk) == Lk
184185
KV_QUANT_GROUP_SIZE = v.shape[-1] // v_scale.shape[-1]
186+
assert KV_QUANT_GROUP_SIZE == 8
185187
BLOCK_HEAD = triton.next_power_of_2(gqa_group_size)
186188
BLOCK_BATCH = triton.next_power_of_2(max_batch_group_size)
187189
if BLOCK_HEAD * BLOCK_BATCH < 16:
@@ -198,7 +200,7 @@ def flash_decode_stage1(
198200
stride_kh=k.stride(1),
199201
stride_kd=k.stride(2),
200202
V=v,
201-
V_scale=v,
203+
V_scale=v_scale,
202204
stride_vbs=v.stride(0),
203205
stride_vh=v.stride(1),
204206
stride_vd=v.stride(2),
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
import pytest
2+
import torch
3+
from lightllm.models.llama.triton_kernel.ppl_int8kv_flash_decoding_diverse_stage1 import flash_decode_stage1
4+
5+
6+
@pytest.fixture
7+
def setup_tensors():
8+
batch_size = 4
9+
num_heads = 4
10+
kv_head_num = 1
11+
seq_len = 256
12+
head_dim = 128
13+
max_len_in_batch = seq_len
14+
block_seq = 256
15+
max_batch_group_size = 4
16+
quant_group_size = 8
17+
18+
test_dtype = torch.float32
19+
20+
kv_shape = (batch_size * seq_len, kv_head_num, head_dim)
21+
kv_scale_shape = (batch_size * seq_len, kv_head_num, head_dim // quant_group_size)
22+
23+
q = torch.randn(size=(batch_size, num_heads, head_dim), dtype=test_dtype, device="cuda")
24+
k = torch.randint(low=-100, high=100, size=kv_shape, dtype=torch.int8, device="cuda")
25+
k_scale = torch.ones(size=kv_scale_shape, dtype=test_dtype, device="cuda")
26+
v = torch.randint(low=-100, high=100, size=kv_shape, dtype=torch.int8, device="cuda")
27+
v_scale = torch.ones(size=kv_scale_shape, dtype=test_dtype, device="cuda")
28+
Req_to_tokens = torch.arange(0, seq_len * batch_size, dtype=torch.int32, device="cuda").view(batch_size, seq_len)
29+
B_req_idx = torch.arange(batch_size, dtype=torch.int32, device="cuda")
30+
b_shared_seq_len = torch.full((batch_size,), seq_len, dtype=torch.int32, device="cuda")
31+
b_mark_shared_group = torch.ones(batch_size, dtype=torch.int32, device="cuda")
32+
mid_out = torch.zeros(
33+
size=(batch_size, num_heads, (seq_len // block_seq) + 2, head_dim), dtype=q.dtype, device="cuda"
34+
)
35+
mid_out_logsumexp = torch.zeros(
36+
size=(batch_size, num_heads, (seq_len // block_seq) + 2), dtype=q.dtype, device="cuda"
37+
)
38+
39+
return {
40+
"q": q,
41+
"k": k,
42+
"k_scale": k_scale,
43+
"v": v,
44+
"v_scale": v_scale,
45+
"Req_to_tokens": Req_to_tokens,
46+
"B_req_idx": B_req_idx,
47+
"b_shared_seq_len": b_shared_seq_len,
48+
"b_mark_shared_group": b_mark_shared_group,
49+
"max_len_in_batch": max_len_in_batch,
50+
"mid_out": mid_out,
51+
"mid_out_logsumexp": mid_out_logsumexp,
52+
"block_seq": block_seq,
53+
"max_batch_group_size": max_batch_group_size,
54+
}
55+
56+
57+
def test_flash_decode_stage1_execution(setup_tensors):
58+
flash_decode_stage1(
59+
q=setup_tensors["q"],
60+
k=setup_tensors["k"],
61+
k_scale=setup_tensors["k_scale"],
62+
v=setup_tensors["v"],
63+
v_scale=setup_tensors["v_scale"],
64+
Req_to_tokens=setup_tensors["Req_to_tokens"],
65+
B_req_idx=setup_tensors["B_req_idx"],
66+
b_shared_seq_len=setup_tensors["b_shared_seq_len"],
67+
b_mark_shared_group=setup_tensors["b_mark_shared_group"],
68+
max_len_in_batch=setup_tensors["max_len_in_batch"],
69+
mid_out=setup_tensors["mid_out"],
70+
mid_out_logsumexp=setup_tensors["mid_out_logsumexp"],
71+
block_seq=setup_tensors["block_seq"],
72+
max_batch_group_size=setup_tensors["max_batch_group_size"],
73+
)
74+
75+
q = setup_tensors["q"]
76+
k = setup_tensors["k"]
77+
v = setup_tensors["v"]
78+
true_mid_out = torch.zeros_like(setup_tensors["mid_out"])
79+
true_mid_out_logsumexp = torch.zeros_like(setup_tensors["mid_out_logsumexp"])
80+
new_q = q
81+
new_k = k.to(q.dtype)
82+
new_v = v.to(q.dtype)
83+
84+
from lightllm.models.llama.triton_kernel.gqa_flash_decoding_stage1 import (
85+
flash_decode_stage1 as gqa_flash_decode_stage1,
86+
)
87+
88+
gqa_flash_decode_stage1(
89+
q=new_q,
90+
k=new_k,
91+
v=new_v,
92+
Req_to_tokens=setup_tensors["Req_to_tokens"],
93+
B_req_idx=setup_tensors["B_req_idx"],
94+
B_Seqlen=setup_tensors["b_shared_seq_len"],
95+
max_len_in_batch=setup_tensors["max_len_in_batch"],
96+
mid_out=true_mid_out,
97+
mid_out_logsumexp=true_mid_out_logsumexp,
98+
block_seq=setup_tensors["block_seq"],
99+
)
100+
print(setup_tensors["mid_out"][0:4, 0, 0, 0], true_mid_out[0:4, 0, 0, 0])
101+
assert torch.allclose(
102+
setup_tensors["mid_out"][0:4, 0, 0, 0], true_mid_out[0:4, 0, 0, 0], atol=1e-2
103+
), "Mid output does not match expected values"
104+
assert torch.allclose(
105+
setup_tensors["mid_out_logsumexp"], true_mid_out_logsumexp, atol=1e-2
106+
), "LogSumExp output does not match expected values"

0 commit comments

Comments
 (0)