Skip to content

Commit a6b2f08

Browse files
committed
fix
1 parent 3328341 commit a6b2f08

File tree

6 files changed

+39
-49
lines changed

6 files changed

+39
-49
lines changed

lightllm/common/fused_moe/grouped_fused_moe.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -489,14 +489,6 @@ def get_grouped_matmul_static_key(
489489
for bn in [16, 32, 64, 128]
490490
for bk in [16, 32, 64, 128]
491491
],
492-
default_config={
493-
"BLOCK_SIZE_M": 64,
494-
"BLOCK_SIZE_N": 64,
495-
"BLOCK_SIZE_K": 32,
496-
"GROUP_SIZE_M": 8,
497-
"num_warps": 4,
498-
"num_stages": 1,
499-
},
500492
static_key_func=get_grouped_matmul_static_key,
501493
run_key_func=lambda token_num_mul_topk_num: str(nearest_power_of_2(token_num_mul_topk_num)),
502494
)

lightllm/common/fused_moe/moe_silu_and_mul.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,6 @@ def _silu_and_mul_kernel_fast(
7272
for bm in [32, 64, 128, 256]
7373
for bn in [32, 64, 128, 256]
7474
],
75-
default_config={"BLOCK_M": 128, "BLOCK_N": 128, "num_warps": 4, "NUM_STAGES": 1},
7675
static_key_func=lambda input, output: f"N={input.shape[-1] // 2},out_dtype={output.dtype}",
7776
run_key_func=lambda input: str(nearest_power_of_2(input.shape[0])),
7877
)

lightllm/common/fused_moe/moe_sum_reduce.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,6 @@ def get_static_key(input, output):
6161
for bm in [1, 2, 4, 8, 16, 32]
6262
for bd in [64, 128, 256, 512, 1024]
6363
],
64-
default_config={"BLOCK_M": 1, "BLOCK_DIM": 128, "NUM_STAGE": 1, "num_warps": 2},
6564
static_key_func=get_static_key,
6665
run_key_func=lambda input: str(nearest_power_of_2(input.shape[0])),
6766
)

lightllm/common/quantization/triton_quant/fp8/fp8w8a8_block_gemm_kernel.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,6 @@ def get_static_key(A, B, block_size, dtype):
179179
@autotune(
180180
name="w8a8_block_fp8_matmul:v1",
181181
configs=get_test_configs(),
182-
default_config={"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 128, "GROUP_M": 32, "num_stages": 3, "num_warps": 4},
183182
static_key_func=get_static_key,
184183
run_key_func=lambda M: str(nearest_power_of_2(M)),
185184
)

lightllm/common/triton_utils/autotuner.py

Lines changed: 39 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
def autotune(
2020
name,
2121
configs,
22-
default_config,
2322
static_key_func=None,
2423
run_key_func=None,
2524
reset_to_zero=None,
@@ -34,7 +33,6 @@ def decorator(fn):
3433
arg_names,
3534
name,
3635
configs,
37-
default_config,
3836
static_key_func,
3937
run_key_func,
4038
reset_to_zero,
@@ -46,6 +44,16 @@ def decorator(fn):
4644
return decorator
4745

4846

47+
def is_triton_autotune_enabled():
48+
# Whether Triton autotune is enabled (read-only check)
49+
return os.environ.get("LIGHTLLM_TRITON_AUTOTUNE", "0") == "1"
50+
51+
52+
def disable_triton_autotune():
53+
# Disable Triton autotune (setter)
54+
os.environ["LIGHTLLM_TRITON_AUTOTUNE"] = "0"
55+
56+
4957
class Autotuner:
5058
@staticmethod
5159
def _get_param_names(func):
@@ -72,7 +80,6 @@ def __init__(
7280
arg_names,
7381
name,
7482
configs,
75-
default_config,
7683
static_key_func,
7784
run_key_func,
7885
reset_to_zero,
@@ -82,11 +89,13 @@ def __init__(
8289
warmup=None,
8390
rep=None,
8491
):
85-
# 是否打印autotune信息
92+
# Whether to print autotune logs
8693
self.print_autotune = os.environ.get("LIGHTLLM_TRITON_PRINT_AUTOTUNE", "0") == "1"
94+
# Whether to use this autotune decorator
95+
self.disable_autotune = os.environ.get("DISABLE_AUTOTUNE_DECORATOR", "0") == "1"
96+
8797
self.all_configs = configs
8898
self.configs = None
89-
self.default_config = default_config
9099
self.name = name
91100
self.cache_dir = os.path.join(
92101
Path(__file__).parent, "all_kernel_configs", get_triton_version(), get_current_device_name(), self.name
@@ -95,11 +104,6 @@ def __init__(
95104
self.static_key_func = static_key_func
96105
self.run_key_func = run_key_func
97106

98-
# 是否使用之前配置
99-
self.can_be_none = os.environ.get("DISABLE_MANUAL_TUNE_CONFIG", "0") == "0"
100-
# 是否使用autotune注解
101-
self.disable_autotune = os.environ.get("DISABLE_AUTOTUNE_ANNOTATION", "0") == "1"
102-
103107
self.cached_configs = {}
104108
self.arg_names = arg_names
105109
self._argname_to_pos = {name: idx for idx, name in enumerate(self.arg_names)}
@@ -165,7 +169,6 @@ def _ensure_cache_loaded(self, static_key: str):
165169
]
166170
self.sorted_cached_configs[static_key].sort(key=lambda x: x[0])
167171
except Exception:
168-
# 若缓存损坏,忽略并在之后覆盖
169172
self.cached_configs[static_key] = {}
170173
self._loaded_static_keys.add(static_key)
171174

@@ -229,9 +232,8 @@ def __call__(self, *args, **kwargs):
229232
static_key = self._static_key(*args, **kwargs)
230233
run_key = self._run_key(*args, **kwargs)
231234

232-
# 懒加载
235+
# Lazy load
233236
self._ensure_cache_loaded(static_key)
234-
best_config = None
235237
self.nargs = dict(zip(self.arg_names, args))
236238

237239
def _benchmark(_run_key):
@@ -300,30 +302,30 @@ def _benchmark(_run_key):
300302
full_nargs = {**self.nargs, **kwargs}
301303
self.pre_hook(full_nargs, reset_only=True)
302304

303-
best_config = self.cached_configs.get(static_key, {}).get(run_key)
305+
if static_key not in self.cached_configs:
306+
if not is_triton_autotune_enabled():
307+
logger.warning(
308+
f"No kernel config for {self.name} in {self.cache_dir}/{static_key}, using default config",
309+
)
310+
self.cached_configs[static_key] = {}
311+
312+
all_configs = self.cached_configs.get(static_key)
313+
best_config = all_configs.get(run_key)
314+
315+
if best_config is not None:
316+
kwargs["run_config"] = best_config
317+
return self.fn(*args, **kwargs)
318+
319+
if is_triton_autotune_enabled():
320+
_benchmark(run_key)
321+
kwargs["run_config"] = self.cached_configs.get(static_key, {}).get(run_key)
322+
return self.fn(*args, **kwargs)
323+
324+
if all_configs != {}:
325+
closest_config = min(all_configs, key=lambda x: abs(int(x[0]) - int(run_key)))[1]
326+
self.cached_configs[static_key][run_key] = closest_config
327+
kwargs["run_config"] = closest_config
304328

305-
if best_config is None:
306-
if os.environ.get("LIGHTLLM_TRITON_AUTOTUNE", "0") == "1":
307-
_benchmark(run_key)
308-
elif not self.can_be_none:
309-
cached_for_static = self.cached_configs.setdefault(static_key, {})
310-
if static_key in self.sorted_cached_configs and self.sorted_cached_configs[static_key]:
311-
sorted_configs = self.sorted_cached_configs[static_key]
312-
try:
313-
target = int(run_key)
314-
cached_for_static[run_key] = min(sorted_configs, key=lambda x: abs(x[0] - target))[1]
315-
except Exception:
316-
cached_for_static[run_key] = self.default_config
317-
else:
318-
if static_key not in self.sorted_cached_configs:
319-
logger.warning(
320-
f"No kernel config for {self.name} in {self.cache_dir}/{static_key}, using default config"
321-
)
322-
cached_for_static[run_key] = self.default_config
323-
324-
best_config = self.cached_configs[static_key][run_key]
325-
326-
kwargs["run_config"] = best_config
327329
return self.fn(*args, **kwargs)
328330

329331
def _select_args(self, param_names, args, kwargs):
@@ -412,7 +414,7 @@ def dict_to_filename(data):
412414

413415

414416
def nearest_power_of_2(x):
415-
# 返回最接近 x 的 2 的幂次方
417+
# Return the power of two closest to x
416418
if x <= 1:
417419
return 1
418420
return triton.next_power_of_2(x - triton.next_power_of_2(x) // 4)

lightllm/models/deepseek2/triton_kernel/rotary_emb.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,6 @@ def get_static_key(q, k):
121121
@autotune(
122122
name="rotary_emb_fwd:v1",
123123
configs=get_test_configs(),
124-
default_config={"BLOCK_SEQ": 16, "NUM_STAGE": 1, "num_warps": 1, "num_stages": 1, "HEAD_PARALLEL_NUM": 1},
125124
static_key_func=get_static_key,
126125
run_key_func=lambda q: str(nearest_power_of_2(q.shape[0])),
127126
)

0 commit comments

Comments
 (0)