2020
2121
2222class AutotuneLevel :
23+ # Use the config of cached files in /lightllm/common/triton_utils/autotune_kernel_configs.
2324 USE_AUTOTUNE_HIS_CONFIG = 0
2425 # Autotune if no config is cached.
2526 ADAPTIVE_AUTOTUNE = 1
2627 # Autotune anyway to overwrite the config of cached files.
2728 FORCE_AUTOTUNE = 2
28- # Close autotune and not use the config of cached files.
29+ # Close autotune and use the configs of cached files in lightllm/common/all_kernel_configs .
2930 CLOSE_AUTOTUNE = 3
3031
3132
@@ -148,7 +149,9 @@ def __call__(self, *args, **kwargs):
148149 self .cached_configs [static_key ] = {}
149150
150151 if autotune_level in [AutotuneLevel .ADAPTIVE_AUTOTUNE , AutotuneLevel .FORCE_AUTOTUNE ]:
151- need_tuning = (autotune_level == AutotuneLevel .FORCE_AUTOTUNE ) or (run_key not in self .cached_configs .get (static_key , {}))
152+ need_tuning = (autotune_level == AutotuneLevel .FORCE_AUTOTUNE ) or (
153+ run_key not in self .cached_configs .get (static_key , {})
154+ )
152155 if world_size > 1 :
153156 _need_tunings = [None for _ in range (world_size )]
154157 dist .all_gather_object (_need_tunings , obj = need_tuning , group = self ._get_autotune_group ())
@@ -162,7 +165,7 @@ def __call__(self, *args, **kwargs):
162165 rank_id = rank_id ,
163166 world_size = world_size ,
164167 )
165-
168+
166169 closest_config = self .fast_match_configs .get (static_key , {}).get (run_key , None )
167170 if closest_config is not None :
168171 kwargs ["run_config" ] = closest_config
@@ -361,13 +364,14 @@ def _run_key(self, *args, **kwargs):
361364 params = self ._select_args (self ._run_key_func_param_names , args , kwargs )
362365 return self .run_key_func (* params )
363366
364- def _get_autotune_group (self ,):
367+ def _get_autotune_group (
368+ self ,
369+ ):
365370 from lightllm .distributed .communication_op import dist_group_manager
366371
367372 return dist_group_manager .get_default_group ().autotune_group
368373
369374
370-
371375class _BenchmarkState :
372376 def __init__ (self ):
373377 self .sum = 0
@@ -389,4 +393,3 @@ def get_triton_version():
389393def split_configs (configs , global_rank , global_world_size ):
390394 random .Random (0 ).shuffle (configs )
391395 return configs [global_rank ::global_world_size ]
392-
0 commit comments