Skip to content

Commit 0f7dedb

Browse files
author
wangzaijun
committed
add autotune
1 parent 6f3eeef commit 0f7dedb

11 files changed

+470
-40
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
{"4096": {"8": {"BLOCK_N": 16, "num_warps": 16, "num_stages": 9}, "32": {"BLOCK_N": 64, "num_warps": 16, "num_stages": 9}, "128": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 4}, "256": {"BLOCK_N": 16, "num_warps": 4, "num_stages": 10}}, "8192": {"8": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 3}, "32": {"BLOCK_N": 32, "num_warps": 8, "num_stages": 4}, "128": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 11}, "256": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 11}}}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
{"4096": {"8": {"BLOCK_N": 16, "num_warps": 16, "num_stages": 9}, "32": {"BLOCK_N": 64, "num_warps": 16, "num_stages": 9}, "128": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 4}, "256": {"BLOCK_N": 16, "num_warps": 4, "num_stages": 10}}, "8192": {"8": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 3}, "32": {"BLOCK_N": 32, "num_warps": 8, "num_stages": 4}, "128": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 11}, "256": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 11}}}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
{"4096": {"8": {"BLOCK_N": 32, "num_warps": 16, "num_stages": 7}, "32": {"BLOCK_N": 16, "num_warps": 16, "num_stages": 7}, "128": {"BLOCK_N": 64, "num_warps": 16, "num_stages": 7}, "256": {"BLOCK_N": 32, "num_warps": 2, "num_stages": 11}}, "8192": {"8": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 11}, "32": {"BLOCK_N": 16, "num_warps": 4, "num_stages": 10}, "128": {"BLOCK_N": 64, "num_warps": 16, "num_stages": 9}, "256": {"BLOCK_N": 16, "num_warps": 4, "num_stages": 11}}}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
{"4096": {"8": {"BLOCK_N": 32, "num_warps": 16, "num_stages": 7}, "32": {"BLOCK_N": 16, "num_warps": 16, "num_stages": 7}, "128": {"BLOCK_N": 64, "num_warps": 16, "num_stages": 7}, "256": {"BLOCK_N": 32, "num_warps": 2, "num_stages": 11}}, "8192": {"8": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 11}, "32": {"BLOCK_N": 16, "num_warps": 4, "num_stages": 10}, "128": {"BLOCK_N": 64, "num_warps": 16, "num_stages": 9}, "256": {"BLOCK_N": 16, "num_warps": 4, "num_stages": 11}}}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
{"4096": {"8": {"BLOCK_N": 64, "num_warps": 16, "num_stages": 7}, "32": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 2}, "128": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 3}, "256": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 4}}, "8192": {"8": {"BLOCK_N": 16, "num_warps": 8, "num_stages": 2}, "32": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 11}, "128": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 10}, "256": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 1}}}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
{"4096": {"8": {"BLOCK_N": 64, "num_warps": 16, "num_stages": 7}, "32": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 2}, "128": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 3}, "256": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 4}}, "8192": {"8": {"BLOCK_N": 16, "num_warps": 8, "num_stages": 2}, "32": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 11}, "128": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 10}, "256": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 1}}}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
{"4096": {"8": {"BLOCK_N": 16, "num_warps": 8, "num_stages": 9}, "32": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 1}, "128": {"BLOCK_N": 64, "num_warps": 16, "num_stages": 7}, "256": {"BLOCK_N": 16, "num_warps": 16, "num_stages": 2}}, "8192": {"8": {"BLOCK_N": 32, "num_warps": 2, "num_stages": 7}, "32": {"BLOCK_N": 16, "num_warps": 4, "num_stages": 1}, "128": {"BLOCK_N": 16, "num_warps": 4, "num_stages": 11}, "256": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 3}}}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
{"4096": {"8": {"BLOCK_N": 16, "num_warps": 8, "num_stages": 9}, "32": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 1}, "128": {"BLOCK_N": 64, "num_warps": 16, "num_stages": 7}, "256": {"BLOCK_N": 16, "num_warps": 16, "num_stages": 2}}, "8192": {"8": {"BLOCK_N": 32, "num_warps": 2, "num_stages": 7}, "32": {"BLOCK_N": 16, "num_warps": 4, "num_stages": 1}, "128": {"BLOCK_N": 16, "num_warps": 4, "num_stages": 11}, "256": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 3}}}

lightllm/models/llama/triton_kernel/ppl_int8kv_flash_decoding_diverse_stage1.py

Lines changed: 91 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,77 @@
11
import torch
22
import triton
33
import triton.language as tl
4+
from typing import Optional
5+
from lightllm.common.kernel_config import KernelConfigs
6+
from frozendict import frozendict
7+
from functools import lru_cache
8+
from typing import Dict
9+
10+
11+
class GQADiverseDecodeStage1KernelConfig(KernelConfigs):
12+
kernel_name: str = "_fwd_kernel_flash_decode_diverse_stage1:v1"
13+
14+
@classmethod
15+
@lru_cache(maxsize=200)
16+
def try_to_get_best_config(
17+
cls,
18+
batch_size: int,
19+
avg_seq_len_in_batch: int,
20+
gqa_group_size: int,
21+
q_head_dim: int,
22+
block_seq: int,
23+
out_dtype: str,
24+
) -> dict:
25+
key_params = {
26+
"gqa_group_size": gqa_group_size,
27+
"q_head_dim": q_head_dim,
28+
"block_seq": block_seq,
29+
"out_dtype": str(out_dtype),
30+
}
31+
key_params = frozendict(key_params)
32+
33+
finded_config = cls.get_the_config(key_params)
34+
35+
if finded_config:
36+
batch_size_config: dict = finded_config[
37+
min(
38+
finded_config.keys(),
39+
key=lambda x: abs(int(x) - avg_seq_len_in_batch),
40+
)
41+
]
42+
config = batch_size_config[min(batch_size_config.keys(), key=lambda x: abs(int(x) - batch_size))]
43+
44+
return config
45+
else:
46+
config = {
47+
"BLOCK_N": 16,
48+
"num_warps": 2,
49+
"num_stages": 2,
50+
}
51+
return config
52+
53+
@classmethod
54+
def save_config(
55+
cls,
56+
gqa_group_size: int,
57+
q_head_dim: int,
58+
block_seq: int,
59+
out_dtype: str,
60+
config_json: Dict[int, Dict[int, Dict]],
61+
):
62+
key_params = {
63+
"gqa_group_size": gqa_group_size,
64+
"q_head_dim": q_head_dim,
65+
"block_seq": block_seq,
66+
"out_dtype": str(out_dtype),
67+
}
68+
key_params = frozendict(key_params)
69+
70+
return cls.store_config(key_params, config_json)
471

572

673
@triton.jit
7-
def _fwd_kernel_flash_decode_stage1(
74+
def _fwd_kernel_flash_decode_diverse_stage1(
875
Q,
976
stride_qbs,
1077
stride_qh,
@@ -160,6 +227,7 @@ def flash_decode_stage1(
160227
mid_out_logsumexp: torch.Tensor,
161228
block_seq: int,
162229
max_batch_group_size: int,
230+
run_config: Optional[dict] = None,
163231
):
164232
"""
165233
该kernel是为多样性生成定制的gqa算子,其中 b_mark_shared_group 是一个shape 为 (batch_size,)的tensor,
@@ -169,9 +237,27 @@ def flash_decode_stage1(
169237
b_mark_shared_group 中每一个不为0的位置都代表其与前面多少个请求形成一个共享前缀组。属于
170238
同一个共享前缀组的请求, 其在对应的 b_shared_seq_len 中的内容必然相同。
171239
"""
240+
if not run_config:
241+
if torch.cuda.is_current_stream_capturing():
242+
avg_seq_len_in_batch = max_len_in_batch
243+
else:
244+
avg_seq_len_in_batch = max_len_in_batch
245+
246+
run_config = GQADiverseDecodeStage1KernelConfig.try_to_get_best_config(
247+
batch_size=int(q.shape[0]),
248+
avg_seq_len_in_batch=avg_seq_len_in_batch,
249+
gqa_group_size=int(q.shape[1] // k.shape[1]),
250+
q_head_dim=int(q.shape[2]),
251+
block_seq=block_seq,
252+
out_dtype=q.dtype,
253+
)
254+
255+
BLOCK_N = run_config["BLOCK_N"]
256+
num_warps = run_config["num_warps"]
257+
num_stages = run_config["num_stages"]
258+
172259
assert q.dim() == 3 and k.dim() == 3 and v.dim() == 3
173260
BLOCK_SEQ = block_seq
174-
BLOCK_N = 16
175261
assert BLOCK_SEQ % BLOCK_N == 0
176262
# shape constraints
177263
Lq, Lk = q.shape[-1], k.shape[-1]
@@ -189,7 +275,7 @@ def flash_decode_stage1(
189275
if BLOCK_HEAD * BLOCK_BATCH < 16:
190276
BLOCK_BATCH = 16 // BLOCK_HEAD
191277

192-
_fwd_kernel_flash_decode_stage1[grid](
278+
_fwd_kernel_flash_decode_diverse_stage1[grid](
193279
Q=q,
194280
stride_qbs=q.stride(0),
195281
stride_qh=q.stride(1),
@@ -227,7 +313,7 @@ def flash_decode_stage1(
227313
BLOCK_N=BLOCK_N,
228314
BLOCK_BATCH=BLOCK_BATCH,
229315
KV_QUANT_GROUP_SIZE=KV_QUANT_GROUP_SIZE,
230-
num_warps=2,
231-
num_stages=2,
316+
num_warps=num_warps,
317+
num_stages=num_stages,
232318
)
233319
return

test/kernel/llama_gqa_decode_vsm_tuning.py

Lines changed: 32 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -267,41 +267,38 @@ def tuning_configs(
267267
torch.multiprocessing.set_start_method("spawn")
268268

269269
from lightllm.utils.tuning_utils import mp_tuning
270-
from lightllm.models.deepseek2.triton_kernel.gqa_flash_decoding_config import MlaDecodeAttentionKernelConfig
271-
272270
import collections
273271

274-
store_json_ans = collections.defaultdict(dict)
275-
276272
def config_iter():
277-
for q_head_num in [32]:
278-
for q_head_dim in [64, 128]:
279-
for group_size in [8, 16, 32]:
280-
for batch_size in [1, 8, 16, 32, 64, 128, 256]:
281-
for seq_len in [256, 512, 1024, 2048, 4096, 8192]:
282-
if batch_size * seq_len > 128 * 1024 * 4:
283-
continue
284-
yield q_head_num, q_head_dim, group_size, batch_size, seq_len
285-
286-
for q_head_num, q_head_dim, group_size, batch_size, seq_len in config_iter():
287-
288-
kv_head_num = q_head_num // group_size
289-
ans = mp_tuning(
290-
tuning_configs,
291-
{
292-
"q_shape": [batch_size, q_head_num, q_head_dim],
293-
"kv_shape": [batch_size * seq_len, kv_head_num, q_head_dim],
294-
"test_seq_len": seq_len,
295-
"dtype": torch.half,
296-
"test_count": 1,
297-
},
298-
)
299-
store_json_ans[seq_len][batch_size] = ans
300-
301-
GQAVSMDecodeAttentionKernelConfig.save_config(
302-
q_head_num=q_head_num,
303-
q_head_dim=q_head_dim,
304-
kv_head_num=kv_head_num,
305-
out_dtype=str(torch.half),
306-
config_json=store_json_ans,
307-
)
273+
for batch_size in [1, 8, 16, 32, 64, 128, 256]:
274+
for seq_len in [256, 512, 1024, 2048, 4096, 8192]:
275+
if batch_size * seq_len > 128 * 1024 * 4:
276+
continue
277+
yield batch_size, seq_len
278+
279+
for q_head_num in [32]:
280+
for q_head_dim in [64, 128]:
281+
for group_size in [8, 16, 32]:
282+
store_json_ans = collections.defaultdict(dict)
283+
for batch_size, seq_len in config_iter():
284+
285+
kv_head_num = q_head_num // group_size
286+
ans = mp_tuning(
287+
tuning_configs,
288+
{
289+
"q_shape": [batch_size, q_head_num, q_head_dim],
290+
"kv_shape": [batch_size * seq_len, kv_head_num, q_head_dim],
291+
"test_seq_len": seq_len,
292+
"dtype": torch.half,
293+
"test_count": 1,
294+
},
295+
)
296+
store_json_ans[seq_len][batch_size] = ans
297+
298+
GQAVSMDecodeAttentionKernelConfig.save_config(
299+
q_head_num=q_head_num,
300+
q_head_dim=q_head_dim,
301+
kv_head_num=kv_head_num,
302+
out_dtype=str(torch.half),
303+
config_json=store_json_ans,
304+
)

0 commit comments

Comments
 (0)