Skip to content

Commit 25b877e

Browse files
committed
fix
1 parent cc2860e commit 25b877e

File tree

7 files changed

+58
-53
lines changed

7 files changed

+58
-53
lines changed

lightllm/common/fused_moe/grouped_fused_moe.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -471,9 +471,8 @@ def _get_grouped_matmul_static_key(
471471
)
472472

473473

474-
@autotune(
475-
name="grouped_matmul:v1",
476-
configs=[
474+
def _get_grouped_matmul_configs():
475+
return [
477476
{
478477
"BLOCK_SIZE_M": bm,
479478
"BLOCK_SIZE_N": bn,
@@ -488,7 +487,12 @@ def _get_grouped_matmul_static_key(
488487
for bm in [16, 32, 64, 128]
489488
for bn in [16, 32, 64, 128]
490489
for bk in [16, 32, 64, 128]
491-
],
490+
]
491+
492+
493+
@autotune(
494+
name="grouped_matmul:v1",
495+
configs=_get_grouped_matmul_configs,
492496
static_key_func=_get_grouped_matmul_static_key,
493497
run_key_func=lambda token_num_mul_topk_num: str(nearest_power_of_2(token_num_mul_topk_num)),
494498
)

lightllm/common/fused_moe/moe_silu_and_mul.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,15 +63,19 @@ def _silu_and_mul_kernel_fast(
6363
)
6464

6565

66-
@autotune(
67-
name="silu_and_mul_fwd:v1",
68-
configs=[
66+
def _get_silu_and_mul_configs():
67+
return [
6968
{"BLOCK_M": bm, "BLOCK_N": bn, "num_warps": nw, "NUM_STAGES": ns}
7069
for ns in [1, 2, 4]
7170
for nw in [1, 4, 8]
7271
for bm in [32, 64, 128, 256]
7372
for bn in [32, 64, 128, 256]
74-
],
73+
]
74+
75+
76+
@autotune(
77+
name="silu_and_mul_fwd:v1",
78+
configs=_get_silu_and_mul_configs,
7579
static_key_func=lambda input, output: f"N={input.shape[-1] // 2},out_dtype={output.dtype}",
7680
run_key_func=lambda input: str(nearest_power_of_2(input.shape[0])),
7781
)

lightllm/common/fused_moe/moe_sum_reduce.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,15 +52,19 @@ def _get_static_key(input, output):
5252
return f"topk_num={input.shape[1]},hidden_dim={input.shape[2]},out_dtype={output.dtype}"
5353

5454

55-
@autotune(
56-
name="moe_sum_reduce:v1",
57-
configs=[
55+
def _get_moe_sum_reduce_configs():
56+
return [
5857
{"BLOCK_M": bm, "BLOCK_DIM": bd, "NUM_STAGE": ns, "num_warps": nw}
5958
for ns in [1, 2, 4]
6059
for nw in [1, 2, 4, 8, 16]
6160
for bm in [1, 2, 4, 8, 16, 32]
6261
for bd in [64, 128, 256, 512, 1024]
63-
],
62+
]
63+
64+
65+
@autotune(
66+
name="moe_sum_reduce:v1",
67+
configs=_get_moe_sum_reduce_configs,
6468
static_key_func=_get_static_key,
6569
run_key_func=lambda input: str(nearest_power_of_2(input.shape[0])),
6670
)

lightllm/common/quantization/triton_quant/fp8/fp8w8a8_block_gemm_kernel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ def _get_static_key(A, B, block_size, dtype):
178178

179179
@autotune(
180180
name="w8a8_block_fp8_matmul:v1",
181-
configs=get_test_configs(),
181+
configs=get_test_configs,
182182
static_key_func=_get_static_key,
183183
run_key_func=lambda M: str(nearest_power_of_2(M)),
184184
)

lightllm/common/triton_utils/autotuner.py

Lines changed: 21 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from lightllm.utils.log_utils import init_logger
1414
import traceback
1515
from typing import Callable, Optional, Union, List
16+
from lightllm.utils.envs_utils import is_triton_autotune_enabled, disable_triton_autotune
1617

1718
logger = init_logger(__name__)
1819

@@ -45,16 +46,6 @@ def decorator(fn):
4546
return decorator
4647

4748

48-
def is_triton_autotune_enabled():
49-
# Whether Triton autotune is enabled (read-only check)
50-
return os.environ.get("LIGHTLLM_TRITON_AUTOTUNE", "0") == "1"
51-
52-
53-
def disable_triton_autotune():
54-
# Disable Triton autotune (setter)
55-
os.environ["LIGHTLLM_TRITON_AUTOTUNE"] = "0"
56-
57-
5849
class Autotuner:
5950
@staticmethod
6051
def _get_param_names(func):
@@ -148,8 +139,6 @@ def _post_hook(kwargs, exception):
148139
os.makedirs(self.cache_dir, exist_ok=True)
149140

150141
self._loaded_static_keys = set()
151-
self.sorted_cached_configs = {}
152-
self.early_stop_cnt = 0
153142

154143
@lru_cache(maxsize=None)
155144
def _ensure_cache_loaded(self, static_key: str):
@@ -160,15 +149,11 @@ def _ensure_cache_loaded(self, static_key: str):
160149
try:
161150
with open(cache_file, "rb") as f:
162151
self.cached_configs[static_key] = orjson.loads(f.read())
163-
self.sorted_cached_configs[static_key] = [
164-
(int(k), v) for k, v in self.cached_configs[static_key].items()
165-
]
166-
self.sorted_cached_configs[static_key].sort(key=lambda x: x[0])
167152
except Exception:
168153
self.cached_configs[static_key] = {}
169154
self._loaded_static_keys.add(static_key)
170155

171-
def _bench(self, *args, n_repeat=5, n_retries=1, current_best_ms=None, **kwargs):
156+
def _bench(self, *args, n_repeat=5, n_retries=1, **kwargs):
172157
from triton.compiler.errors import CompileTimeAssertionFailure
173158
from triton.runtime.errors import OutOfResources, PTXASError
174159

@@ -234,26 +219,25 @@ def _benchmark(_run_key):
234219
rank_id = get_global_rank()
235220
_best_config = self.default_config
236221
best_time = float("inf")
237-
self.early_stop_cnt = 0
222+
238223
bar = tqdm(
239224
self.configs,
240-
desc=f"Autotuning {self.name} for {_run_key}, es:{self.early_stop_cnt / len(self.configs):.2%}",
225+
desc=f"Autotuning {self.name} for {_run_key}",
241226
position=get_global_rank(),
242227
dynamic_ncols=True,
243228
)
244229
enum_configs = enumerate(bar)
245230
for i, config in enum_configs:
246231
kwargs_with_config = kwargs.copy()
247232
kwargs_with_config["run_config"] = config
248-
run_time = self._bench(*args, current_best_ms=best_time, **kwargs_with_config)
233+
run_time = self._bench(*args, **kwargs_with_config)
249234
if run_time < best_time:
250235
best_time = run_time
251236
_best_config = config
252237
bar.set_description(
253-
f"Autotuning {self.name} [rank:{rank_id}] \
254-
for {_run_key}, es:{self.early_stop_cnt / len(self.configs):.2%}, \
255-
best_time: {best_time:.5f}"
238+
f"Autotuning {self.name} [rank:{rank_id}] for {_run_key}, best_time: {best_time:.5f}"
256239
)
240+
257241
world_size = dist.get_world_size() if dist.is_initialized() else 1
258242
if world_size > 1:
259243
local_best = torch.tensor([best_time], device="cuda")
@@ -268,24 +252,21 @@ def _benchmark(_run_key):
268252
if static_key not in self.cached_configs:
269253
self.cached_configs[static_key] = {}
270254
self.cached_configs[static_key][run_key] = _best_config
271-
self.sorted_cached_configs[static_key] = [(int(k), v) for k, v in self.cached_configs[static_key].items()]
272-
self.sorted_cached_configs[static_key].sort(key=lambda x: x[0])
273255

256+
# save configs to file
274257
if not dist.is_initialized() or get_global_rank() == 0:
275-
if os.environ.get("LIGHTLLM_TRITON_AUTOTUNE", "0") == "1":
276-
cache_file = os.path.join(self.cache_dir, f"{static_key}.json")
277-
with open(cache_file, "wb") as f:
278-
fcntl.flock(f, fcntl.LOCK_EX)
279-
try:
280-
f.write(
281-
orjson.dumps(
282-
self.cached_configs[static_key], option=orjson.OPT_INDENT_2 | orjson.OPT_SORT_KEYS
283-
)
258+
cache_file = os.path.join(self.cache_dir, f"{static_key}.json")
259+
with open(cache_file, "wb") as f:
260+
fcntl.flock(f, fcntl.LOCK_EX)
261+
try:
262+
f.write(
263+
orjson.dumps(
264+
self.cached_configs[static_key], option=orjson.OPT_INDENT_2 | orjson.OPT_SORT_KEYS
284265
)
285-
finally:
286-
fcntl.flock(f, fcntl.LOCK_UN)
287-
if self.print_autotune:
288-
logger.info(f"Saved configs for {self.name} - {static_key} - {run_key}")
266+
)
267+
finally:
268+
fcntl.flock(f, fcntl.LOCK_UN)
269+
logger.info(f"Saved configs for {self.name} - {static_key} - {run_key}")
289270

290271
kwargs["run_config"] = self.cached_configs[static_key][run_key]
291272
full_nargs = {**self.nargs, **kwargs}
@@ -294,7 +275,8 @@ def _benchmark(_run_key):
294275
if static_key not in self.cached_configs:
295276
if not is_triton_autotune_enabled():
296277
logger.warning(
297-
f"No kernel config for {self.name} in {self.cache_dir}/{static_key}, using default config",
278+
f"No kernel config for {self.name} - {static_key}, \
279+
using default config. Use `LIGHTLLM_TRITON_AUTOTUNE=1` to enable autotune.",
298280
)
299281
self.cached_configs[static_key] = {}
300282

lightllm/models/deepseek2/triton_kernel/rotary_emb.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,10 +120,11 @@ def get_static_key(q, k):
120120

121121
@autotune(
122122
name="rotary_emb_fwd:v1",
123-
configs=get_test_configs(),
123+
configs=get_test_configs,
124124
static_key_func=get_static_key,
125125
run_key_func=lambda q: str(nearest_power_of_2(q.shape[0])),
126126
)
127+
@torch.no_grad()
127128
def rotary_emb_fwd(q, k, cos, sin, run_config=None):
128129
total_len = q.shape[0]
129130
head_num_q, head_num_k = q.shape[1], k.shape[1]

lightllm/utils/envs_utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,16 @@ def get_kv_quant_calibration_inference_count():
149149
return int(os.getenv("LIGHTLLM_KV_QUANT_CALIBRARTION_INFERENCE_COUNT", 4000))
150150

151151

152+
def is_triton_autotune_enabled():
153+
# Whether Triton autotune is enabled (read-only check)
154+
return os.environ.get("LIGHTLLM_TRITON_AUTOTUNE", "0") == "1"
155+
156+
157+
def disable_triton_autotune():
158+
# Disable Triton autotune (setter)
159+
os.environ["LIGHTLLM_TRITON_AUTOTUNE"] = "0"
160+
161+
152162
g_model_init_done = False
153163

154164

0 commit comments

Comments
 (0)