33import os
44import inspect
55import torch
6- import fcntl
7- import random
86import torch .distributed as dist
97from pathlib import Path
108from tqdm import tqdm
1412from typing import Callable , Optional , Union , List
1513from lightllm .utils .envs_utils import is_triton_autotune_enabled
1614from lightllm .common .kernel_config import KernelConfigs
17- from lightllm .utils .dist_utils import get_current_rank_in_node , get_node_world_size , get_global_rank
15+ from lightllm .utils .dist_utils import get_global_world_size , get_global_rank
1816
1917logger = init_logger (__name__ )
2018
@@ -143,7 +141,7 @@ def kernel_call():
143141 kernel_call ()
144142 torch .cuda .synchronize ()
145143
146- state = BenchmarkState ()
144+ state = _BenchmarkState ()
147145 for i in range (n_retries ):
148146 start_event = torch .cuda .Event (enable_timing = True )
149147 end_event = torch .cuda .Event (enable_timing = True )
@@ -158,15 +156,14 @@ def kernel_call():
158156 return float ("inf" )
159157
160158 def _autotune (self , args , kwargs , static_key , run_key ):
161- if self .configs is None :
162- self .configs = split_configs (self .configs_gen_func ())
159+ rank_tuning_configs = split_configs (self .configs_gen_func ())
163160
164161 rank_id = get_global_rank ()
165162 _best_config = None
166163 best_time = float ("inf" )
167164
168165 bar = tqdm (
169- self . configs ,
166+ rank_tuning_configs ,
170167 desc = f"Autotuning { self .kernel_name } for { run_key } " ,
171168 position = get_global_rank (),
172169 dynamic_ncols = True ,
@@ -202,31 +199,25 @@ def _autotune(self, args, kwargs, static_key, run_key):
202199 if not dist .is_initialized () or get_global_rank () == 0 :
203200 cache_file = os .path .join (self .cache_dir , KernelConfigs .get_config_file_name (static_key ))
204201 with open (cache_file , "wb" ) as f :
205- fcntl .flock (f , fcntl .LOCK_EX )
206- try :
207- f .write (
208- orjson .dumps (
209- self .cached_configs [static_key ],
210- option = orjson .OPT_INDENT_2 | orjson .OPT_SORT_KEYS | orjson .OPT_NON_STR_KEYS ,
211- )
202+ f .write (
203+ orjson .dumps (
204+ self .cached_configs [static_key ],
205+ option = orjson .OPT_INDENT_2 | orjson .OPT_SORT_KEYS | orjson .OPT_NON_STR_KEYS ,
212206 )
213- finally :
214- fcntl .flock (f , fcntl .LOCK_UN )
215- logger .info (f"Saved configs for { self .name } - { static_key } - { run_key } " )
216-
217- kwargs ["run_config" ] = self .cached_configs [static_key ][run_key ]
207+ )
208+ logger .info (f"Saved configs for { self .kernel_name } - { static_key } - { run_key } " )
218209
219210 def _mutate_args_clone (self , args , kwargs ):
220211 new_kwargs = kwargs .copy ()
221212 new_args = list (args ).copy ()
222213
223214 for name in self .mutates_args :
224215 if name in kwargs :
225- new_kwargs [name ] = kwargs [name ].clone ()
216+ new_kwargs [name ] = None if kwargs [ name ] is None else kwargs [name ].clone ()
226217 else :
227218 pos = self ._argname_to_pos .get (name , None )
228219 if pos is not None and pos < len (args ):
229- new_args [pos ] = args [pos ].clone ()
220+ new_args [pos ] = None if args [ pos ] is None else args [pos ].clone ()
230221 else :
231222 raise KeyError (f"Missing argument '{ name } ' required to be mutated" )
232223 return tuple (new_args ), new_kwargs
@@ -256,7 +247,7 @@ def _run_key(self, *args, **kwargs):
256247 return self .run_key_func (* params )
257248
258249
259- class BenchmarkState :
250+ class _BenchmarkState :
260251 def __init__ (self ):
261252 self .sum = 0
262253 self .min = float ("inf" )
@@ -275,8 +266,6 @@ def get_triton_version():
275266
276267
277268def split_configs (configs ):
278-
279- random .shuffle (configs )
280- rank_in_node = get_current_rank_in_node ()
281- node_world_size = get_node_world_size ()
282- return configs [rank_in_node ::node_world_size ]
269+ global_rank = get_global_rank ()
270+ global_world_size = get_global_world_size ()
271+ return configs [global_rank ::global_world_size ]
0 commit comments