@@ -26,14 +26,12 @@ def autotune(
2626 configs_gen_func : Callable [[], List ],
2727 static_key_func : Callable ,
2828 run_key_func : Callable ,
29- run_key_distance_func : Callable = lambda run_key , run_key_cached : abs (run_key - run_key_cached ),
29+ run_key_distance_func : Callable = lambda run_key , config_key : abs (int ( run_key ) - int ( config_key ) ),
3030 mutates_args : List [str ] = [],
3131):
3232 def decorator (fn ):
33- arg_names = [param .name for param in inspect .signature (fn ).parameters .values ()]
3433 return Autotuner (
3534 fn = fn ,
36- arg_names = arg_names ,
3735 kernel_name = kernel_name ,
3836 configs_gen_func = configs_gen_func ,
3937 static_key_func = static_key_func ,
@@ -54,12 +52,11 @@ def _get_param_names(func: Callable) -> List[str]:
5452 def __init__ (
5553 self ,
5654 fn ,
57- arg_names : List [str ],
5855 kernel_name : str ,
5956 configs_gen_func : Callable [[], List ],
6057 static_key_func : Callable ,
6158 run_key_func : Callable ,
62- run_key_distance_func : Callable = lambda run_key , run_key_cached : abs (run_key - run_key_cached ),
59+ run_key_distance_func : Callable = lambda run_key , config_key : abs (int ( run_key ) - int ( config_key ) ),
6360 mutates_args : List [str ] = [],
6461 ):
6562 # Whether to use this autotune decorator
@@ -79,20 +76,20 @@ def __init__(
7976 self .run_key_func = run_key_func
8077 self .run_key_distance_func = run_key_distance_func
8178 self .cached_configs = {}
82- self .arg_names = arg_names
79+ self .arg_names = [ param . name for param in inspect . signature ( fn ). parameters . values ()]
8380 self ._argname_to_pos = {name : idx for idx , name in enumerate (self .arg_names )}
81+ self ._pos_to_argname = {idx : name for idx , name in enumerate (self .arg_names )}
8482
85- self ._static_param_names = self ._get_param_names (self .static_key_func )
86- self ._run_param_names = self ._get_param_names (self .run_key_func )
83+ self ._static_key_func_param_names = self ._get_param_names (self .static_key_func )
84+ self ._run_key_func_param_names = self ._get_param_names (self .run_key_func )
8785
8886 self .mutates_args = mutates_args
8987
9088 if not os .path .exists (self .cache_dir ):
9189 if is_triton_autotune_enabled ():
9290 os .makedirs (self .cache_dir , exist_ok = True )
9391
94- @lru_cache (maxsize = None )
95- def _ensure_cache_loaded (self , static_key ):
92+ def _try_load_cache (self , static_key ):
9693 if static_key in self .cached_configs :
9794 return
9895 cache_file = os .path .join (self .cache_dir , KernelConfigs .get_config_file_name (static_key ))
@@ -104,19 +101,12 @@ def _bench(self, *args, n_repeat=5, n_retries=1, **kwargs):
104101 from triton .compiler .errors import CompileTimeAssertionFailure
105102 from triton .runtime .errors import OutOfResources , PTXASError
106103
107- full_nargs = {** self .nargs , ** kwargs }
108-
109104 def kernel_call ():
110- self .pre_hook ( full_nargs )
105+ new_args , new_kwargs = self ._mutate_args_clone ( args , kwargs )
111106 try :
112- self .fn (* args , ** kwargs )
107+ self .fn (* new_args , ** new_kwargs )
113108 except Exception as e :
114- try :
115- self .post_hook (full_nargs , exception = e )
116- finally :
117- raise
118-
119- self .post_hook (full_nargs , exception = None )
109+ raise e
120110
121111 try :
122112 # warmup
@@ -211,36 +201,42 @@ def __call__(self, *args, **kwargs):
211201 run_key = self ._run_key (* args , ** kwargs )
212202
213203 # Lazy load
214- self ._ensure_cache_loaded (static_key )
215- self .nargs = dict (zip (self .arg_names , args ))
204+ self ._try_load_cache (static_key )
205+
206+ if is_triton_autotune_enabled ():
207+ if run_key not in self .cached_configs .get (static_key , {}):
208+ self ._autotune (args , kwargs , static_key , run_key )
216209
217210 if static_key not in self .cached_configs :
218- if not is_triton_autotune_enabled ():
219- logger .warning (
220- f"No kernel config for { self .name } - { static_key } , \
221- using default config. Use `LIGHTLLM_TRITON_AUTOTUNE=1` to enable autotune." ,
222- )
211+ logger .warning (
212+ f"No kernel config for { self .kernel_name } - { static_key } , \
213+ using default config. Use `LIGHTLLM_TRITON_AUTOTUNE=1` to enable autotune." ,
214+ )
223215 self .cached_configs [static_key ] = {}
224216
225217 all_configs = self .cached_configs .get (static_key )
226- best_config = all_configs .get (run_key )
227-
228- if best_config is not None :
229- kwargs ["run_config" ] = best_config
230- return self .fn (* args , ** kwargs )
231218
232- if is_triton_autotune_enabled ():
233- self ._autotune (args , kwargs , static_key , run_key )
234- kwargs ["run_config" ] = self .cached_configs .get (static_key , {}).get (run_key )
235- return self .fn (* args , ** kwargs )
236-
237- if all_configs != {}:
238- closest_config = min (all_configs , key = lambda x : self .run_key_distance_func (int (x [0 ]), int (run_key )))[1 ]
239- self .cached_configs [static_key ][run_key ] = closest_config
219+ if len (all_configs ) != 0 :
220+ closest_config = min (all_configs , key = lambda c_key : self .run_key_distance_func (run_key , c_key ))
240221 kwargs ["run_config" ] = closest_config
241222
242223 return self .fn (* args , ** kwargs )
243224
225+ def _mutate_args_clone (self , args , kwargs ):
226+ new_kwargs = kwargs .copy ()
227+ new_args = list (args ).copy ()
228+
229+ for name in self .mutates_args :
230+ if name in kwargs :
231+ new_kwargs [name ] = kwargs [name ].clone ()
232+ else :
233+ pos = self ._argname_to_pos .get (name , None )
234+ if pos is not None and pos < len (args ):
235+ new_args [pos ] = args [pos ].clone ()
236+ else :
237+ raise KeyError (f"Missing argument '{ name } ' required to be mutated" )
238+ return tuple (new_args ), new_kwargs
239+
244240 def _select_args (self , param_names , args , kwargs ):
245241 if not param_names :
246242 return ()
@@ -257,16 +253,12 @@ def _select_args(self, param_names, args, kwargs):
257253 return tuple (values )
258254
259255 def _static_key (self , * args , ** kwargs ):
260- if self .static_key_func is None :
261- return "default"
262- params = self ._select_args (self ._static_param_names , args , kwargs )
256+ params = self ._select_args (self ._static_key_func_param_names , args , kwargs )
263257 key = self .static_key_func (* params )
264258 return frozendict (key )
265259
266260 def _run_key (self , * args , ** kwargs ):
267- if self .run_key_func is None :
268- return "default"
269- params = self ._select_args (self ._run_param_names , args , kwargs )
261+ params = self ._select_args (self ._run_key_func_param_names , args , kwargs )
270262 return self .run_key_func (* params )
271263
272264
@@ -296,10 +288,3 @@ def split_configs(configs):
296288 rank_in_node = get_current_rank_in_node ()
297289 node_world_size = get_node_world_size ()
298290 return configs [rank_in_node ::node_world_size ]
299-
300-
301- def nearest_power_of_2 (x ):
302- # Return the power of two closest to x
303- if x <= 1 :
304- return 1
305- return triton .next_power_of_2 (x - triton .next_power_of_2 (x ) // 4 )
0 commit comments