1212from lightllm .utils .device_utils import get_current_device_name
1313from lightllm .utils .log_utils import init_logger
1414from typing import Callable , Optional , Union , List
15- from lightllm .utils .envs_utils import is_triton_autotune_enabled
15+ from lightllm .utils .envs_utils import get_triton_autotune_level
1616from lightllm .common .kernel_config import KernelConfigs
1717from lightllm .utils .dist_utils import get_global_world_size , get_global_rank , get_current_rank_in_node
18- from lightllm .distributed .communication_op import dist_group_manager
1918
2019logger = init_logger (__name__ )
2120
2221
22+ def _get_autotune_group ():
23+ from lightllm .distributed .communication_op import dist_group_manager
24+ return dist_group_manager .get_default_group ().autotune_group
25+
26+
27+ class AutotuneLevel :
28+ # Do not autotune, only use the config of cached files.
29+ NO_AUTOTUNE = 0
30+ # Autotune if no config is cached.
31+ AUTOTUNE = 1
32+ # Autotune anyway to overwrite the config of cached files.
33+ AUTOTUNE_OVERWRITE = 2
34+ # Auotune in runtime to search for more better config.
35+ AUTOTUNE_RUNTIME = 3
36+ # Autotune in runtime to search for more better config and overwrite the config of cached files.
37+ AUTOTUNE_RUNTIME_OVERWRITE = 4
38+ # Close autotune and not use the config of cached files.
39+ CLOSE_AUTOTUNE = 5
40+
41+
2342def autotune (
2443 kernel_name : str ,
2544 configs_gen_func : Callable [[], List ],
@@ -28,6 +47,29 @@ def autotune(
2847 run_key_distance_func : Callable = lambda run_key , config_key : abs (int (run_key ) - int (config_key )),
2948 mutates_args : List [str ] = [],
3049):
50+ """Decorator that constructs and returns an Autotuner wrapper for a Triton kernel.
51+
52+ This decorator configures an Autotuner with the provided configuration
53+ generator and key functions, enabling on-demand benchmarking and caching
54+ of kernel run configurations across runs and processes.
55+
56+ Args:
57+ kernel_name (str): Human-readable kernel name used for logging and cache paths.
58+ configs_gen_func (Callable[[], List]): Function that returns candidate run configurations.
59+ static_key_func (Callable): Function that derives a static key (dict-like) from call arguments.
60+ This key identifies the cache file that stores tuned configs.
61+ run_key_func (Callable): Function that derives a run-time key from call arguments.
62+ This key indexes tuned configs within a static key's cache.
63+ run_key_distance_func (Callable, optional): Distance metric taking ``(run_key, config_key)`` and
64+ returning a comparable value; used to pick the closest config when an exact match is absent.
65+ Defaults to ``abs(int(run_key) - int(config_key))``.
66+ mutates_args (List[str], optional): Names of arguments that can be mutated by the kernel.
67+ During benchmarking, defensive clones are made to avoid side effects. Defaults to ``[]``.
68+
69+ Returns:
70+ Callable: A callable object that wraps the original function and performs autotuning
71+ as needed before invocation.
72+ """
3173 def decorator (fn ):
3274 return Autotuner (
3375 fn = fn ,
@@ -53,8 +95,7 @@ def __init__(
5395 run_key_distance_func : Callable = lambda run_key , config_key : abs (int (run_key ) - int (config_key )),
5496 mutates_args : List [str ] = [],
5597 ):
56- # Whether to use this autotune decorator
57- self .disable_autotune = not is_triton_autotune_enabled ()
98+ self .autotune_level = get_triton_autotune_level ()
5899
59100 self .configs_gen_func = configs_gen_func
60101 self .kernel_name = kernel_name
@@ -65,7 +106,6 @@ def __init__(
65106 get_current_device_name (),
66107 self .kernel_name ,
67108 )
68- os .makedirs (self .cache_dir , exist_ok = True )
69109 self .fn = fn
70110 self .static_key_func = static_key_func
71111 self .run_key_func = run_key_func
@@ -81,38 +121,42 @@ def __init__(
81121 ]
82122 self ._run_key_func_param_names = [name for name , _ in inspect .signature (self .run_key_func ).parameters .items ()]
83123 self .mutates_args = mutates_args
124+
125+ assert self .autotune_level in [AutotuneLevel .NO_AUTOTUNE , AutotuneLevel .AUTOTUNE , AutotuneLevel .AUTOTUNE_OVERWRITE , AutotuneLevel .AUTOTUNE_RUNTIME , AutotuneLevel .AUTOTUNE_RUNTIME_OVERWRITE , AutotuneLevel .CLOSE_AUTOTUNE ]
84126 return
85127
86128 @torch .no_grad ()
87129 def __call__ (self , * args , ** kwargs ):
88130 if kwargs .get ("run_config" , None ) is not None :
89131 return self .fn (* args , ** kwargs )
90132
91- if self .disable_autotune :
133+ # if the autotune_level is AutotuneLevel.CLOSE_AUTOTUNE, ignore the autotune
134+ if self .autotune_level == AutotuneLevel .CLOSE_AUTOTUNE :
92135 return self .fn (* args , ** kwargs )
93136
94137 rank_id = 0 if not dist .is_initialized () else get_global_rank ()
95138 world_size = 1 if not dist .is_initialized () else get_global_world_size ()
96139
97- static_key = self ._static_key (* args , ** kwargs )
140+ static_key = frozendict ( self ._static_key (* args , ** kwargs ) )
98141 run_key = str (self ._run_key (* args , ** kwargs ))
99142
100- # Lazy load
101- self ._try_load_cache (static_key )
143+ # Lazy load the cached configs in lightllm/common/triton_utils/autotune_kernel_configs
144+ if self .autotune_level not in [AutotuneLevel .AUTOTUNE_OVERWRITE , AutotuneLevel .AUTOTUNE_RUNTIME_OVERWRITE ]:
145+ self ._try_load_cache (static_key )
102146
103- if static_key not in self .cached_configs :
147+ if static_key not in self .cached_configs and self . autotune_level == AutotuneLevel . NO_AUTOTUNE :
104148 if (dist .is_initialized () and get_current_rank_in_node () == 0 ) or not dist .is_initialized ():
105149 logger .warning (
106150 f"No kernel config for { self .kernel_name } in { KernelConfigs .get_config_file_name (static_key )} " ,
107151 )
108152 self .cached_configs [static_key ] = {}
109153
110- if is_triton_autotune_enabled () :
154+ if self . autotune_level != AutotuneLevel . NO_AUTOTUNE :
111155 need_tunning = run_key not in self .cached_configs .get (static_key , {})
112156 if world_size > 1 :
113157 _need_tunnings = [None for _ in range (world_size )]
114158 dist .all_gather_object (
115- _need_tunnings , obj = need_tunning , group = dist_group_manager . get_default_group (). autotune_group
159+ _need_tunnings , obj = need_tunning , group = _get_autotune_group ()
116160 )
117161 need_tunning = any (_need_tunnings )
118162 if need_tunning :
@@ -125,12 +169,12 @@ def __call__(self, *args, **kwargs):
125169 world_size = world_size ,
126170 )
127171
128- if static_key in self .fast_match_configs and run_key in self . fast_match_configs [ static_key ]:
129- closest_config = self . fast_match_configs [ static_key ][ run_key ]
130- kwargs ["run_config" ] = closest_config
172+ fast_for_key = self .fast_match_configs . get ( static_key )
173+ if fast_for_key is not None and run_key in fast_for_key :
174+ kwargs ["run_config" ] = fast_for_key [ run_key ]
131175 return self .fn (* args , ** kwargs )
132176
133- all_configs = self .cached_configs .get (static_key )
177+ all_configs = self .cached_configs .get (static_key , {} )
134178 if len (all_configs ) != 0 :
135179 closest_config = min (
136180 list (all_configs .items ()), key = lambda item : self .run_key_distance_func (run_key , item [0 ])
@@ -146,6 +190,7 @@ def _try_load_cache(self, static_key):
146190
147191 cache_file = os .path .join (self .cache_dir , KernelConfigs .get_config_file_name (static_key ))
148192 if os .path .exists (cache_file ):
193+ logger .info (f"Loading cached configs for { self .kernel_name } - { static_key } " )
149194 with open (cache_file , "rb" ) as f :
150195 self .cached_configs [static_key ] = orjson .loads (f .read ())
151196 return
@@ -195,7 +240,7 @@ def _autotune(self, args, kwargs, static_key, run_key, rank_id, world_size):
195240 all_keys = [None for _ in range (world_size )]
196241 all_key_str = f"{ run_key } _{ static_key } "
197242 dist .all_gather_object (
198- all_keys , obj = all_key_str , group = dist_group_manager . get_default_group (). autotune_group
243+ all_keys , obj = all_key_str , group = _get_autotune_group ()
199244 )
200245 is_key_all_same = all (all_keys [0 ] == k for k in all_keys )
201246 if not is_key_all_same :
@@ -237,7 +282,7 @@ def _autotune(self, args, kwargs, static_key, run_key, rank_id, world_size):
237282 dist .all_gather_object (
238283 all_gather_configs ,
239284 obj = (best_time , run_key , dict (static_key ), best_config ),
240- group = dist_group_manager . get_default_group (). autotune_group ,
285+ group = _get_autotune_group () ,
241286 )
242287 all_gather_configs = sorted (all_gather_configs , key = lambda x : x [0 ])
243288 key_set = set ()
@@ -318,8 +363,7 @@ def _select_args(self, param_names, args, kwargs):
318363
319364 def _static_key (self , * args , ** kwargs ):
320365 params = self ._select_args (self ._static_key_func_param_names , args , kwargs )
321- key = self .static_key_func (* params )
322- return frozendict (key )
366+ return self .static_key_func (* params )
323367
324368 def _run_key (self , * args , ** kwargs ):
325369 params = self ._select_args (self ._run_key_func_param_names , args , kwargs )
@@ -347,3 +391,7 @@ def get_triton_version():
347391def split_configs (configs , global_rank , global_world_size ):
348392 random .Random (0 ).shuffle (configs )
349393 return configs [global_rank ::global_world_size ]
394+
395+
396+ def closest_pow_of_2 (x ):
397+ return triton .next_power_of_two (x - triton .next_power_of_two (x )// 4 )
0 commit comments