1919def autotune (
2020 name ,
2121 configs ,
22- default_config ,
2322 static_key_func = None ,
2423 run_key_func = None ,
2524 reset_to_zero = None ,
@@ -34,7 +33,6 @@ def decorator(fn):
3433 arg_names ,
3534 name ,
3635 configs ,
37- default_config ,
3836 static_key_func ,
3937 run_key_func ,
4038 reset_to_zero ,
@@ -46,6 +44,16 @@ def decorator(fn):
4644 return decorator
4745
4846
47+ def is_triton_autotune_enabled ():
48+ # Whether Triton autotune is enabled (read-only check)
49+ return os .environ .get ("LIGHTLLM_TRITON_AUTOTUNE" , "0" ) == "1"
50+
51+
52+ def disable_triton_autotune ():
53+ # Disable Triton autotune (setter)
54+ os .environ ["LIGHTLLM_TRITON_AUTOTUNE" ] = "0"
55+
56+
4957class Autotuner :
5058 @staticmethod
5159 def _get_param_names (func ):
@@ -72,7 +80,6 @@ def __init__(
7280 arg_names ,
7381 name ,
7482 configs ,
75- default_config ,
7683 static_key_func ,
7784 run_key_func ,
7885 reset_to_zero ,
@@ -82,11 +89,13 @@ def __init__(
8289 warmup = None ,
8390 rep = None ,
8491 ):
85- # 是否打印autotune信息
92+ # Whether to print autotune logs
8693 self .print_autotune = os .environ .get ("LIGHTLLM_TRITON_PRINT_AUTOTUNE" , "0" ) == "1"
94+ # Whether to use this autotune decorator
95+ self .disable_autotune = os .environ .get ("DISABLE_AUTOTUNE_DECORATOR" , "0" ) == "1"
96+
8797 self .all_configs = configs
8898 self .configs = None
89- self .default_config = default_config
9099 self .name = name
91100 self .cache_dir = os .path .join (
92101 Path (__file__ ).parent , "all_kernel_configs" , get_triton_version (), get_current_device_name (), self .name
@@ -95,11 +104,6 @@ def __init__(
95104 self .static_key_func = static_key_func
96105 self .run_key_func = run_key_func
97106
98- # 是否使用之前配置
99- self .can_be_none = os .environ .get ("DISABLE_MANUAL_TUNE_CONFIG" , "0" ) == "0"
100- # 是否使用autotune注解
101- self .disable_autotune = os .environ .get ("DISABLE_AUTOTUNE_ANNOTATION" , "0" ) == "1"
102-
103107 self .cached_configs = {}
104108 self .arg_names = arg_names
105109 self ._argname_to_pos = {name : idx for idx , name in enumerate (self .arg_names )}
@@ -165,7 +169,6 @@ def _ensure_cache_loaded(self, static_key: str):
165169 ]
166170 self .sorted_cached_configs [static_key ].sort (key = lambda x : x [0 ])
167171 except Exception :
168- # 若缓存损坏,忽略并在之后覆盖
169172 self .cached_configs [static_key ] = {}
170173 self ._loaded_static_keys .add (static_key )
171174
@@ -229,9 +232,8 @@ def __call__(self, *args, **kwargs):
229232 static_key = self ._static_key (* args , ** kwargs )
230233 run_key = self ._run_key (* args , ** kwargs )
231234
232- # 懒加载
235+ # Lazy load
233236 self ._ensure_cache_loaded (static_key )
234- best_config = None
235237 self .nargs = dict (zip (self .arg_names , args ))
236238
237239 def _benchmark (_run_key ):
@@ -300,30 +302,30 @@ def _benchmark(_run_key):
300302 full_nargs = {** self .nargs , ** kwargs }
301303 self .pre_hook (full_nargs , reset_only = True )
302304
303- best_config = self .cached_configs .get (static_key , {}).get (run_key )
305+ if static_key not in self .cached_configs :
306+ if not is_triton_autotune_enabled ():
307+ logger .warning (
308+ f"No kernel config for { self .name } in { self .cache_dir } /{ static_key } , using default config" ,
309+ )
310+ self .cached_configs [static_key ] = {}
311+
312+ all_configs = self .cached_configs .get (static_key )
313+ best_config = all_configs .get (run_key )
314+
315+ if best_config is not None :
316+ kwargs ["run_config" ] = best_config
317+ return self .fn (* args , ** kwargs )
318+
319+ if is_triton_autotune_enabled ():
320+ _benchmark (run_key )
321+ kwargs ["run_config" ] = self .cached_configs .get (static_key , {}).get (run_key )
322+ return self .fn (* args , ** kwargs )
323+
324+ if all_configs != {}:
325+ closest_config = min (all_configs , key = lambda x : abs (int (x [0 ]) - int (run_key )))[1 ]
326+ self .cached_configs [static_key ][run_key ] = closest_config
327+ kwargs ["run_config" ] = closest_config
304328
305- if best_config is None :
306- if os .environ .get ("LIGHTLLM_TRITON_AUTOTUNE" , "0" ) == "1" :
307- _benchmark (run_key )
308- elif not self .can_be_none :
309- cached_for_static = self .cached_configs .setdefault (static_key , {})
310- if static_key in self .sorted_cached_configs and self .sorted_cached_configs [static_key ]:
311- sorted_configs = self .sorted_cached_configs [static_key ]
312- try :
313- target = int (run_key )
314- cached_for_static [run_key ] = min (sorted_configs , key = lambda x : abs (x [0 ] - target ))[1 ]
315- except Exception :
316- cached_for_static [run_key ] = self .default_config
317- else :
318- if static_key not in self .sorted_cached_configs :
319- logger .warning (
320- f"No kernel config for { self .name } in { self .cache_dir } /{ static_key } , using default config"
321- )
322- cached_for_static [run_key ] = self .default_config
323-
324- best_config = self .cached_configs [static_key ][run_key ]
325-
326- kwargs ["run_config" ] = best_config
327329 return self .fn (* args , ** kwargs )
328330
329331 def _select_args (self , param_names , args , kwargs ):
@@ -412,7 +414,7 @@ def dict_to_filename(data):
412414
413415
414416def nearest_power_of_2 (x ):
415- # 返回最接近 x 的 2 的幂次方
417+ # Return the power of two closest to x
416418 if x <= 1 :
417419 return 1
418420 return triton .next_power_of_2 (x - triton .next_power_of_2 (x ) // 4 )
0 commit comments