Skip to content

Commit 57e0cca

Browse files
committed
fix
1 parent 8a1a59e commit 57e0cca

File tree

3 files changed

+8
-5
lines changed

3 files changed

+8
-5
lines changed

lightllm/common/fused_moe/grouped_fused_moe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -492,7 +492,7 @@ def _get_grouped_matmul_configs():
492492
name="grouped_matmul:v1",
493493
configs_gen_func=_get_grouped_matmul_configs,
494494
static_key_func=_get_grouped_matmul_static_key,
495-
run_key_func=lambda token_num_mul_topk_num: token_num_mul_topk_num,
495+
run_key_func=lambda token_inputs: token_inputs.shape[0],
496496
)
497497
def grouped_matmul(
498498
token_num_mul_topk_num: int,

lightllm/common/quantization/triton_quant/fp8/fp8w8a8_block_gemm_kernel.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from functools import lru_cache
88
from typing import Any, Dict, List, Optional, Tuple
99
from triton import Config
10-
from lightllm.common.triton_utils.autotuner import autotune, nearest_power_of_2
10+
from lightllm.common.triton_utils.autotuner import autotune
1111

1212

1313
class Fp8BlockMMKernelConfig(KernelConfigs):
@@ -180,7 +180,7 @@ def _get_static_key(A, B, block_size, dtype):
180180
name="w8a8_block_fp8_matmul:v1",
181181
configs_gen_func=get_test_configs,
182182
static_key_func=_get_static_key,
183-
run_key_func=lambda M: M,
183+
run_key_func=lambda A: A.shape[0],
184184
)
185185
def w8a8_block_fp8_matmul(
186186
A: torch.Tensor,

lightllm/common/triton_utils/autotuner.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -225,12 +225,15 @@ def _autotune(self, args, kwargs, static_key, run_key):
225225

226226
# save configs to file
227227
if not dist.is_initialized() or get_global_rank() == 0:
228-
cache_file = os.path.join(self.cache_dir, f"{KernelConfigs.get_config_file_name(static_key)}.json")
228+
cache_file = os.path.join(self.cache_dir, KernelConfigs.get_config_file_name(static_key))
229229
with open(cache_file, "wb") as f:
230230
fcntl.flock(f, fcntl.LOCK_EX)
231231
try:
232232
f.write(
233-
orjson.dumps(self.cached_configs[static_key], option=orjson.OPT_INDENT_2 | orjson.OPT_SORT_KEYS)
233+
orjson.dumps(
234+
self.cached_configs[static_key],
235+
option=orjson.OPT_INDENT_2 | orjson.OPT_SORT_KEYS | orjson.OPT_NON_STR_KEYS,
236+
)
234237
)
235238
finally:
236239
fcntl.flock(f, fcntl.LOCK_UN)

0 commit comments

Comments
 (0)