Skip to content

Commit 0a79f6e

Browse files
committed
fix
1 parent a6b2f08 commit 0a79f6e

File tree

5 files changed

+17
-8
lines changed

5 files changed

+17
-8
lines changed

lightllm/common/fused_moe/grouped_fused_moe.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -450,7 +450,7 @@ def grouped_matmul_kernel(
450450
return
451451

452452

453-
def get_grouped_matmul_static_key(
453+
def _get_grouped_matmul_static_key(
454454
expert_weights: torch.Tensor,
455455
topk_num: int,
456456
out: torch.Tensor,
@@ -489,7 +489,7 @@ def get_grouped_matmul_static_key(
489489
for bn in [16, 32, 64, 128]
490490
for bk in [16, 32, 64, 128]
491491
],
492-
static_key_func=get_grouped_matmul_static_key,
492+
static_key_func=_get_grouped_matmul_static_key,
493493
run_key_func=lambda token_num_mul_topk_num: str(nearest_power_of_2(token_num_mul_topk_num)),
494494
)
495495
def grouped_matmul(

lightllm/common/fused_moe/moe_sum_reduce.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def _moe_sum_reduce_kernel(
4848
tl.store(store_t_ptr, accumulator.to(input_ptr.dtype.element_ty), mask=offs_dim < dim_end)
4949

5050

51-
def get_static_key(input, output):
51+
def _get_static_key(input, output):
5252
return f"topk_num={input.shape[1]},hidden_dim={input.shape[2]},out_dtype={output.dtype}"
5353

5454

@@ -61,7 +61,7 @@ def get_static_key(input, output):
6161
for bm in [1, 2, 4, 8, 16, 32]
6262
for bd in [64, 128, 256, 512, 1024]
6363
],
64-
static_key_func=get_static_key,
64+
static_key_func=_get_static_key,
6565
run_key_func=lambda input: str(nearest_power_of_2(input.shape[0])),
6666
)
6767
def moe_sum_reduce(input: torch.Tensor, output: torch.Tensor, run_config: Dict = None):

lightllm/common/quantization/triton_quant/fp8/fp8w8a8_block_gemm_kernel.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ def get_test_configs():
165165
return fp8_gemm_configs
166166

167167

168-
def get_static_key(A, B, block_size, dtype):
168+
def _get_static_key(A, B, block_size, dtype):
169169
M, K = A.shape
170170
_, N = B.shape
171171
return {
@@ -179,7 +179,7 @@ def get_static_key(A, B, block_size, dtype):
179179
@autotune(
180180
name="w8a8_block_fp8_matmul:v1",
181181
configs=get_test_configs(),
182-
static_key_func=get_static_key,
182+
static_key_func=_get_static_key,
183183
run_key_func=lambda M: str(nearest_power_of_2(M)),
184184
)
185185
def w8a8_block_fp8_matmul(

lightllm/common/triton_utils/autotuner.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,6 @@ def __init__(
8989
warmup=None,
9090
rep=None,
9191
):
92-
# Whether to print autotune logs
93-
self.print_autotune = os.environ.get("LIGHTLLM_TRITON_PRINT_AUTOTUNE", "0") == "1"
9492
# Whether to use this autotune decorator
9593
self.disable_autotune = os.environ.get("DISABLE_AUTOTUNE_DECORATOR", "0") == "1"
9694

lightllm/models/deepseek2/triton_kernel/rotary_emb.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,17 @@ def rotary_emb_fwd(q, k, cos, sin, run_config=None):
132132
assert k.shape[0] == cos.shape[0] and k.shape[0] == sin.shape[0], f"k shape {k.shape} cos shape {cos.shape}"
133133
assert triton.next_power_of_2(head_dim) == head_dim
134134

135+
from .rotary_emb_config import DeepseekV3RotaryKernelConfig
136+
137+
if not run_config:
138+
run_config = DeepseekV3RotaryKernelConfig.try_to_get_best_config(
139+
M=total_len,
140+
Q_HEAD_NUM=head_num_q,
141+
K_HEAD_NUM=head_num_k,
142+
HEAD_DIM=head_dim,
143+
dtype=str(q.dtype),
144+
)
145+
135146
BLOCK_SEQ = run_config["BLOCK_SEQ"]
136147
HEAD_PARALLEL_NUM = run_config["HEAD_PARALLEL_NUM"]
137148
num_warps = run_config["num_warps"]

0 commit comments

Comments
 (0)