1919logger = init_logger (__name__ )
2020
2121
22- def _get_autotune_group ():
23- from lightllm .distributed .communication_op import dist_group_manager
24-
25- return dist_group_manager .get_default_group ().autotune_group
26-
27-
2822class AutotuneLevel :
29- # Do not autotune, only use the config of cached files.
30- NO_AUTOTUNE = 0
23+ USE_AUTOTUNE_HIS_CONFIG = 0
3124 # Autotune if no config is cached.
32- AUTOTUNE = 1
25+ ADAPTIVE_AUTOTUNE = 0
3326 # Autotune anyway to overwrite the config of cached files.
34- AUTOTUNE_OVERWRITE = 2
27+ FORCE_AUTOTUNE = 1
3528 # Close autotune and not use the config of cached files.
36- CLOSE_AUTOTUNE = 3
29+ CLOSE_AUTOTUNE = 2
3730
3831
3932def autotune (
@@ -104,6 +97,7 @@ def __init__(
10497 get_current_device_name (),
10598 self .kernel_name ,
10699 )
100+ os .makedirs (self .cache_dir , exist_ok = True )
107101 self .fn = fn
108102 self .static_key_func = static_key_func
109103 self .run_key_func = run_key_func
@@ -163,7 +157,7 @@ def __call__(self, *args, **kwargs):
163157 )
164158 if world_size > 1 :
165159 _need_tunings = [None for _ in range (world_size )]
166- dist .all_gather_object (_need_tunings , obj = need_tuning , group = _get_autotune_group ())
160+ dist .all_gather_object (_need_tunings , obj = need_tuning , group = self . _get_autotune_group ())
167161 need_tuning = any (_need_tunings )
168162 if need_tuning :
169163 self ._autotune (
@@ -245,7 +239,7 @@ def _autotune(self, args, kwargs, static_key, run_key, rank_id, world_size):
245239 if world_size > 1 :
246240 all_keys = [None for _ in range (world_size )]
247241 all_key_str = f"{ run_key } _{ static_key } "
248- dist .all_gather_object (all_keys , obj = all_key_str , group = _get_autotune_group ())
242+ dist .all_gather_object (all_keys , obj = all_key_str , group = self . _get_autotune_group ())
249243 is_key_all_same = all (all_keys [0 ] == k for k in all_keys )
250244 if not is_key_all_same :
251245 logger .warning (
@@ -286,7 +280,7 @@ def _autotune(self, args, kwargs, static_key, run_key, rank_id, world_size):
286280 dist .all_gather_object (
287281 all_gather_configs ,
288282 obj = (best_time , run_key , dict (static_key ), best_config ),
289- group = _get_autotune_group (),
283+ group = self . _get_autotune_group (),
290284 )
291285 all_gather_configs = sorted (all_gather_configs , key = lambda x : x [0 ])
292286 key_set = set ()
@@ -312,7 +306,6 @@ def _autotune(self, args, kwargs, static_key, run_key, rank_id, world_size):
312306 if rank_id == 0 :
313307 for _static_key in update_static_key_list :
314308 cache_file = os .path .join (self .cache_dir , KernelConfigs .get_config_file_name (_static_key ))
315- os .makedirs (self .cache_dir , exist_ok = True )
316309 with open (cache_file , "wb" ) as f :
317310 f .write (
318311 orjson .dumps (
@@ -374,6 +367,12 @@ def _run_key(self, *args, **kwargs):
374367 params = self ._select_args (self ._run_key_func_param_names , args , kwargs )
375368 return self .run_key_func (* params )
376369
370+ def _get_autotune_group (self ,):
371+ from lightllm .distributed .communication_op import dist_group_manager
372+
373+ return dist_group_manager .get_default_group ().autotune_group
374+
375+
377376
378377class _BenchmarkState :
379378 def __init__ (self ):
@@ -397,6 +396,3 @@ def split_configs(configs, global_rank, global_world_size):
397396 random .Random (0 ).shuffle (configs )
398397 return configs [global_rank ::global_world_size ]
399398
400-
401- def closest_pow_of_2 (x ):
402- return triton .next_power_of_2 (x - triton .next_power_of_2 (x ) // 4 )
0 commit comments