Skip to content

Commit 677f648

Browse files
committed
refactor: move deepseek tunning to new dir
1 parent 1b9f4f4 commit 677f648

File tree

1 file changed

+333
-0
lines changed

1 file changed

+333
-0
lines changed
Lines changed: 333 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,333 @@
1+
import torch
2+
import time
3+
import os
4+
import torch.multiprocessing as mp
5+
from typing import List
6+
from lightllm.utils.log_utils import init_logger
7+
from lightllm.models.deepseek2.triton_kernel.gqa_flash_decoding import gqa_token_decode_attention_flash_decoding
8+
from lightllm.utils.watchdog_utils import Watchdog
9+
10+
logger = init_logger(__name__)
11+
12+
13+
def set_seed():
14+
import torch
15+
import random
16+
import numpy as np
17+
18+
seed = 42
19+
torch.manual_seed(seed)
20+
random.seed(seed)
21+
np.random.seed(seed)
22+
if torch.cuda.is_available():
23+
torch.cuda.manual_seed(seed)
24+
torch.cuda.manual_seed_all(seed)
25+
return
26+
27+
28+
@torch.no_grad()
29+
def test_decode_attentions(
30+
q_nope_shape: List[int],
31+
q_rope_shape: List[int],
32+
kv_nope_shape: List[int],
33+
kv_rope_shape: List[int],
34+
test_seq_len: int,
35+
dtype: torch.dtype,
36+
test_count: int = 20,
37+
**run_config,
38+
):
39+
set_seed()
40+
tmp_class = type("TestObj", (object,), {})
41+
infer_state = tmp_class()
42+
infer_state.batch_size = q_nope_shape[0]
43+
infer_state.max_len_in_batch = test_seq_len
44+
infer_state.req_manager = tmp_class()
45+
infer_state.req_manager.req_to_token_indexs = torch.zeros(
46+
(infer_state.batch_size, infer_state.max_len_in_batch), dtype=torch.int32, device="cuda"
47+
)
48+
infer_state.req_manager.req_to_token_indexs.view(-1)[:] = torch.arange(
49+
0, infer_state.batch_size * infer_state.max_len_in_batch, step=1, dtype=torch.int32
50+
).cuda()
51+
infer_state.b_req_idx = torch.arange(0, infer_state.batch_size, step=1, dtype=torch.int32).cuda()
52+
infer_state.b_seq_len = torch.full((infer_state.batch_size,), fill_value=test_seq_len, dtype=torch.int32).cuda()
53+
infer_state.total_token_num_tensor = torch.sum(infer_state.b_seq_len)
54+
55+
input_tuples = []
56+
for _ in range(test_count):
57+
q_nope = torch.randn(q_nope_shape, device="cuda", dtype=dtype) / 10
58+
q_rope = torch.randn(q_rope_shape, device="cuda", dtype=dtype) / 10
59+
kv_buffer_shape = [
60+
(test_seq_len + 10) * infer_state.batch_size,
61+
kv_nope_shape[1],
62+
kv_nope_shape[2] + kv_rope_shape[2],
63+
]
64+
kv_buffer = torch.randn(kv_buffer_shape, device="cuda", dtype=dtype) / 10
65+
66+
kv_nope = kv_buffer[:, :, 0 : kv_nope_shape[2]]
67+
kv_rope = kv_buffer[:, :, kv_nope_shape[2] :]
68+
o_tensor = torch.empty_like(q_nope)
69+
input_tuples.append((q_nope, q_rope, kv_buffer, kv_nope, kv_rope, o_tensor))
70+
71+
tensor_dict = {}
72+
73+
def inner_alloc_func(shape, dtype=torch.float32, device="cuda"):
74+
shape = tuple(shape)
75+
if shape not in tensor_dict:
76+
ans = torch.empty(shape, dtype=dtype, device=device)
77+
tensor_dict[shape] = ans
78+
return ans
79+
else:
80+
return tensor_dict[shape]
81+
82+
gqa_token_decode_attention_flash_decoding(
83+
q_nope,
84+
q_rope,
85+
kv_nope,
86+
kv_rope,
87+
infer_state,
88+
q_nope_shape[1],
89+
q_nope_shape[2],
90+
q_rope_shape[2],
91+
None,
92+
0.01,
93+
out=o_tensor,
94+
alloc_tensor_func=inner_alloc_func,
95+
**run_config,
96+
)
97+
98+
graph = torch.cuda.CUDAGraph()
99+
with torch.cuda.graph(graph):
100+
for index in range(test_count):
101+
q_nope, q_rope, kv_buffer, kv_nope, kv_rope, o_tensor = input_tuples[index]
102+
gqa_token_decode_attention_flash_decoding(
103+
q_nope,
104+
q_rope,
105+
kv_nope,
106+
kv_rope,
107+
infer_state,
108+
q_nope_shape[1],
109+
q_nope_shape[2],
110+
q_rope_shape[2],
111+
None,
112+
0.01,
113+
out=o_tensor,
114+
alloc_tensor_func=inner_alloc_func,
115+
**run_config,
116+
)
117+
118+
graph.replay()
119+
120+
torch.cuda.synchronize()
121+
start = time.time()
122+
graph.replay()
123+
torch.cuda.synchronize()
124+
125+
cost_time = (time.time() - start) * 1000
126+
127+
logger.info(f"bf16 {test_seq_len} cost time: {cost_time} ms")
128+
return cost_time
129+
130+
131+
def worker(
132+
q_nope_shape: List[int],
133+
q_rope_shape: List[int],
134+
kv_nope_shape: List[int],
135+
kv_rope_shape: List[int],
136+
test_seq_len: int,
137+
dtype: torch.dtype,
138+
test_count: int,
139+
test_configs,
140+
queue,
141+
):
142+
dog = Watchdog(timeout=10)
143+
dog.start()
144+
145+
try:
146+
for index in range(len(test_configs)):
147+
tuning_config = test_configs[index]
148+
cost_time = test_decode_attentions(
149+
q_nope_shape=q_nope_shape,
150+
q_rope_shape=q_rope_shape,
151+
kv_nope_shape=kv_nope_shape,
152+
kv_rope_shape=kv_rope_shape,
153+
test_seq_len=test_seq_len,
154+
dtype=dtype,
155+
test_count=test_count,
156+
**tuning_config,
157+
)
158+
dog.heartbeat()
159+
queue.put(cost_time) # Put result in queue
160+
except Exception as ex:
161+
logger.error(str(ex) + f"config {tuning_config}")
162+
import sys
163+
164+
sys.exit(-1)
165+
pass
166+
167+
168+
def get_test_configs(split_id, split_count):
169+
index = 0
170+
for block_n in [16, 32]:
171+
for block_q_head in [
172+
16,
173+
]:
174+
for stage1_num_warps in [2, 4, 8, 16]:
175+
for stage1_num_stages in [
176+
1,
177+
2,
178+
3,
179+
4,
180+
5,
181+
6,
182+
7,
183+
8,
184+
12,
185+
15,
186+
]:
187+
for stage2_num_warps in [1, 2, 4]:
188+
for stage2_num_stages in [
189+
1,
190+
3,
191+
]:
192+
t_config = {
193+
"BLOCK_N": block_n,
194+
"BLOCK_Q_HEAD": block_q_head,
195+
"stage1_num_warps": stage1_num_warps,
196+
"stage1_num_stages": stage1_num_stages,
197+
"stage2_num_warps": stage2_num_warps,
198+
"stage2_num_stages": stage2_num_stages,
199+
}
200+
if index % split_count == split_id:
201+
yield t_config
202+
index += 1
203+
else:
204+
index += 1
205+
206+
207+
def tuning_configs(
208+
device_id: int, # use for mult mp tunning
209+
device_count: int,
210+
q_nope_shape: List[int],
211+
q_rope_shape: List[int],
212+
kv_nope_shape: List[int],
213+
kv_rope_shape: List[int],
214+
test_seq_len: int,
215+
dtype: torch.dtype,
216+
test_count: int,
217+
):
218+
os.environ["CUDA_VISIBLE_DEVICES"] = str(device_id)
219+
best_config, best_cost_time = None, 10000000
220+
queue = mp.Queue()
221+
test_configs = []
222+
for t_config in get_test_configs(device_id, device_count):
223+
test_configs.append(t_config)
224+
if len(test_configs) < 64:
225+
continue
226+
227+
p = mp.Process(
228+
target=worker,
229+
args=(
230+
q_nope_shape,
231+
q_rope_shape,
232+
kv_nope_shape,
233+
kv_rope_shape,
234+
test_seq_len,
235+
dtype,
236+
test_count,
237+
test_configs,
238+
queue,
239+
),
240+
)
241+
p.start()
242+
p.join()
243+
244+
while len(test_configs) != 0:
245+
try:
246+
cost_time = queue.get_nowait()
247+
logger.info(f"get {test_configs[0]} cost_time: {cost_time}")
248+
if cost_time < best_cost_time:
249+
best_config = test_configs[0]
250+
best_cost_time = cost_time
251+
logger.info(f"cur best {best_config}, {best_cost_time}")
252+
del test_configs[0:1]
253+
except:
254+
logger.info(f"cur best {best_config}, {best_cost_time}")
255+
del test_configs[0:1]
256+
break
257+
258+
while len(test_configs) != 0:
259+
p = mp.Process(
260+
target=worker,
261+
args=(
262+
q_nope_shape,
263+
q_rope_shape,
264+
kv_nope_shape,
265+
kv_rope_shape,
266+
test_seq_len,
267+
dtype,
268+
test_count,
269+
test_configs,
270+
queue,
271+
),
272+
)
273+
p.start()
274+
p.join()
275+
276+
while len(test_configs) != 0:
277+
try:
278+
cost_time = queue.get_nowait()
279+
logger.info(f"get {test_configs[0]} cost_time: {cost_time}")
280+
if cost_time < best_cost_time:
281+
best_config = test_configs[0]
282+
best_cost_time = cost_time
283+
logger.info(f"cur best {best_config}, {best_cost_time}")
284+
del test_configs[0:1]
285+
except:
286+
logger.info(f"cur best {best_config}, {best_cost_time}")
287+
del test_configs[0:1]
288+
break
289+
290+
logger.info(f"{best_config} best cost: {best_cost_time}")
291+
return best_config, best_cost_time
292+
293+
294+
if __name__ == "__main__":
295+
torch.multiprocessing.set_start_method("spawn")
296+
297+
from lightllm.utils.tuning_utils import mp_tuning
298+
from lightllm.models.deepseek2.triton_kernel.gqa_flash_decoding_config import MlaDecodeAttentionKernelConfig
299+
300+
q_head_num = 16
301+
q_head_dim = 512
302+
q_rope_dim = 64
303+
import collections
304+
305+
store_json_ans = collections.defaultdict(dict)
306+
for batch_size in [1, 8, 16, 32, 64, 128, 256]:
307+
for seq_len in [256, 512, 1024, 2048, 4096, 8192]:
308+
if batch_size * seq_len > 128 * 1024 * 4:
309+
continue
310+
311+
ans = mp_tuning(
312+
tuning_configs,
313+
{
314+
"q_nope_shape": [batch_size, q_head_num, q_head_dim],
315+
"q_rope_shape": [batch_size, q_head_num, q_rope_dim],
316+
"kv_nope_shape": [None, 1, q_head_dim],
317+
"kv_rope_shape": [None, 1, q_rope_dim],
318+
"test_seq_len": seq_len,
319+
"dtype": torch.bfloat16,
320+
"test_count": 40,
321+
},
322+
)
323+
store_json_ans[seq_len][batch_size] = ans
324+
325+
MlaDecodeAttentionKernelConfig.save_config(
326+
q_head_num=q_head_num,
327+
q_head_dim=q_head_dim,
328+
q_rope_dim=q_rope_dim,
329+
out_dtype=str(torch.bfloat16),
330+
config_json=store_json_ans,
331+
)
332+
333+
pass

0 commit comments

Comments
 (0)