Skip to content

Commit 00885fb

Browse files
committed
fix
1 parent d6695ef commit 00885fb

File tree

1 file changed

+21
-13
lines changed

1 file changed

+21
-13
lines changed

lightllm/common/triton_utils/autotuner.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
def _get_autotune_group():
2323
from lightllm.distributed.communication_op import dist_group_manager
24+
2425
return dist_group_manager.get_default_group().autotune_group
2526

2627

@@ -70,6 +71,7 @@ def autotune(
7071
Callable: A callable object that wraps the original function and performs autotuning
7172
as needed before invocation.
7273
"""
74+
7375
def decorator(fn):
7476
return Autotuner(
7577
fn=fn,
@@ -121,8 +123,15 @@ def __init__(
121123
]
122124
self._run_key_func_param_names = [name for name, _ in inspect.signature(self.run_key_func).parameters.items()]
123125
self.mutates_args = mutates_args
124-
125-
assert self.autotune_level in [AutotuneLevel.NO_AUTOTUNE, AutotuneLevel.AUTOTUNE, AutotuneLevel.AUTOTUNE_OVERWRITE, AutotuneLevel.AUTOTUNE_RUNTIME, AutotuneLevel.AUTOTUNE_RUNTIME_OVERWRITE, AutotuneLevel.CLOSE_AUTOTUNE]
126+
127+
assert self.autotune_level in [
128+
AutotuneLevel.NO_AUTOTUNE,
129+
AutotuneLevel.AUTOTUNE,
130+
AutotuneLevel.AUTOTUNE_OVERWRITE,
131+
AutotuneLevel.AUTOTUNE_RUNTIME,
132+
AutotuneLevel.AUTOTUNE_RUNTIME_OVERWRITE,
133+
AutotuneLevel.CLOSE_AUTOTUNE,
134+
]
126135
return
127136

128137
@torch.no_grad()
@@ -131,7 +140,7 @@ def __call__(self, *args, **kwargs):
131140
return self.fn(*args, **kwargs)
132141

133142
# if the autotune_level is AutotuneLevel.CLOSE_AUTOTUNE, ignore the autotune
134-
if self.autotune_level == AutotuneLevel.CLOSE_AUTOTUNE:
143+
if get_triton_autotune_level() == AutotuneLevel.CLOSE_AUTOTUNE:
135144
return self.fn(*args, **kwargs)
136145

137146
rank_id = 0 if not dist.is_initialized() else get_global_rank()
@@ -141,23 +150,24 @@ def __call__(self, *args, **kwargs):
141150
run_key = str(self._run_key(*args, **kwargs))
142151

143152
# Lazy load the cached configs in lightllm/common/triton_utils/autotune_kernel_configs
144-
if self.autotune_level not in [AutotuneLevel.AUTOTUNE_OVERWRITE, AutotuneLevel.AUTOTUNE_RUNTIME_OVERWRITE]:
153+
if get_triton_autotune_level() not in [
154+
AutotuneLevel.AUTOTUNE_OVERWRITE,
155+
AutotuneLevel.AUTOTUNE_RUNTIME_OVERWRITE,
156+
]:
145157
self._try_load_cache(static_key)
146158

147-
if static_key not in self.cached_configs and self.autotune_level == AutotuneLevel.NO_AUTOTUNE:
159+
if static_key not in self.cached_configs and get_triton_autotune_level() == AutotuneLevel.NO_AUTOTUNE:
148160
if (dist.is_initialized() and get_current_rank_in_node() == 0) or not dist.is_initialized():
149161
logger.warning(
150162
f"No kernel config for {self.kernel_name} in {KernelConfigs.get_config_file_name(static_key)}",
151163
)
152164
self.cached_configs[static_key] = {}
153165

154-
if self.autotune_level != AutotuneLevel.NO_AUTOTUNE:
166+
if get_triton_autotune_level() != AutotuneLevel.NO_AUTOTUNE:
155167
need_tunning = run_key not in self.cached_configs.get(static_key, {})
156168
if world_size > 1:
157169
_need_tunnings = [None for _ in range(world_size)]
158-
dist.all_gather_object(
159-
_need_tunnings, obj=need_tunning, group=_get_autotune_group()
160-
)
170+
dist.all_gather_object(_need_tunnings, obj=need_tunning, group=_get_autotune_group())
161171
need_tunning = any(_need_tunnings)
162172
if need_tunning:
163173
self._autotune(
@@ -239,9 +249,7 @@ def _autotune(self, args, kwargs, static_key, run_key, rank_id, world_size):
239249
if world_size > 1:
240250
all_keys = [None for _ in range(world_size)]
241251
all_key_str = f"{run_key}_{static_key}"
242-
dist.all_gather_object(
243-
all_keys, obj=all_key_str, group=_get_autotune_group()
244-
)
252+
dist.all_gather_object(all_keys, obj=all_key_str, group=_get_autotune_group())
245253
is_key_all_same = all(all_keys[0] == k for k in all_keys)
246254
if not is_key_all_same:
247255
logger.warning(
@@ -394,4 +402,4 @@ def split_configs(configs, global_rank, global_world_size):
394402

395403

396404
def closest_pow_of_2(x):
397-
return triton.next_power_of_two(x - triton.next_power_of_two(x)//4)
405+
return triton.next_power_of_two(x - triton.next_power_of_two(x) // 4)

0 commit comments

Comments
 (0)