1313from lightllm .utils .log_utils import init_logger
1414import traceback
1515from typing import Callable , Optional , Union , List
16+ from lightllm .utils .envs_utils import is_triton_autotune_enabled , disable_triton_autotune
1617
1718logger = init_logger (__name__ )
1819
@@ -45,16 +46,6 @@ def decorator(fn):
4546 return decorator
4647
4748
48- def is_triton_autotune_enabled ():
49- # Whether Triton autotune is enabled (read-only check)
50- return os .environ .get ("LIGHTLLM_TRITON_AUTOTUNE" , "0" ) == "1"
51-
52-
53- def disable_triton_autotune ():
54- # Disable Triton autotune (setter)
55- os .environ ["LIGHTLLM_TRITON_AUTOTUNE" ] = "0"
56-
57-
5849class Autotuner :
5950 @staticmethod
6051 def _get_param_names (func ):
@@ -148,8 +139,6 @@ def _post_hook(kwargs, exception):
148139 os .makedirs (self .cache_dir , exist_ok = True )
149140
150141 self ._loaded_static_keys = set ()
151- self .sorted_cached_configs = {}
152- self .early_stop_cnt = 0
153142
154143 @lru_cache (maxsize = None )
155144 def _ensure_cache_loaded (self , static_key : str ):
@@ -160,15 +149,11 @@ def _ensure_cache_loaded(self, static_key: str):
160149 try :
161150 with open (cache_file , "rb" ) as f :
162151 self .cached_configs [static_key ] = orjson .loads (f .read ())
163- self .sorted_cached_configs [static_key ] = [
164- (int (k ), v ) for k , v in self .cached_configs [static_key ].items ()
165- ]
166- self .sorted_cached_configs [static_key ].sort (key = lambda x : x [0 ])
167152 except Exception :
168153 self .cached_configs [static_key ] = {}
169154 self ._loaded_static_keys .add (static_key )
170155
171- def _bench (self , * args , n_repeat = 5 , n_retries = 1 , current_best_ms = None , ** kwargs ):
156+ def _bench (self , * args , n_repeat = 5 , n_retries = 1 , ** kwargs ):
172157 from triton .compiler .errors import CompileTimeAssertionFailure
173158 from triton .runtime .errors import OutOfResources , PTXASError
174159
@@ -234,26 +219,25 @@ def _benchmark(_run_key):
234219 rank_id = get_global_rank ()
235220 _best_config = self .default_config
236221 best_time = float ("inf" )
237- self . early_stop_cnt = 0
222+
238223 bar = tqdm (
239224 self .configs ,
240- desc = f"Autotuning { self .name } for { _run_key } , es: { self . early_stop_cnt / len ( self . configs ):.2% } " ,
225+ desc = f"Autotuning { self .name } for { _run_key } " ,
241226 position = get_global_rank (),
242227 dynamic_ncols = True ,
243228 )
244229 enum_configs = enumerate (bar )
245230 for i , config in enum_configs :
246231 kwargs_with_config = kwargs .copy ()
247232 kwargs_with_config ["run_config" ] = config
248- run_time = self ._bench (* args , current_best_ms = best_time , ** kwargs_with_config )
233+ run_time = self ._bench (* args , ** kwargs_with_config )
249234 if run_time < best_time :
250235 best_time = run_time
251236 _best_config = config
252237 bar .set_description (
253- f"Autotuning { self .name } [rank:{ rank_id } ] \
254- for { _run_key } , es:{ self .early_stop_cnt / len (self .configs ):.2%} , \
255- best_time: { best_time :.5f} "
238+ f"Autotuning { self .name } [rank:{ rank_id } ] for { _run_key } , best_time: { best_time :.5f} "
256239 )
240+
257241 world_size = dist .get_world_size () if dist .is_initialized () else 1
258242 if world_size > 1 :
259243 local_best = torch .tensor ([best_time ], device = "cuda" )
@@ -268,24 +252,21 @@ def _benchmark(_run_key):
268252 if static_key not in self .cached_configs :
269253 self .cached_configs [static_key ] = {}
270254 self .cached_configs [static_key ][run_key ] = _best_config
271- self .sorted_cached_configs [static_key ] = [(int (k ), v ) for k , v in self .cached_configs [static_key ].items ()]
272- self .sorted_cached_configs [static_key ].sort (key = lambda x : x [0 ])
273255
256+ # save configs to file
274257 if not dist .is_initialized () or get_global_rank () == 0 :
275- if os .environ .get ("LIGHTLLM_TRITON_AUTOTUNE" , "0" ) == "1" :
276- cache_file = os .path .join (self .cache_dir , f"{ static_key } .json" )
277- with open (cache_file , "wb" ) as f :
278- fcntl .flock (f , fcntl .LOCK_EX )
279- try :
280- f .write (
281- orjson .dumps (
282- self .cached_configs [static_key ], option = orjson .OPT_INDENT_2 | orjson .OPT_SORT_KEYS
283- )
258+ cache_file = os .path .join (self .cache_dir , f"{ static_key } .json" )
259+ with open (cache_file , "wb" ) as f :
260+ fcntl .flock (f , fcntl .LOCK_EX )
261+ try :
262+ f .write (
263+ orjson .dumps (
264+ self .cached_configs [static_key ], option = orjson .OPT_INDENT_2 | orjson .OPT_SORT_KEYS
284265 )
285- finally :
286- fcntl . flock ( f , fcntl . LOCK_UN )
287- if self . print_autotune :
288- logger .info (f"Saved configs for { self .name } - { static_key } - { run_key } " )
266+ )
267+ finally :
268+ fcntl . flock ( f , fcntl . LOCK_UN )
269+ logger .info (f"Saved configs for { self .name } - { static_key } - { run_key } " )
289270
290271 kwargs ["run_config" ] = self .cached_configs [static_key ][run_key ]
291272 full_nargs = {** self .nargs , ** kwargs }
@@ -294,7 +275,8 @@ def _benchmark(_run_key):
294275 if static_key not in self .cached_configs :
295276 if not is_triton_autotune_enabled ():
296277 logger .warning (
297- f"No kernel config for { self .name } in { self .cache_dir } /{ static_key } , using default config" ,
278+ f"No kernel config for { self .name } - { static_key } , \
279+ using default config. Use `LIGHTLLM_TRITON_AUTOTUNE=1` to enable autotune." ,
298280 )
299281 self .cached_configs [static_key ] = {}
300282
0 commit comments