11import torch
22import triton
33import triton .language as tl
4+ from typing import Optional
5+ from lightllm .common .kernel_config import KernelConfigs
6+ from frozendict import frozendict
7+ from functools import lru_cache
8+ from typing import Dict
9+
10+
11+ class GQADiverseDecodeStage1KernelConfig (KernelConfigs ):
12+ kernel_name : str = "_fwd_kernel_flash_decode_diverse_stage1:v1"
13+
14+ @classmethod
15+ @lru_cache (maxsize = 200 )
16+ def try_to_get_best_config (
17+ cls ,
18+ batch_size : int ,
19+ avg_seq_len_in_batch : int ,
20+ gqa_group_size : int ,
21+ q_head_dim : int ,
22+ block_seq : int ,
23+ out_dtype : str ,
24+ ) -> dict :
25+ key_params = {
26+ "gqa_group_size" : gqa_group_size ,
27+ "q_head_dim" : q_head_dim ,
28+ "block_seq" : block_seq ,
29+ "out_dtype" : str (out_dtype ),
30+ }
31+ key_params = frozendict (key_params )
32+
33+ finded_config = cls .get_the_config (key_params )
34+
35+ if finded_config :
36+ batch_size_config : dict = finded_config [
37+ min (
38+ finded_config .keys (),
39+ key = lambda x : abs (int (x ) - avg_seq_len_in_batch ),
40+ )
41+ ]
42+ config = batch_size_config [min (batch_size_config .keys (), key = lambda x : abs (int (x ) - batch_size ))]
43+
44+ return config
45+ else :
46+ config = {
47+ "BLOCK_N" : 16 ,
48+ "num_warps" : 2 ,
49+ "num_stages" : 2 ,
50+ }
51+ return config
52+
53+ @classmethod
54+ def save_config (
55+ cls ,
56+ gqa_group_size : int ,
57+ q_head_dim : int ,
58+ block_seq : int ,
59+ out_dtype : str ,
60+ config_json : Dict [int , Dict [int , Dict ]],
61+ ):
62+ key_params = {
63+ "gqa_group_size" : gqa_group_size ,
64+ "q_head_dim" : q_head_dim ,
65+ "block_seq" : block_seq ,
66+ "out_dtype" : str (out_dtype ),
67+ }
68+ key_params = frozendict (key_params )
69+
70+ return cls .store_config (key_params , config_json )
471
572
673@triton .jit
7- def _fwd_kernel_flash_decode_stage1 (
74+ def _fwd_kernel_flash_decode_diverse_stage1 (
875 Q ,
976 stride_qbs ,
1077 stride_qh ,
@@ -160,6 +227,7 @@ def flash_decode_stage1(
160227 mid_out_logsumexp : torch .Tensor ,
161228 block_seq : int ,
162229 max_batch_group_size : int ,
230+ run_config : Optional [dict ] = None ,
163231):
164232 """
165233 该kernel是为多样性生成定制的gqa算子,其中 b_mark_shared_group 是一个shape 为 (batch_size,)的tensor,
@@ -169,9 +237,27 @@ def flash_decode_stage1(
169237 b_mark_shared_group 中每一个不为0的位置都代表其与前面多少个请求形成一个共享前缀组。属于
170238 同一个共享前缀组的请求, 其在对应的 b_shared_seq_len 中的内容必然相同。
171239 """
240+ if not run_config :
241+ if torch .cuda .is_current_stream_capturing ():
242+ avg_seq_len_in_batch = max_len_in_batch
243+ else :
244+ avg_seq_len_in_batch = max_len_in_batch
245+
246+ run_config = GQADiverseDecodeStage1KernelConfig .try_to_get_best_config (
247+ batch_size = int (q .shape [0 ]),
248+ avg_seq_len_in_batch = avg_seq_len_in_batch ,
249+ gqa_group_size = int (q .shape [1 ] // k .shape [1 ]),
250+ q_head_dim = int (q .shape [2 ]),
251+ block_seq = block_seq ,
252+ out_dtype = q .dtype ,
253+ )
254+
255+ BLOCK_N = run_config ["BLOCK_N" ]
256+ num_warps = run_config ["num_warps" ]
257+ num_stages = run_config ["num_stages" ]
258+
172259 assert q .dim () == 3 and k .dim () == 3 and v .dim () == 3
173260 BLOCK_SEQ = block_seq
174- BLOCK_N = 16
175261 assert BLOCK_SEQ % BLOCK_N == 0
176262 # shape constraints
177263 Lq , Lk = q .shape [- 1 ], k .shape [- 1 ]
@@ -189,7 +275,7 @@ def flash_decode_stage1(
189275 if BLOCK_HEAD * BLOCK_BATCH < 16 :
190276 BLOCK_BATCH = 16 // BLOCK_HEAD
191277
192- _fwd_kernel_flash_decode_stage1 [grid ](
278+ _fwd_kernel_flash_decode_diverse_stage1 [grid ](
193279 Q = q ,
194280 stride_qbs = q .stride (0 ),
195281 stride_qh = q .stride (1 ),
@@ -227,7 +313,7 @@ def flash_decode_stage1(
227313 BLOCK_N = BLOCK_N ,
228314 BLOCK_BATCH = BLOCK_BATCH ,
229315 KV_QUANT_GROUP_SIZE = KV_QUANT_GROUP_SIZE ,
230- num_warps = 2 ,
231- num_stages = 2 ,
316+ num_warps = num_warps ,
317+ num_stages = num_stages ,
232318 )
233319 return
0 commit comments