@@ -77,6 +77,22 @@ def decorator(fn):
7777
7878
7979class Autotuner :
80+ _autotune_warmup : bool = False
81+
82+ @staticmethod
83+ def start_autotune_warmup ():
84+ Autotuner ._autotune_warmup = True
85+ return
86+
87+ @staticmethod
88+ def end_autotune_warmup ():
89+ Autotuner ._autotune_warmup = False
90+ return
91+
92+ @staticmethod
93+ def is_autotune_warmup ():
94+ return Autotuner ._autotune_warmup
95+
8096 def __init__ (
8197 self ,
8298 fn ,
@@ -104,6 +120,7 @@ def __init__(
104120 self .run_key_distance_func = run_key_distance_func
105121 self .cached_configs = {}
106122 self .fast_match_configs = collections .defaultdict (dict )
123+ self .warmuped_configs_set = set ()
107124 self .arg_names = [param .name for param in inspect .signature (self .fn ).parameters .values ()]
108125 self ._argname_to_pos = {name : idx for idx , name in enumerate (self .arg_names )}
109126 self ._pos_to_argname = {idx : name for idx , name in enumerate (self .arg_names )}
@@ -139,7 +156,13 @@ def __call__(self, *args, **kwargs):
139156 run_key = str (self ._run_key (* args , ** kwargs ))
140157
141158 # Lazy load the cached configs in lightllm/common/triton_utils/autotune_kernel_configs
142- self ._try_load_cache (static_key )
159+ if self ._try_load_cache (static_key ) or Autotuner .is_autotune_warmup ():
160+ all_configs = self .cached_configs .get (static_key , {})
161+ for run_config in all_configs .values ():
162+ # warmup all configs
163+ _copy_kwargs = kwargs .copy ()
164+ _copy_kwargs ["run_config" ] = run_config
165+ self .kernel_warmup (static_key , * args , ** _copy_kwargs )
143166
144167 if static_key not in self .cached_configs and autotune_level == AutotuneLevel .USE_AUTOTUNE_HIS_CONFIG :
145168 if (dist .is_initialized () and get_current_rank_in_node () == 0 ) or not dist .is_initialized ():
@@ -150,7 +173,10 @@ def __call__(self, *args, **kwargs):
150173 )
151174 self .cached_configs [static_key ] = {}
152175
153- if autotune_level in [AutotuneLevel .ADAPTIVE_AUTOTUNE , AutotuneLevel .FORCE_AUTOTUNE ]:
176+ if (
177+ autotune_level in [AutotuneLevel .ADAPTIVE_AUTOTUNE , AutotuneLevel .FORCE_AUTOTUNE ]
178+ and Autotuner .is_autotune_warmup ()
179+ ):
154180 need_tuning = (autotune_level == AutotuneLevel .FORCE_AUTOTUNE ) or (
155181 run_key not in self .cached_configs .get (static_key , {})
156182 )
@@ -185,13 +211,28 @@ def __call__(self, *args, **kwargs):
185211
186212 def _try_load_cache (self , static_key ):
187213 if static_key in self .cached_configs :
188- return
214+ return False
189215
190216 cache_file = os .path .join (self .cache_dir , KernelConfigs .get_config_file_name (static_key ))
191217 if os .path .exists (cache_file ):
192218 logger .info (f"Loading cached configs for { self .kernel_name } - { static_key } " )
193219 with open (cache_file , "rb" ) as f :
194220 self .cached_configs [static_key ] = orjson .loads (f .read ())
221+ return True
222+
223+ def kernel_warmup (self , static_key , * args , ** kwargs ):
224+ new_args , new_kwargs , origin_list , new_list = self ._mutate_args_clone (args , kwargs )
225+ run_config = kwargs .get ("run_config" , {})
226+ hash_key = str (frozendict (run_config )) + str (static_key )
227+ if hash_key in self .warmuped_configs_set :
228+ return
229+ try :
230+ self .fn (* new_args , ** new_kwargs )
231+ self .warmuped_configs_set .add (hash_key )
232+ except :
233+ pass
234+ finally :
235+ self ._recover_mutated_args (origin_list = origin_list , new_list = new_list )
195236 return
196237
197238 def _bench (self , * args , n_repeat = 3 , n_retries = 3 , ** kwargs ):
0 commit comments