Skip to content

Commit 2853168

Browse files
committed
fix
1 parent b9cc80d commit 2853168

File tree

23 files changed

+217
-1492
lines changed

23 files changed

+217
-1492
lines changed

lightllm/common/basemodel/basemodel.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -723,6 +723,7 @@ def _check_max_len_infer(self):
723723
return
724724

725725
def autotune_layers(self):
726+
# 控制autotune的层数,用于适配不同模型
726727
return self.config.get("first_k_dense_replace", 0) + 1
727728

728729
@final
@@ -749,7 +750,7 @@ def _autotune_warmup(self):
749750
for input_len in warmup_lengths:
750751
try:
751752
logger.info(f"autotune warmup for length {input_len}")
752-
dummy_input_ids = torch.ones(input_len, dtype=torch.int32, device="cuda")
753+
dummy_input_ids = torch.randint(0, 10000, (input_len,), dtype=torch.int32, device="cuda")
753754
b_req_idx = torch.tensor([self.req_manager.alloc()], dtype=torch.int32, device="cuda")
754755
mem_indexes = self.mem_manager.alloc(len(dummy_input_ids)).cuda()
755756
b_seq_len = torch.ones(1, dtype=torch.int32, device="cuda")
@@ -777,11 +778,13 @@ def _autotune_warmup(self):
777778
del model_output
778779
self.req_manager.free_all()
779780
self.mem_manager.free_all()
781+
torch.cuda.empty_cache()
780782
logger.info(f"autotune warmup for length {input_len} ok")
781783
except Exception as e:
782784
logger.warning(f"autotune warmup for length {input_len} failed: {str(e)}")
783785
self.req_manager.free_all()
784786
self.mem_manager.free_all()
787+
torch.cuda.empty_cache()
785788
self.layers_num = layer_num_bak
786789
torch.distributed.barrier()
787790
os.environ["LIGHTLLM_TRITON_AUTOTUNE"] = "0"

lightllm/common/fused_moe/grouped_fused_moe.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
get_device_sm_shared_mem_num,
3131
get_device_warp_size,
3232
)
33+
from .moe_kernel_configs import MoeGroupedGemmKernelConfig
3334
from .moe_silu_and_mul import silu_and_mul_fwd
3435
from .moe_sum_reduce import moe_sum_reduce
3536
from lightllm.common.quantization.triton_quant.fp8.fp8act_quant_kernel import per_token_group_quant_fp8
@@ -117,7 +118,7 @@ def moe_align1_kernel(
117118
experts_topk_weight_stride0,
118119
experts_topk_weight_stride1,
119120
TOKEN_BLOCK_SIZE: tl.constexpr,
120-
NUM_STAGE: tl.constexpr,
121+
num_stages: tl.constexpr,
121122
):
122123

123124
expert_id = tl.program_id(axis=0)
@@ -126,7 +127,7 @@ def moe_align1_kernel(
126127

127128
pre_sum = 0
128129

129-
for start_loc in tl.range(0, experts_info_n, TOKEN_BLOCK_SIZE, num_stages=NUM_STAGE):
130+
for start_loc in tl.range(0, experts_info_n, TOKEN_BLOCK_SIZE, num_stages=num_stages):
130131
n_range = start_loc + off_n
131132
topk_weights_data = tl.load(topk_weights + n_range, mask=n_range < experts_info_n, other=0)
132133
expert_data = tl.load(
@@ -212,7 +213,7 @@ def moe_align1(
212213
experts_weight_info.stride(0),
213214
experts_weight_info.stride(1),
214215
TOKEN_BLOCK_SIZE=TOKEN_BLOCK_SIZE,
215-
NUM_STAGE=4,
216+
num_stages=4,
216217
num_warps=8,
217218
num_stages=1,
218219
)
@@ -478,8 +479,8 @@ def get_grouped_matmul_static_key(
478479
"BLOCK_SIZE_N": bn,
479480
"BLOCK_SIZE_K": bk,
480481
"GROUP_SIZE_M": gm,
481-
"NUM_WARPS": nw,
482-
"NUM_STAGE": ns,
482+
"num_warps": nw,
483+
"num_stages": ns,
483484
}
484485
for ns in [1, 2, 3, 4, 5]
485486
for gm in [1, 2, 4, 8]
@@ -493,8 +494,8 @@ def get_grouped_matmul_static_key(
493494
"BLOCK_SIZE_N": 64,
494495
"BLOCK_SIZE_K": 32,
495496
"GROUP_SIZE_M": 8,
496-
"NUM_WARPS": 4,
497-
"NUM_STAGE": 1,
497+
"num_warps": 4,
498+
"num_stages": 1,
498499
},
499500
static_key_func=get_grouped_matmul_static_key,
500501
run_key_func=lambda token_num_mul_topk_num: str(nearest_power_of_2(token_num_mul_topk_num)),
@@ -536,7 +537,6 @@ def grouped_matmul(
536537
assert expert_to_token_num.is_contiguous()
537538
assert expert_to_weights.is_contiguous()
538539
assert expert_weights.is_contiguous()
539-
assert run_config is not None
540540

541541
# for deepseek_v3 block-wise quant
542542
block_size_n = 0
@@ -546,12 +546,24 @@ def grouped_matmul(
546546
block_size_n = expert_weights.shape[1] // expert_to_weights_scale.shape[1]
547547
block_size_k = expert_weights.shape[2] // expert_to_weights_scale.shape[2]
548548

549+
if run_config is None:
550+
run_config = MoeGroupedGemmKernelConfig.try_to_get_best_config(
551+
M=token_inputs.shape[0],
552+
N=n,
553+
K=k,
554+
topk_num=topk_num,
555+
expert_num=expert_num,
556+
mul_routed_weight=mul_routed_weight,
557+
use_fp8_w8a8=use_fp8_w8a8,
558+
out_dtype=str(out.dtype),
559+
)
560+
549561
BLOCK_SIZE_M = run_config["BLOCK_SIZE_M"]
550562
BLOCK_SIZE_N = run_config["BLOCK_SIZE_N"]
551563
BLOCK_SIZE_K = run_config["BLOCK_SIZE_K"]
552564
GROUP_SIZE_M = run_config["GROUP_SIZE_M"]
553-
num_warps = run_config["NUM_WARPS"]
554-
num_stages = run_config["NUM_STAGE"]
565+
num_warps = run_config["num_warps"]
566+
num_stages = run_config["num_stages"]
555567

556568
if block_size_k != 0:
557569
# 如果使用了 block wise 量化,分块大小不能超过 block size

lightllm/common/fused_moe/moe_silu_and_mul.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,18 +79,24 @@ def _silu_and_mul_kernel_fast(
7979
def silu_and_mul_fwd(input: torch.Tensor, output: torch.Tensor, run_config=None):
8080
assert input.is_contiguous()
8181
assert output.is_contiguous()
82-
assert run_config is not None
82+
8383
stride_input_m = input.stride(0)
8484
stride_input_n = input.stride(1)
8585
stride_output_m = output.stride(0)
8686
stride_output_n = output.stride(1)
8787
size_m = input.shape[0]
8888
size_n = input.shape[-1] // 2
8989

90+
if not run_config:
91+
run_config = MoeSiluAndMulKernelConfig.try_to_get_best_config(M=size_m, N=size_n, out_dtype=str(output.dtype))
92+
9093
BLOCK_M = run_config["BLOCK_M"]
9194
BLOCK_N = run_config["BLOCK_N"]
9295
num_warps = run_config["num_warps"]
9396
NUM_STAGES = run_config["NUM_STAGES"]
97+
# limit the grid size to avoid the invalid argument error of triton
98+
while triton.cdiv(size_m, BLOCK_M) > 8192:
99+
BLOCK_M *= 2
94100

95101
grid = (
96102
triton.cdiv(size_n, BLOCK_N),

lightllm/common/fused_moe/moe_sum_reduce.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,11 @@ def moe_sum_reduce(input: torch.Tensor, output: torch.Tensor, run_config: Dict =
7272
token_num, topk_num, hidden_dim = input.shape
7373
assert output.shape[0] == token_num and output.shape[1] == hidden_dim
7474

75+
if not run_config:
76+
run_config = MoeSumReduceKernelConfig.try_to_get_best_config(
77+
M=token_num, topk_num=topk_num, hidden_dim=hidden_dim, out_dtype=str(output.dtype)
78+
)
79+
7580
BLOCK_M = run_config["BLOCK_M"]
7681
BLOCK_DIM = run_config["BLOCK_DIM"]
7782
NUM_STAGE = run_config["NUM_STAGE"]

lightllm/common/triton_utils/all_kernel_configs/triton_3.3.1/NVIDIA H200/grouped_matmul:v1/K=192,N=4096,expert_num=128,mul_routed_weight=True,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=False.json

Lines changed: 0 additions & 122 deletions
This file was deleted.

0 commit comments

Comments
 (0)