Skip to content

Commit 4046deb

Browse files
committed
add deepseekv3 rotary config
1 parent 01e0440 commit 4046deb

File tree

4 files changed

+325
-13
lines changed

4 files changed

+325
-13
lines changed

lightllm/models/deepseek2/triton_kernel/rotary_emb.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -42,18 +42,18 @@ def _rotary_kernel(
4242
sin = tl.load(Sin + off_dimcos_sin)
4343

4444
for q_head_index in tl.static_range(0, HEAD_Q, step=1):
45-
off_q0 = (seq_index * stride_qbs + q_head_index * stride_qh + dim_range0 * stride_qd)
46-
off_q1 = (seq_index * stride_qbs + q_head_index * stride_qh + dim_range1 * stride_qd)
45+
off_q0 = seq_index * stride_qbs + q_head_index * stride_qh + dim_range0 * stride_qd
46+
off_q1 = seq_index * stride_qbs + q_head_index * stride_qh + dim_range1 * stride_qd
4747
q0 = tl.load(Q + off_q0)
4848
q1 = tl.load(Q + off_q1)
4949
out_q0 = q0 * cos - q1 * sin
5050
out_q1 = q0 * sin + q1 * cos
5151
tl.store(Q + off_q0, out_q0)
5252
tl.store(Q + off_q1, out_q1)
53-
53+
5454
for k_head_index in tl.static_range(0, HEAD_K, step=1):
55-
off_k0 = (seq_index * stride_kbs + k_head_index * stride_kh + dim_range0 * stride_kd)
56-
off_k1 = (seq_index * stride_kbs + k_head_index * stride_kh + dim_range1 * stride_kd)
55+
off_k0 = seq_index * stride_kbs + k_head_index * stride_kh + dim_range0 * stride_kd
56+
off_k1 = seq_index * stride_kbs + k_head_index * stride_kh + dim_range1 * stride_kd
5757

5858
k0 = tl.load(K + off_k0)
5959
k1 = tl.load(K + off_k1)
@@ -67,21 +67,28 @@ def _rotary_kernel(
6767

6868

6969
@torch.no_grad()
70-
def rotary_emb_fwd(q, k, cos, sin):
70+
def rotary_emb_fwd(q, k, cos, sin, **run_config):
7171
total_len = q.shape[0]
7272
head_num_q, head_num_k = q.shape[1], k.shape[1]
7373
head_dim = q.shape[2]
7474
assert q.shape[0] == cos.shape[0] and q.shape[0] == sin.shape[0], f"q shape {q.shape} cos shape {cos.shape}"
7575
assert k.shape[0] == cos.shape[0] and k.shape[0] == sin.shape[0], f"k shape {k.shape} cos shape {cos.shape}"
7676
assert triton.next_power_of_2(head_dim) == head_dim
7777

78-
if total_len <= 512:
79-
BLOCK_SEQ = 1
80-
else:
81-
BLOCK_SEQ = 16
78+
from .rotary_emb_config import DeepseekV3RotaryKernelConfig
79+
80+
if not run_config:
81+
run_config = DeepseekV3RotaryKernelConfig.try_to_get_best_config(
82+
M=total_len,
83+
Q_HEAD_NUM=head_num_q,
84+
K_HEAD_NUM=head_num_k,
85+
HEAD_DIM=head_dim,
86+
out_dtype=str(q.dtype),
87+
)
8288

83-
num_warps = 1
84-
num_stages = 3
89+
BLOCK_SEQ = run_config["BLOCK_SEQ"]
90+
num_warps = run_config["num_warps"]
91+
num_stages = run_config["num_stages"]
8592

8693
grid = (triton.cdiv(total_len, BLOCK_SEQ),)
8794
_rotary_kernel[grid](
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
import os
2+
from frozendict import frozendict
3+
from functools import lru_cache
4+
from lightllm.common.kernel_config import KernelConfigs
5+
from lightllm.utils.log_utils import init_logger
6+
7+
logger = init_logger(__name__)
8+
9+
10+
class DeepseekV3RotaryKernelConfig(KernelConfigs):
11+
kernel_name: str = "deepseek_v3_rotary_emb_kernel"
12+
13+
@classmethod
14+
@lru_cache(maxsize=200)
15+
def try_to_get_best_config(
16+
cls,
17+
M: int,
18+
Q_HEAD_NUM: int,
19+
K_HEAD_NUM: int,
20+
HEAD_DIM: int,
21+
dtype: str,
22+
) -> dict:
23+
key_params = {
24+
"M": M,
25+
"Q_HEAD_NUM": Q_HEAD_NUM,
26+
"K_HEAD_NUM": K_HEAD_NUM,
27+
"HEAD_DIM": HEAD_DIM,
28+
"dtype": str(dtype),
29+
}
30+
key_params = frozendict(key_params)
31+
32+
finded_config = cls.get_the_config(key_params)
33+
34+
if finded_config:
35+
config = finded_config[min(finded_config.keys(), key=lambda x: abs(int(x) - M))]
36+
return config
37+
else:
38+
if M <= 256:
39+
config = {"BLOCK_SEQ": 1, "NUM_STAGE": 1, "num_warps": 1, "num_stages": 1}
40+
else:
41+
config = {"BLOCK_SEQ": 16, "NUM_STAGE": 1, "num_warps": 1, "num_stages": 1}
42+
return config
43+
44+
@classmethod
45+
def save_config(
46+
cls,
47+
M: int,
48+
Q_HEAD_NUM: int,
49+
K_HEAD_NUM: int,
50+
HEAD_DIM: int,
51+
dtype: str,
52+
config_json: dict,
53+
):
54+
key_params = {
55+
"M": M,
56+
"Q_HEAD_NUM": Q_HEAD_NUM,
57+
"K_HEAD_NUM": K_HEAD_NUM,
58+
"HEAD_DIM": HEAD_DIM,
59+
"dtype": str(dtype),
60+
}
61+
key_params = frozendict(key_params)
62+
63+
return cls.store_config(key_params, config_json)

lightllm/utils/tuning_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def mp_tuning(func, args: Dict[str, Any]):
5454
best_cost_time = _cost_time
5555
best_config = _config
5656

57-
logger.info(f"best config {best_config} best cost time {best_cost_time}")
57+
logger.info(f"args: {args} best config {best_config} best cost time {best_cost_time}")
5858
return best_config
5959

6060

Lines changed: 242 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,242 @@
1+
import os
2+
import torch
3+
import time
4+
import torch.multiprocessing as mp
5+
import itertools
6+
from lightllm.models.deepseek2.triton_kernel.rotary_emb import rotary_emb_fwd
7+
from lightllm.models.deepseek2.triton_kernel.rotary_emb_config import DeepseekV3RotaryKernelConfig
8+
from lightllm.utils.watchdog_utils import Watchdog
9+
from typing import List
10+
from lightllm.utils.log_utils import init_logger
11+
12+
logger = init_logger(__name__)
13+
14+
15+
def set_seed():
16+
import torch
17+
import random
18+
import numpy as np
19+
20+
seed = 42
21+
torch.manual_seed(seed)
22+
random.seed(seed)
23+
np.random.seed(seed)
24+
if torch.cuda.is_available():
25+
torch.cuda.manual_seed(seed)
26+
torch.cuda.manual_seed_all(seed)
27+
return
28+
29+
30+
@torch.no_grad()
31+
def test_kernel(
32+
M: int,
33+
Q_HEAD_NUM: int,
34+
K_HEAD_NUM: int,
35+
HEAD_DIM: int,
36+
dtype: torch.dtype,
37+
test_count: int,
38+
**config,
39+
):
40+
set_seed()
41+
input_tuples = []
42+
43+
q = torch.randn((M, Q_HEAD_NUM, HEAD_DIM), device="cuda", dtype=dtype) / 10
44+
k = torch.randn((M, K_HEAD_NUM, HEAD_DIM), device="cuda", dtype=dtype) / 10
45+
cos = torch.randn((M, HEAD_DIM // 2), device="cuda", dtype=dtype)
46+
sin = torch.randn((M, HEAD_DIM // 2), device="cuda", dtype=dtype)
47+
48+
for _ in range(test_count):
49+
input_tuples.append((q.clone(), k.clone(), cos.clone(), sin.clone()))
50+
51+
# warm_up
52+
rotary_emb_fwd(q=q, k=k, cos=cos, sin=sin, **config)
53+
54+
graph = torch.cuda.CUDAGraph()
55+
56+
with torch.cuda.graph(graph):
57+
for index in range(test_count):
58+
q, k, cos, sin = input_tuples[index]
59+
rotary_emb_fwd(q=q, k=k, cos=cos, sin=sin, **config)
60+
61+
graph.replay()
62+
63+
torch.cuda.synchronize()
64+
start_event = torch.cuda.Event(enable_timing=True)
65+
end_event = torch.cuda.Event(enable_timing=True)
66+
start_event.record()
67+
graph.replay()
68+
end_event.record()
69+
end_event.synchronize()
70+
71+
cost_time = start_event.elapsed_time(end_event)
72+
73+
logger.info(str(config))
74+
logger.info(f"bf16 {M} cost time: {cost_time} ms")
75+
return cost_time
76+
77+
78+
def worker(
79+
M: int,
80+
Q_HEAD_NUM: int,
81+
K_HEAD_NUM: int,
82+
HEAD_DIM: int,
83+
dtype: torch.dtype,
84+
test_count: int,
85+
test_configs,
86+
queue,
87+
):
88+
dog = Watchdog(timeout=10)
89+
dog.start()
90+
try:
91+
for index in range(len(test_configs)):
92+
cost_time = test_kernel(
93+
M=M,
94+
Q_HEAD_NUM=Q_HEAD_NUM,
95+
K_HEAD_NUM=K_HEAD_NUM,
96+
HEAD_DIM=HEAD_DIM,
97+
dtype=dtype,
98+
test_count=test_count,
99+
**test_configs[index],
100+
)
101+
dog.heartbeat()
102+
queue.put(cost_time) # Put result in queue
103+
104+
except Exception as ex:
105+
logger.error(str(ex))
106+
logger.exception(str(ex))
107+
import sys
108+
109+
sys.exit(-1)
110+
pass
111+
112+
113+
def get_test_configs(split_id, split_count):
114+
index = 0
115+
result = itertools.product([1, 2, 4, 8, 16, 32], [1, 2, 4, 8], [1, 2, 3, 4, 5])
116+
for BLOCK_SEQ, num_warps, num_stages in result:
117+
t_config = {
118+
"BLOCK_SEQ": BLOCK_SEQ,
119+
"num_warps": num_warps,
120+
"num_stages": num_stages,
121+
}
122+
if index % split_count == split_id:
123+
yield t_config
124+
index += 1
125+
else:
126+
index += 1
127+
128+
129+
def tuning_configs(
130+
device_id: int, # use for mult mp tunning
131+
device_count: int,
132+
M: int,
133+
Q_HEAD_NUM: int,
134+
K_HEAD_NUM: int,
135+
HEAD_DIM: int,
136+
dtype: torch.dtype,
137+
test_count: int,
138+
):
139+
os.environ["CUDA_VISIBLE_DEVICES"] = str(device_id)
140+
best_config, best_cost_time = None, 10000000
141+
queue = mp.Queue()
142+
test_configs = []
143+
for t_config in get_test_configs(device_id, device_count):
144+
test_configs.append(t_config)
145+
if len(test_configs) < 256:
146+
continue
147+
148+
p = mp.Process(
149+
target=worker,
150+
args=(
151+
M,
152+
Q_HEAD_NUM,
153+
K_HEAD_NUM,
154+
HEAD_DIM,
155+
dtype,
156+
test_count,
157+
test_configs,
158+
queue,
159+
),
160+
)
161+
p.start()
162+
p.join()
163+
while len(test_configs) != 0:
164+
try:
165+
cost_time = queue.get_nowait()
166+
logger.info(f"get {test_configs[0]} cost_time: {cost_time}")
167+
if cost_time < best_cost_time:
168+
best_config = test_configs[0]
169+
best_cost_time = cost_time
170+
logger.info(f"cur best : {best_config} {best_cost_time}")
171+
del test_configs[0:1]
172+
except:
173+
del test_configs[0:16]
174+
logger.info(f"cur best : {best_config} {best_cost_time}")
175+
break
176+
177+
while len(test_configs) != 0:
178+
p = mp.Process(
179+
target=worker,
180+
args=(
181+
M,
182+
Q_HEAD_NUM,
183+
K_HEAD_NUM,
184+
HEAD_DIM,
185+
dtype,
186+
test_count,
187+
test_configs,
188+
queue,
189+
),
190+
)
191+
p.start()
192+
p.join()
193+
194+
while len(test_configs) != 0:
195+
try:
196+
cost_time = queue.get_nowait()
197+
logger.info(f"get {test_configs[0]} cost_time: {cost_time}")
198+
if cost_time < best_cost_time:
199+
best_config = test_configs[0]
200+
best_cost_time = cost_time
201+
logger.info(f"cur best : {best_config} {best_cost_time}")
202+
del test_configs[0:1]
203+
except:
204+
del test_configs[0:16]
205+
logger.info(f"cur best : {best_config} {best_cost_time}")
206+
break
207+
208+
logger.info(f"M {M} {best_config} best cost: {best_cost_time}")
209+
return best_config, best_cost_time
210+
211+
212+
if __name__ == "__main__":
213+
torch.multiprocessing.set_start_method("spawn")
214+
from lightllm.utils.tuning_utils import mp_tuning
215+
216+
# for deepseekv3 600B
217+
q_head_num = 128
218+
k_head_num = 1
219+
head_dim = 64
220+
dtype = torch.bfloat16
221+
for m in [1, 128, 256, 1024, 2048, 4096, 8192]:
222+
json_dict = {}
223+
ans = mp_tuning(
224+
tuning_configs,
225+
{
226+
"M": m,
227+
"Q_HEAD_NUM": q_head_num,
228+
"K_HEAD_NUM": k_head_num,
229+
"HEAD_DIM": head_dim,
230+
"dtype": dtype,
231+
"test_count": 20,
232+
},
233+
)
234+
json_dict[m] = ans
235+
DeepseekV3RotaryKernelConfig.save_config(
236+
M=m,
237+
Q_HEAD_NUM=q_head_num,
238+
K_HEAD_NUM=k_head_num,
239+
HEAD_DIM=head_dim,
240+
dtype=str(dtype),
241+
config_json=json_dict,
242+
)

0 commit comments

Comments
 (0)