Skip to content

Commit 2d08577

Browse files
author
sufubao
committed
[triton] autotune
1 parent d9e3ba2 commit 2d08577

File tree

3 files changed

+353
-90
lines changed

3 files changed

+353
-90
lines changed

lightllm/common/autotuner.py

Lines changed: 285 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,285 @@
1+
from __future__ import annotations
2+
3+
from functools import wraps
4+
import builtins
5+
import os
6+
import time
7+
import inspect
8+
from typing import Dict
9+
from tqdm import tqdm
10+
11+
from triton.testing import do_bench, do_bench_cudagraph
12+
from triton.runtime.jit import KernelInterface
13+
from triton.runtime.errors import OutOfResources
14+
import triton
15+
16+
from lightllm.utils.log_utils import init_logger
17+
18+
logger = init_logger(__name__)
19+
20+
21+
def closest_power_of_2(n):
22+
n = int(n)
23+
# 对于小于等于 1 的情况,直接返回 1
24+
if n <= 1:
25+
return 1
26+
# 使用位运算查找最接近的 2 的幂
27+
lower = 1 << (n.bit_length() - 1)
28+
upper = lower << 1
29+
return lower if (n - lower) < (upper - n) else upper
30+
31+
32+
def get_str(name_list, value_list):
33+
return ",".join([f"{name}={value}" for (name, value) in zip(name_list, value_list)])
34+
35+
36+
class Autotuner(KernelInterface):
37+
def __init__(
38+
self,
39+
fn,
40+
arg_names,
41+
configs,
42+
key,
43+
reset_to_zero,
44+
restore_value,
45+
pre_hook=None,
46+
post_hook=None,
47+
prune_configs_by: Dict = None,
48+
warmup=25,
49+
rep=100,
50+
use_cuda_graph=False,
51+
):
52+
if not configs:
53+
self.configs = [triton.Config({}, num_warps=4, num_stages=2, num_ctas=1)]
54+
else:
55+
self.configs = configs
56+
self.key_idx = [arg_names.index(k) for k in key]
57+
self.cache = {}
58+
self.arg_names = arg_names
59+
60+
# Reset to zero or restore values
61+
self.reset_idx = []
62+
if reset_to_zero is not None:
63+
self.reset_idx = [arg_names.index(k) for k in reset_to_zero]
64+
self.restore_idx = []
65+
if restore_value is not None:
66+
self.restore_idx = [arg_names.index(k) for k in restore_value]
67+
68+
# Hook to reset or restore for required tensors
69+
self.pre_hook = lambda args, reset_only=False: 0
70+
self.post_hook = lambda args, exception: 0
71+
if pre_hook:
72+
self.pre_hook = pre_hook
73+
elif len(self.reset_idx) > 0 or len(self.restore_idx) > 0:
74+
75+
def _pre_hook(args, reset_only=False):
76+
for i in self.reset_idx:
77+
args[i].zero_()
78+
if not reset_only:
79+
self.restore_copies = [args[i].clone() for i in self.restore_idx]
80+
81+
self.pre_hook = _pre_hook
82+
83+
if post_hook:
84+
self.post_hook = post_hook
85+
elif len(self.restore_idx) > 0:
86+
87+
def _post_hook(args, exception):
88+
for i, j in enumerate(self.restore_idx):
89+
args[j].copy_(self.restore_copies[i])
90+
self.restore_copies = []
91+
92+
self.post_hook = _post_hook
93+
94+
self.perf_model = None
95+
self.configs_top_k = 1.0
96+
self.early_config_prune = None
97+
if prune_configs_by:
98+
self.perf_model = prune_configs_by.get("perf_model", self.perf_model)
99+
self.configs_top_k = prune_configs_by.get("top_k", self.configs_top_k)
100+
self.early_config_prune = prune_configs_by.get("early_config_prune", self.early_config_prune)
101+
102+
self.fn = fn
103+
self.fn_name = f"{os.path.relpath(fn.__module__)}.{fn.__name__}"
104+
self.base_fn = fn
105+
while not inspect.isfunction(self.base_fn):
106+
self.base_fn = self.base_fn.fn
107+
self.num_warmups = warmup
108+
self.num_reps = rep
109+
import torch
110+
111+
self.use_cuda_graph = use_cuda_graph and torch.cuda.is_available()
112+
113+
def _bench(self, *args, config, **meta):
114+
from triton.compiler.errors import CompileTimeAssertionFailure
115+
116+
# check for conflicts, i.e. meta-parameters both provided
117+
# as kwargs and by the autotuner
118+
conflicts = meta.keys() & config.all_kwargs().keys()
119+
if conflicts:
120+
# raise ValueError(f"Conflicting meta-parameters: {', '.join(conflicts)}."
121+
# " Make sure that you don't re-define auto-tuned symbols.")
122+
meta = {k: v for k, v in meta.items() if k not in conflicts}
123+
124+
conflicts = meta.keys() & config.all_kwargs().keys()
125+
if conflicts:
126+
raise ValueError(
127+
f"Conflicting meta-parameters: {', '.join(conflicts)}."
128+
" Make sure that you don't re-define auto-tuned symbols."
129+
)
130+
131+
# augment meta-parameters with tunable ones
132+
current = dict(meta, **config.all_kwargs())
133+
full_nargs = {**self.nargs, **current}
134+
135+
def kernel_call():
136+
if config.pre_hook:
137+
config.pre_hook(full_nargs)
138+
self.pre_hook(args)
139+
try:
140+
self.fn.run(
141+
*args,
142+
**current,
143+
)
144+
except Exception as e:
145+
try:
146+
self.post_hook(args, exception=e)
147+
finally:
148+
# Throw exception raised by `self.fn.run`
149+
raise
150+
151+
self.post_hook(args, exception=None)
152+
153+
try:
154+
if self.use_cuda_graph:
155+
import torch
156+
157+
with torch.cuda.stream(torch.cuda.Stream()):
158+
bench_res = do_bench_cudagraph(kernel_call, rep=self.num_reps, return_mode="median")
159+
return bench_res
160+
return do_bench(kernel_call, warmup=self.num_warmups, rep=self.num_reps, quantiles=(0.5, 0.2, 0.8))
161+
except (OutOfResources, CompileTimeAssertionFailure):
162+
return float("inf") if self.use_cuda_graph else [float("inf"), float("inf"), float("inf")]
163+
164+
def run(self, *args, **kwargs):
165+
if os.environ.get("ENABLE_AUTOTUNE", "0") == "0":
166+
return self.fn.run(*args, **kwargs)
167+
168+
self.nargs = dict(zip(self.arg_names, args))
169+
used_cached_result = True
170+
if len(self.configs) > 1:
171+
all_args = {**self.nargs, **kwargs}
172+
_args = []
173+
_args_name = []
174+
for name in self.arg_names:
175+
if name in all_args:
176+
_args.append(all_args[name])
177+
_args_name.append(name)
178+
key_list = [_args[i] for i in self.key_idx]
179+
key = tuple(key_list)
180+
if key not in self.cache:
181+
_args_name = []
182+
for name in self.arg_names:
183+
if name in all_args:
184+
_args_name.append(name)
185+
name_list = [_args_name[i] for i in self.key_idx]
186+
used_cached_result = False
187+
bench_start = time.time()
188+
timings = {
189+
config: self._bench(*args, config=config, **kwargs)
190+
for config in tqdm(self.configs, desc=f"Tuning {self.fn_name}::{get_str(name_list, key_list)}")
191+
}
192+
bench_end = time.time()
193+
self.bench_time = bench_end - bench_start
194+
self.cache[key] = builtins.min(timings, key=timings.get)
195+
self.pre_hook(args, reset_only=True)
196+
self.configs_timings = timings
197+
config = self.cache[key]
198+
else:
199+
config = self.configs[0]
200+
self.best_config = config
201+
202+
conflicts = kwargs.keys() & config.all_kwargs().keys()
203+
kwargs = {k: v for k, v in kwargs.items() if k not in conflicts}
204+
205+
if not used_cached_result:
206+
logger.debug(
207+
f"Triton autotuning for function {self.base_fn.__name__} finished after "
208+
f"{self.bench_time:.2f}s; best config selected: {self.best_config};"
209+
)
210+
if config.pre_hook is not None:
211+
config.pre_hook({**self.nargs, **kwargs, **config.all_kwargs()})
212+
213+
ret = self.fn.run(
214+
*args,
215+
**kwargs,
216+
**config.all_kwargs(),
217+
)
218+
219+
self.nargs = None
220+
return ret
221+
222+
def prune_configs(self, kwargs):
223+
pruned_configs = self.configs
224+
if self.early_config_prune:
225+
pruned_configs = self.early_config_prune(self.configs, self.nargs, **kwargs)
226+
if self.perf_model:
227+
top_k = self.configs_top_k
228+
if isinstance(top_k, float) and top_k <= 1.0:
229+
top_k = int(len(self.configs) * top_k)
230+
if len(pruned_configs) > top_k:
231+
est_timing = {
232+
config: self.perf_model(
233+
**self.nargs,
234+
**kwargs,
235+
**config.all_kwargs(),
236+
)
237+
for config in pruned_configs
238+
}
239+
pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k]
240+
return pruned_configs
241+
242+
def warmup(self, *args, **kwargs):
243+
self.nargs = dict(zip(self.arg_names, args))
244+
ret = []
245+
for config in self.prune_configs(kwargs):
246+
ret.append(
247+
self.fn.warmup(
248+
*args,
249+
**kwargs,
250+
**config.all_kwargs(),
251+
)
252+
)
253+
self.nargs = None
254+
return ret
255+
256+
257+
def autotune(
258+
configs,
259+
key,
260+
prune_configs_by=None,
261+
reset_to_zero=None,
262+
restore_value=None,
263+
pre_hook=None,
264+
post_hook=None,
265+
warmup=25,
266+
rep=100,
267+
use_cuda_graph=True,
268+
):
269+
def autotuned(fn):
270+
return Autotuner(
271+
fn,
272+
fn.arg_names,
273+
configs,
274+
key,
275+
reset_to_zero,
276+
restore_value,
277+
pre_hook=pre_hook,
278+
post_hook=post_hook,
279+
prune_configs_by=prune_configs_by,
280+
warmup=warmup,
281+
rep=rep,
282+
use_cuda_graph=use_cuda_graph,
283+
)
284+
285+
return autotuned

lightllm/common/basemodel/basemodel.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,9 @@ def __init__(self, kvargs):
6161
self.disable_cudagraph = kvargs.get("disable_cudagraph", False)
6262
self.mem_fraction = kvargs.get("mem_fraction", 0.9)
6363

64+
autotune = os.environ.get("ENABLE_AUTOTUNE", "0")
65+
os.environ["ENABLE_AUTOTUNE"] = "0"
66+
6467
self._init_datatype()
6568
self._init_config()
6669
self._verify_must()
@@ -74,6 +77,11 @@ def __init__(self, kvargs):
7477
self._init_custom()
7578
self._init_cudagraph()
7679
self._check_max_len_infer()
80+
81+
if autotune == "1":
82+
os.environ["ENABLE_AUTOTUNE"] = "1"
83+
self._autotune()
84+
7785
torch.cuda.empty_cache()
7886
return
7987

@@ -509,3 +517,43 @@ def _check_max_len_infer(self):
509517
logger.error(exception_str)
510518
raise Exception(exception_str)
511519
return
520+
521+
@torch.no_grad()
522+
def _check_prefill_infer(self, prefill_len):
523+
dummy_input_ids = torch.ones(prefill_len, dtype=torch.int32, device="cuda")
524+
b_req_idx = self.req_manager.alloc(1).int()
525+
b_seq_len = torch.ones(1, dtype=torch.int32, device="cuda")
526+
b_seq_len[:] = prefill_len
527+
b_ready_cache_len = torch.zeros(1, dtype=torch.int32, device="cuda")
528+
b_start_loc = torch.arange(0, 1, dtype=torch.int32, device="cuda")
529+
logics = self.forward(
530+
1,
531+
prefill_len,
532+
prefill_len,
533+
dummy_input_ids,
534+
b_req_idx,
535+
b_start_loc,
536+
b_seq_len,
537+
b_ready_cache_len=b_ready_cache_len,
538+
is_prefill=True,
539+
multimodal_params=[],
540+
)
541+
prob_out = torch.softmax(logics, dim=-1)
542+
logics = None
543+
torch.argmax(prob_out, dim=1, keepdim=True)
544+
prob_out = None
545+
self.req_manager.free_all()
546+
self.mem_manager.free_all()
547+
548+
@torch.no_grad()
549+
def _autotune(self):
550+
prefill_len = 1
551+
552+
# ---------------- Autotune Prefill------------------------------
553+
logger.info("begin test prefill other len infer for autotune.")
554+
prefill_len = 1
555+
while prefill_len < self.batch_max_tokens:
556+
self._check_prefill_infer(prefill_len)
557+
prefill_len *= 2
558+
self._check_prefill_infer(self.batch_max_tokens)
559+
logger.info("autotune done.")

0 commit comments

Comments
 (0)