Skip to content

Commit 8057993

Browse files
authored
fix
1 parent f5a27fd commit 8057993

File tree

1 file changed

+14
-18
lines changed

1 file changed

+14
-18
lines changed

lightllm/common/triton_utils/autotuner.py

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -19,21 +19,14 @@
1919
logger = init_logger(__name__)
2020

2121

22-
def _get_autotune_group():
23-
from lightllm.distributed.communication_op import dist_group_manager
24-
25-
return dist_group_manager.get_default_group().autotune_group
26-
27-
2822
class AutotuneLevel:
29-
# Do not autotune, only use the config of cached files.
30-
NO_AUTOTUNE = 0
23+
USE_AUTOTUNE_HIS_CONFIG = 0
3124
# Autotune if no config is cached.
32-
AUTOTUNE = 1
25+
ADAPTIVE_AUTOTUNE = 0
3326
# Autotune anyway to overwrite the config of cached files.
34-
AUTOTUNE_OVERWRITE = 2
27+
FORCE_AUTOTUNE = 1
3528
# Close autotune and not use the config of cached files.
36-
CLOSE_AUTOTUNE = 3
29+
CLOSE_AUTOTUNE = 2
3730

3831

3932
def autotune(
@@ -104,6 +97,7 @@ def __init__(
10497
get_current_device_name(),
10598
self.kernel_name,
10699
)
100+
os.makedirs(self.cache_dir, exist_ok=True)
107101
self.fn = fn
108102
self.static_key_func = static_key_func
109103
self.run_key_func = run_key_func
@@ -163,7 +157,7 @@ def __call__(self, *args, **kwargs):
163157
)
164158
if world_size > 1:
165159
_need_tunings = [None for _ in range(world_size)]
166-
dist.all_gather_object(_need_tunings, obj=need_tuning, group=_get_autotune_group())
160+
dist.all_gather_object(_need_tunings, obj=need_tuning, group=self._get_autotune_group())
167161
need_tuning = any(_need_tunings)
168162
if need_tuning:
169163
self._autotune(
@@ -245,7 +239,7 @@ def _autotune(self, args, kwargs, static_key, run_key, rank_id, world_size):
245239
if world_size > 1:
246240
all_keys = [None for _ in range(world_size)]
247241
all_key_str = f"{run_key}_{static_key}"
248-
dist.all_gather_object(all_keys, obj=all_key_str, group=_get_autotune_group())
242+
dist.all_gather_object(all_keys, obj=all_key_str, group=self._get_autotune_group())
249243
is_key_all_same = all(all_keys[0] == k for k in all_keys)
250244
if not is_key_all_same:
251245
logger.warning(
@@ -286,7 +280,7 @@ def _autotune(self, args, kwargs, static_key, run_key, rank_id, world_size):
286280
dist.all_gather_object(
287281
all_gather_configs,
288282
obj=(best_time, run_key, dict(static_key), best_config),
289-
group=_get_autotune_group(),
283+
group=self._get_autotune_group(),
290284
)
291285
all_gather_configs = sorted(all_gather_configs, key=lambda x: x[0])
292286
key_set = set()
@@ -312,7 +306,6 @@ def _autotune(self, args, kwargs, static_key, run_key, rank_id, world_size):
312306
if rank_id == 0:
313307
for _static_key in update_static_key_list:
314308
cache_file = os.path.join(self.cache_dir, KernelConfigs.get_config_file_name(_static_key))
315-
os.makedirs(self.cache_dir, exist_ok=True)
316309
with open(cache_file, "wb") as f:
317310
f.write(
318311
orjson.dumps(
@@ -374,6 +367,12 @@ def _run_key(self, *args, **kwargs):
374367
params = self._select_args(self._run_key_func_param_names, args, kwargs)
375368
return self.run_key_func(*params)
376369

370+
def _get_autotune_group(self,):
371+
from lightllm.distributed.communication_op import dist_group_manager
372+
373+
return dist_group_manager.get_default_group().autotune_group
374+
375+
377376

378377
class _BenchmarkState:
379378
def __init__(self):
@@ -397,6 +396,3 @@ def split_configs(configs, global_rank, global_world_size):
397396
random.Random(0).shuffle(configs)
398397
return configs[global_rank::global_world_size]
399398

400-
401-
def closest_pow_of_2(x):
402-
return triton.next_power_of_2(x - triton.next_power_of_2(x) // 4)

0 commit comments

Comments
 (0)