2121
2222def _get_autotune_group ():
2323 from lightllm .distributed .communication_op import dist_group_manager
24+
2425 return dist_group_manager .get_default_group ().autotune_group
2526
2627
@@ -70,6 +71,7 @@ def autotune(
7071 Callable: A callable object that wraps the original function and performs autotuning
7172 as needed before invocation.
7273 """
74+
7375 def decorator (fn ):
7476 return Autotuner (
7577 fn = fn ,
@@ -121,8 +123,15 @@ def __init__(
121123 ]
122124 self ._run_key_func_param_names = [name for name , _ in inspect .signature (self .run_key_func ).parameters .items ()]
123125 self .mutates_args = mutates_args
124-
125- assert self .autotune_level in [AutotuneLevel .NO_AUTOTUNE , AutotuneLevel .AUTOTUNE , AutotuneLevel .AUTOTUNE_OVERWRITE , AutotuneLevel .AUTOTUNE_RUNTIME , AutotuneLevel .AUTOTUNE_RUNTIME_OVERWRITE , AutotuneLevel .CLOSE_AUTOTUNE ]
126+
127+ assert self .autotune_level in [
128+ AutotuneLevel .NO_AUTOTUNE ,
129+ AutotuneLevel .AUTOTUNE ,
130+ AutotuneLevel .AUTOTUNE_OVERWRITE ,
131+ AutotuneLevel .AUTOTUNE_RUNTIME ,
132+ AutotuneLevel .AUTOTUNE_RUNTIME_OVERWRITE ,
133+ AutotuneLevel .CLOSE_AUTOTUNE ,
134+ ]
126135 return
127136
128137 @torch .no_grad ()
@@ -131,7 +140,7 @@ def __call__(self, *args, **kwargs):
131140 return self .fn (* args , ** kwargs )
132141
133142 # if the autotune_level is AutotuneLevel.CLOSE_AUTOTUNE, ignore the autotune
134- if self . autotune_level == AutotuneLevel .CLOSE_AUTOTUNE :
143+ if get_triton_autotune_level () == AutotuneLevel .CLOSE_AUTOTUNE :
135144 return self .fn (* args , ** kwargs )
136145
137146 rank_id = 0 if not dist .is_initialized () else get_global_rank ()
@@ -141,23 +150,24 @@ def __call__(self, *args, **kwargs):
141150 run_key = str (self ._run_key (* args , ** kwargs ))
142151
143152 # Lazy load the cached configs in lightllm/common/triton_utils/autotune_kernel_configs
144- if self .autotune_level not in [AutotuneLevel .AUTOTUNE_OVERWRITE , AutotuneLevel .AUTOTUNE_RUNTIME_OVERWRITE ]:
153+ if get_triton_autotune_level () not in [
154+ AutotuneLevel .AUTOTUNE_OVERWRITE ,
155+ AutotuneLevel .AUTOTUNE_RUNTIME_OVERWRITE ,
156+ ]:
145157 self ._try_load_cache (static_key )
146158
147- if static_key not in self .cached_configs and self . autotune_level == AutotuneLevel .NO_AUTOTUNE :
159+ if static_key not in self .cached_configs and get_triton_autotune_level () == AutotuneLevel .NO_AUTOTUNE :
148160 if (dist .is_initialized () and get_current_rank_in_node () == 0 ) or not dist .is_initialized ():
149161 logger .warning (
150162 f"No kernel config for { self .kernel_name } in { KernelConfigs .get_config_file_name (static_key )} " ,
151163 )
152164 self .cached_configs [static_key ] = {}
153165
154- if self . autotune_level != AutotuneLevel .NO_AUTOTUNE :
166+ if get_triton_autotune_level () != AutotuneLevel .NO_AUTOTUNE :
155167 need_tunning = run_key not in self .cached_configs .get (static_key , {})
156168 if world_size > 1 :
157169 _need_tunnings = [None for _ in range (world_size )]
158- dist .all_gather_object (
159- _need_tunnings , obj = need_tunning , group = _get_autotune_group ()
160- )
170+ dist .all_gather_object (_need_tunnings , obj = need_tunning , group = _get_autotune_group ())
161171 need_tunning = any (_need_tunnings )
162172 if need_tunning :
163173 self ._autotune (
@@ -239,9 +249,7 @@ def _autotune(self, args, kwargs, static_key, run_key, rank_id, world_size):
239249 if world_size > 1 :
240250 all_keys = [None for _ in range (world_size )]
241251 all_key_str = f"{ run_key } _{ static_key } "
242- dist .all_gather_object (
243- all_keys , obj = all_key_str , group = _get_autotune_group ()
244- )
252+ dist .all_gather_object (all_keys , obj = all_key_str , group = _get_autotune_group ())
245253 is_key_all_same = all (all_keys [0 ] == k for k in all_keys )
246254 if not is_key_all_same :
247255 logger .warning (
@@ -394,4 +402,4 @@ def split_configs(configs, global_rank, global_world_size):
394402
395403
396404def closest_pow_of_2 (x ):
397- return triton .next_power_of_two (x - triton .next_power_of_two (x )// 4 )
405+ return triton .next_power_of_two (x - triton .next_power_of_two (x ) // 4 )
0 commit comments