Skip to content

Commit 0184a83

Browse files
author
sufubao
committed
reformatted
1 parent e34e937 commit 0184a83

File tree

1 file changed

+9
-6
lines changed

1 file changed

+9
-6
lines changed

lightllm/common/triton_utils/autotuner.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,13 @@
2020

2121

2222
class AutotuneLevel:
23+
# Use the config of cached files in /lightllm/common/triton_utils/autotune_kernel_configs.
2324
USE_AUTOTUNE_HIS_CONFIG = 0
2425
# Autotune if no config is cached.
2526
ADAPTIVE_AUTOTUNE = 1
2627
# Autotune anyway to overwrite the config of cached files.
2728
FORCE_AUTOTUNE = 2
28-
# Close autotune and not use the config of cached files.
29+
# Close autotune and use the configs of cached files in lightllm/common/all_kernel_configs.
2930
CLOSE_AUTOTUNE = 3
3031

3132

@@ -148,7 +149,9 @@ def __call__(self, *args, **kwargs):
148149
self.cached_configs[static_key] = {}
149150

150151
if autotune_level in [AutotuneLevel.ADAPTIVE_AUTOTUNE, AutotuneLevel.FORCE_AUTOTUNE]:
151-
need_tuning = (autotune_level == AutotuneLevel.FORCE_AUTOTUNE) or (run_key not in self.cached_configs.get(static_key, {}))
152+
need_tuning = (autotune_level == AutotuneLevel.FORCE_AUTOTUNE) or (
153+
run_key not in self.cached_configs.get(static_key, {})
154+
)
152155
if world_size > 1:
153156
_need_tunings = [None for _ in range(world_size)]
154157
dist.all_gather_object(_need_tunings, obj=need_tuning, group=self._get_autotune_group())
@@ -162,7 +165,7 @@ def __call__(self, *args, **kwargs):
162165
rank_id=rank_id,
163166
world_size=world_size,
164167
)
165-
168+
166169
closest_config = self.fast_match_configs.get(static_key, {}).get(run_key, None)
167170
if closest_config is not None:
168171
kwargs["run_config"] = closest_config
@@ -361,13 +364,14 @@ def _run_key(self, *args, **kwargs):
361364
params = self._select_args(self._run_key_func_param_names, args, kwargs)
362365
return self.run_key_func(*params)
363366

364-
def _get_autotune_group(self,):
367+
def _get_autotune_group(
368+
self,
369+
):
365370
from lightllm.distributed.communication_op import dist_group_manager
366371

367372
return dist_group_manager.get_default_group().autotune_group
368373

369374

370-
371375
class _BenchmarkState:
372376
def __init__(self):
373377
self.sum = 0
@@ -389,4 +393,3 @@ def get_triton_version():
389393
def split_configs(configs, global_rank, global_world_size):
390394
random.Random(0).shuffle(configs)
391395
return configs[global_rank::global_world_size]
392-

0 commit comments

Comments
 (0)