Skip to content

Commit f401a6d

Browse files
committed
fix
1 parent eae44a8 commit f401a6d

File tree

1 file changed

+16
-27
lines changed

1 file changed

+16
-27
lines changed

lightllm/common/triton_utils/autotuner.py

Lines changed: 16 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33
import os
44
import inspect
55
import torch
6-
import fcntl
7-
import random
86
import torch.distributed as dist
97
from pathlib import Path
108
from tqdm import tqdm
@@ -14,7 +12,7 @@
1412
from typing import Callable, Optional, Union, List
1513
from lightllm.utils.envs_utils import is_triton_autotune_enabled
1614
from lightllm.common.kernel_config import KernelConfigs
17-
from lightllm.utils.dist_utils import get_current_rank_in_node, get_node_world_size, get_global_rank
15+
from lightllm.utils.dist_utils import get_global_world_size, get_global_rank
1816

1917
logger = init_logger(__name__)
2018

@@ -143,7 +141,7 @@ def kernel_call():
143141
kernel_call()
144142
torch.cuda.synchronize()
145143

146-
state = BenchmarkState()
144+
state = _BenchmarkState()
147145
for i in range(n_retries):
148146
start_event = torch.cuda.Event(enable_timing=True)
149147
end_event = torch.cuda.Event(enable_timing=True)
@@ -158,15 +156,14 @@ def kernel_call():
158156
return float("inf")
159157

160158
def _autotune(self, args, kwargs, static_key, run_key):
161-
if self.configs is None:
162-
self.configs = split_configs(self.configs_gen_func())
159+
rank_tuning_configs = split_configs(self.configs_gen_func())
163160

164161
rank_id = get_global_rank()
165162
_best_config = None
166163
best_time = float("inf")
167164

168165
bar = tqdm(
169-
self.configs,
166+
rank_tuning_configs,
170167
desc=f"Autotuning {self.kernel_name} for {run_key}",
171168
position=get_global_rank(),
172169
dynamic_ncols=True,
@@ -202,31 +199,25 @@ def _autotune(self, args, kwargs, static_key, run_key):
202199
if not dist.is_initialized() or get_global_rank() == 0:
203200
cache_file = os.path.join(self.cache_dir, KernelConfigs.get_config_file_name(static_key))
204201
with open(cache_file, "wb") as f:
205-
fcntl.flock(f, fcntl.LOCK_EX)
206-
try:
207-
f.write(
208-
orjson.dumps(
209-
self.cached_configs[static_key],
210-
option=orjson.OPT_INDENT_2 | orjson.OPT_SORT_KEYS | orjson.OPT_NON_STR_KEYS,
211-
)
202+
f.write(
203+
orjson.dumps(
204+
self.cached_configs[static_key],
205+
option=orjson.OPT_INDENT_2 | orjson.OPT_SORT_KEYS | orjson.OPT_NON_STR_KEYS,
212206
)
213-
finally:
214-
fcntl.flock(f, fcntl.LOCK_UN)
215-
logger.info(f"Saved configs for {self.name} - {static_key} - {run_key}")
216-
217-
kwargs["run_config"] = self.cached_configs[static_key][run_key]
207+
)
208+
logger.info(f"Saved configs for {self.kernel_name} - {static_key} - {run_key}")
218209

219210
def _mutate_args_clone(self, args, kwargs):
220211
new_kwargs = kwargs.copy()
221212
new_args = list(args).copy()
222213

223214
for name in self.mutates_args:
224215
if name in kwargs:
225-
new_kwargs[name] = kwargs[name].clone()
216+
new_kwargs[name] = None if kwargs[name] is None else kwargs[name].clone()
226217
else:
227218
pos = self._argname_to_pos.get(name, None)
228219
if pos is not None and pos < len(args):
229-
new_args[pos] = args[pos].clone()
220+
new_args[pos] = None if args[pos] is None else args[pos].clone()
230221
else:
231222
raise KeyError(f"Missing argument '{name}' required to be mutated")
232223
return tuple(new_args), new_kwargs
@@ -256,7 +247,7 @@ def _run_key(self, *args, **kwargs):
256247
return self.run_key_func(*params)
257248

258249

259-
class BenchmarkState:
250+
class _BenchmarkState:
260251
def __init__(self):
261252
self.sum = 0
262253
self.min = float("inf")
@@ -275,8 +266,6 @@ def get_triton_version():
275266

276267

277268
def split_configs(configs):
278-
279-
random.shuffle(configs)
280-
rank_in_node = get_current_rank_in_node()
281-
node_world_size = get_node_world_size()
282-
return configs[rank_in_node::node_world_size]
269+
global_rank = get_global_rank()
270+
global_world_size = get_global_world_size()
271+
return configs[global_rank::global_world_size]

0 commit comments

Comments
 (0)