2323
2424def autotune (
2525 name : str ,
26- configs : "Optional[Union[List, Callable[[], List]]]" ,
26+ configs_gen_func : Callable [[], List ],
2727 static_key_func : "Optional[Callable]" = None ,
2828 run_key_func : "Optional[Callable]" = None ,
2929 run_key_distance_func : "Optional[Callable]" = None ,
30- reset_to_zero = None ,
31- restore_value = None ,
32- pre_hook : "Optional[Callable]" = None ,
33- post_hook : "Optional[Callable]" = None ,
30+ mutates_args : List [str ] = None ,
3431):
3532 def decorator (fn ):
3633 arg_names = [param .name for param in inspect .signature (fn ).parameters .values ()]
3734 return Autotuner (
3835 fn ,
3936 arg_names ,
4037 name ,
41- configs ,
38+ configs_gen_func ,
4239 static_key_func ,
4340 run_key_func ,
4441 run_key_distance_func ,
45- reset_to_zero ,
46- restore_value ,
47- pre_hook = pre_hook ,
48- post_hook = post_hook ,
42+ mutates_args ,
4943 )
5044
5145 return decorator
@@ -76,21 +70,18 @@ def __init__(
7670 fn ,
7771 arg_names ,
7872 name ,
79- configs : "Optional[Union[List, Callable[[], List]]]" ,
80- static_key_func : "Optional[Callable[[], dict]]" = None ,
81- run_key_func : "Optional[Callable]" = None ,
82- run_key_distance_func : "Optional[Callable]" = lambda a , b : abs (a - b ),
83- reset_to_zero : "Optional[List]" = None ,
84- restore_value : "Optional[List]" = None ,
85- pre_hook : "Optional[Callable]" = None ,
86- post_hook : "Optional[Callable]" = None ,
73+ configs_gen_func : Callable [[], List ],
74+ static_key_func : Optional [Callable [[], dict ]] = None ,
75+ run_key_func : Optional [Callable ] = None ,
76+ run_key_distance_func : Optional [Callable ] = lambda run_key , run_key_cached : abs (run_key - run_key_cached ),
77+ mutates_args : List [str ] = None ,
8778 ):
8879 # Whether to use this autotune decorator
8980 self .disable_autotune = os .environ .get ("DISABLE_AUTOTUNE_DECORATOR" , "0" ) == "1"
9081
91- self .configs : "Optional[Union[List, Callable[[], List]]]" = configs
92- if not callable ( self .configs ):
93- self . configs = split_configs ( self . configs )
82+ self .configs = None
83+ self .configs_gen_func = configs_gen_func
84+
9485 self .name = name
9586 self .cache_dir = os .path .join (
9687 Path (__file__ ).parent , "all_kernel_configs" , get_triton_version (), get_current_device_name (), self .name
@@ -106,34 +97,21 @@ def __init__(
10697 self ._static_param_names = self ._get_param_names (self .static_key_func )
10798 self ._run_param_names = self ._get_param_names (self .run_key_func )
10899
109- self .reset_to_zero = []
110- if reset_to_zero is not None :
111- self .reset_to_zero = list (reset_to_zero )
112- self .restore_value = []
113- if restore_value is not None :
114- self .restore_value = list (restore_value )
100+ self .mutates_args = []
101+ if mutates_args is not None :
102+ self .mutates_args = list (mutates_args )
115103
116104 self .pre_hook = lambda kwargs , reset_only = False : 0
117105 self .post_hook = lambda kwargs , exception : 0
118106 self .user_defined_pre_hook = False
119107 self .user_defined_post_hook = False
120- if pre_hook :
121- self .pre_hook = pre_hook
122- self .user_defined_pre_hook = True
123- elif len (self .reset_to_zero ) > 0 or len (self .restore_value ) > 0 :
124108
125- def _pre_hook (kwargs , reset_only = False ):
126- for name in self .reset_to_zero :
127- kwargs [name ].zero_ ()
128- if not reset_only :
129- self .restore_copies = {name : kwargs [name ].clone () for name in self .restore_value }
109+ if len (self .mutates_args ) > 0 :
130110
131- self .pre_hook = _pre_hook
111+ def _pre_hook (kwargs ):
112+ self .restore_copies = {name : kwargs [name ].clone () for name in self .mutates_args }
132113
133- if post_hook :
134- self .post_hook = post_hook
135- self .user_defined_post_hook = True
136- elif len (self .restore_value ) > 0 :
114+ self .pre_hook = _pre_hook
137115
138116 def _post_hook (kwargs , exception ):
139117 for name in self .restore_value :
@@ -143,7 +121,7 @@ def _post_hook(kwargs, exception):
143121 self .post_hook = _post_hook
144122
145123 if not os .path .exists (self .cache_dir ):
146- if os . environ . get ( "LIGHTLLM_TRITON_AUTOTUNE" , "0" ) == "1" :
124+ if is_triton_autotune_enabled () :
147125 os .makedirs (self .cache_dir , exist_ok = True )
148126
149127 self ._loaded_static_keys = set ()
@@ -207,11 +185,11 @@ def kernel_call():
207185 def _autotune (self , args , kwargs , static_key , run_key ):
208186 from lightllm .utils .dist_utils import get_global_rank
209187
210- if callable ( self .configs ) :
211- self .configs = split_configs (self .configs ())
188+ if self .configs is None :
189+ self .configs = split_configs (self .configs_gen_func ())
212190
213191 rank_id = get_global_rank ()
214- _best_config = self . default_config
192+ _best_config = None
215193 best_time = float ("inf" )
216194
217195 bar = tqdm (
@@ -261,8 +239,6 @@ def _autotune(self, args, kwargs, static_key, run_key):
261239 logger .info (f"Saved configs for { self .name } - { static_key } - { run_key } " )
262240
263241 kwargs ["run_config" ] = self .cached_configs [static_key ][run_key ]
264- full_nargs = {** self .nargs , ** kwargs }
265- self .pre_hook (full_nargs , reset_only = True )
266242
267243 @torch .no_grad ()
268244 def __call__ (self , * args , ** kwargs ):
0 commit comments