Skip to content

Commit da01e98

Browse files
committed
fix
1 parent 5942072 commit da01e98

File tree

1 file changed

+38
-53
lines changed

1 file changed

+38
-53
lines changed

lightllm/common/triton_utils/autotuner.py

Lines changed: 38 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,12 @@ def autotune(
2626
configs_gen_func: Callable[[], List],
2727
static_key_func: Callable,
2828
run_key_func: Callable,
29-
run_key_distance_func: Callable = lambda run_key, run_key_cached: abs(run_key - run_key_cached),
29+
run_key_distance_func: Callable = lambda run_key, config_key: abs(int(run_key) - int(config_key)),
3030
mutates_args: List[str] = [],
3131
):
3232
def decorator(fn):
33-
arg_names = [param.name for param in inspect.signature(fn).parameters.values()]
3433
return Autotuner(
3534
fn=fn,
36-
arg_names=arg_names,
3735
kernel_name=kernel_name,
3836
configs_gen_func=configs_gen_func,
3937
static_key_func=static_key_func,
@@ -54,12 +52,11 @@ def _get_param_names(func: Callable) -> List[str]:
5452
def __init__(
5553
self,
5654
fn,
57-
arg_names: List[str],
5855
kernel_name: str,
5956
configs_gen_func: Callable[[], List],
6057
static_key_func: Callable,
6158
run_key_func: Callable,
62-
run_key_distance_func: Callable = lambda run_key, run_key_cached: abs(run_key - run_key_cached),
59+
run_key_distance_func: Callable = lambda run_key, config_key: abs(int(run_key) - int(config_key)),
6360
mutates_args: List[str] = [],
6461
):
6562
# Whether to use this autotune decorator
@@ -79,20 +76,20 @@ def __init__(
7976
self.run_key_func = run_key_func
8077
self.run_key_distance_func = run_key_distance_func
8178
self.cached_configs = {}
82-
self.arg_names = arg_names
79+
self.arg_names = [param.name for param in inspect.signature(fn).parameters.values()]
8380
self._argname_to_pos = {name: idx for idx, name in enumerate(self.arg_names)}
81+
self._pos_to_argname = {idx: name for idx, name in enumerate(self.arg_names)}
8482

85-
self._static_param_names = self._get_param_names(self.static_key_func)
86-
self._run_param_names = self._get_param_names(self.run_key_func)
83+
self._static_key_func_param_names = self._get_param_names(self.static_key_func)
84+
self._run_key_func_param_names = self._get_param_names(self.run_key_func)
8785

8886
self.mutates_args = mutates_args
8987

9088
if not os.path.exists(self.cache_dir):
9189
if is_triton_autotune_enabled():
9290
os.makedirs(self.cache_dir, exist_ok=True)
9391

94-
@lru_cache(maxsize=None)
95-
def _ensure_cache_loaded(self, static_key):
92+
def _try_load_cache(self, static_key):
9693
if static_key in self.cached_configs:
9794
return
9895
cache_file = os.path.join(self.cache_dir, KernelConfigs.get_config_file_name(static_key))
@@ -104,19 +101,12 @@ def _bench(self, *args, n_repeat=5, n_retries=1, **kwargs):
104101
from triton.compiler.errors import CompileTimeAssertionFailure
105102
from triton.runtime.errors import OutOfResources, PTXASError
106103

107-
full_nargs = {**self.nargs, **kwargs}
108-
109104
def kernel_call():
110-
self.pre_hook(full_nargs)
105+
new_args, new_kwargs = self._mutate_args_clone(args, kwargs)
111106
try:
112-
self.fn(*args, **kwargs)
107+
self.fn(*new_args, **new_kwargs)
113108
except Exception as e:
114-
try:
115-
self.post_hook(full_nargs, exception=e)
116-
finally:
117-
raise
118-
119-
self.post_hook(full_nargs, exception=None)
109+
raise e
120110

121111
try:
122112
# warmup
@@ -211,36 +201,42 @@ def __call__(self, *args, **kwargs):
211201
run_key = self._run_key(*args, **kwargs)
212202

213203
# Lazy load
214-
self._ensure_cache_loaded(static_key)
215-
self.nargs = dict(zip(self.arg_names, args))
204+
self._try_load_cache(static_key)
205+
206+
if is_triton_autotune_enabled():
207+
if run_key not in self.cached_configs.get(static_key, {}):
208+
self._autotune(args, kwargs, static_key, run_key)
216209

217210
if static_key not in self.cached_configs:
218-
if not is_triton_autotune_enabled():
219-
logger.warning(
220-
f"No kernel config for {self.name} - {static_key}, \
221-
using default config. Use `LIGHTLLM_TRITON_AUTOTUNE=1` to enable autotune.",
222-
)
211+
logger.warning(
212+
f"No kernel config for {self.kernel_name} - {static_key}, \
213+
using default config. Use `LIGHTLLM_TRITON_AUTOTUNE=1` to enable autotune.",
214+
)
223215
self.cached_configs[static_key] = {}
224216

225217
all_configs = self.cached_configs.get(static_key)
226-
best_config = all_configs.get(run_key)
227-
228-
if best_config is not None:
229-
kwargs["run_config"] = best_config
230-
return self.fn(*args, **kwargs)
231218

232-
if is_triton_autotune_enabled():
233-
self._autotune(args, kwargs, static_key, run_key)
234-
kwargs["run_config"] = self.cached_configs.get(static_key, {}).get(run_key)
235-
return self.fn(*args, **kwargs)
236-
237-
if all_configs != {}:
238-
closest_config = min(all_configs, key=lambda x: self.run_key_distance_func(int(x[0]), int(run_key)))[1]
239-
self.cached_configs[static_key][run_key] = closest_config
219+
if len(all_configs) != 0:
220+
closest_config = min(all_configs, key=lambda c_key: self.run_key_distance_func(run_key, c_key))
240221
kwargs["run_config"] = closest_config
241222

242223
return self.fn(*args, **kwargs)
243224

225+
def _mutate_args_clone(self, args, kwargs):
226+
new_kwargs = kwargs.copy()
227+
new_args = list(args).copy()
228+
229+
for name in self.mutates_args:
230+
if name in kwargs:
231+
new_kwargs[name] = kwargs[name].clone()
232+
else:
233+
pos = self._argname_to_pos.get(name, None)
234+
if pos is not None and pos < len(args):
235+
new_args[pos] = args[pos].clone()
236+
else:
237+
raise KeyError(f"Missing argument '{name}' required to be mutated")
238+
return tuple(new_args), new_kwargs
239+
244240
def _select_args(self, param_names, args, kwargs):
245241
if not param_names:
246242
return ()
@@ -257,16 +253,12 @@ def _select_args(self, param_names, args, kwargs):
257253
return tuple(values)
258254

259255
def _static_key(self, *args, **kwargs):
260-
if self.static_key_func is None:
261-
return "default"
262-
params = self._select_args(self._static_param_names, args, kwargs)
256+
params = self._select_args(self._static_key_func_param_names, args, kwargs)
263257
key = self.static_key_func(*params)
264258
return frozendict(key)
265259

266260
def _run_key(self, *args, **kwargs):
267-
if self.run_key_func is None:
268-
return "default"
269-
params = self._select_args(self._run_param_names, args, kwargs)
261+
params = self._select_args(self._run_key_func_param_names, args, kwargs)
270262
return self.run_key_func(*params)
271263

272264

@@ -296,10 +288,3 @@ def split_configs(configs):
296288
rank_in_node = get_current_rank_in_node()
297289
node_world_size = get_node_world_size()
298290
return configs[rank_in_node::node_world_size]
299-
300-
301-
def nearest_power_of_2(x):
302-
# Return the power of two closest to x
303-
if x <= 1:
304-
return 1
305-
return triton.next_power_of_2(x - triton.next_power_of_2(x) // 4)

0 commit comments

Comments
 (0)