Skip to content

Commit 7f92a33

Browse files
committed
Merge branch 'fused_moe_improve' of https://github.com/ModelTC/lightllm into fused_moe_improve
2 parents 94ca166 + 5a39dec commit 7f92a33

15 files changed

+511
-313
lines changed
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
{"1": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 2, "num_warps": 2, "num_stages": 2}, "8": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 4, "num_warps": 2, "num_stages": 3}, "64": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 1}, "128": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 1}, "256": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 4, "num_warps": 2, "num_stages": 3}, "512": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 2, "num_warps": 4, "num_stages": 3}, "1024": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}, "4096": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 4}, "8192": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 4}}
1+
{"1": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}, "8": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 2, "num_warps": 4, "num_stages": 2}, "64": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 2, "num_warps": 4, "num_stages": 3}, "128": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 2, "num_warps": 4, "num_stages": 3}, "256": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}, "512": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 3}, "1024": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}, "4096": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 4}, "8192": {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}, "16384": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}, "32768": {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}}
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
{"1": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 5}, "8": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 3}, "64": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 2, "num_stages": 2}, "128": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 2, "num_warps": 2, "num_stages": 2}, "256": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 2}, "512": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 2, "num_warps": 8, "num_stages": 3}, "1024": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4}, "4096": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 2, "num_warps": 4, "num_stages": 4}, "8192": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 3}}
1+
{"1": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 5}, "8": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 2, "num_warps": 4, "num_stages": 4}, "64": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 4}, "128": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 2, "num_warps": 4, "num_stages": 2}, "256": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 2, "num_warps": 4, "num_stages": 2}, "512": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 8, "num_stages": 3}, "1024": {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 2, "num_warps": 8, "num_stages": 3}, "4096": {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 2, "num_warps": 4, "num_stages": 3}, "8192": {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 2, "num_warps": 4, "num_stages": 3}}

lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_tp.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,11 +67,22 @@ def experts(self, input_tensor, router_logits, top_k, renormalize, use_grouped_t
6767
topk_group=topk_group,
6868
num_expert_group=num_expert_group,
6969
scoring_func=self.scoring_func,
70-
num_fused_shared_experts=self.num_fused_shared_experts,
7170
)
7271
if self.num_fused_shared_experts > 0:
73-
topk_ids[:, -1] = self.n_routed_experts - 1
74-
topk_weights[:, -1] = 1.0 / self.routed_scaling_factor
72+
pad_topk_ids = torch.arange(
73+
start=self.n_routed_experts - self.num_fused_shared_experts,
74+
end=self.n_routed_experts,
75+
step=1,
76+
dtype=topk_ids.dtype,
77+
device="cuda").view(1, self.num_fused_shared_experts).repeat(topk_ids.shape[0], 1)
78+
pad_topk_weights = torch.full((topk_weights.shape[0], self.num_fused_shared_experts),
79+
fill_value=1.0 / self.routed_scaling_factor,
80+
device="cuda",
81+
dtype=topk_weights.dtype)
82+
83+
topk_ids = torch.cat([topk_ids, pad_topk_ids], dim=1)
84+
topk_weights = torch.cat([topk_weights, pad_topk_weights], dim=1)
85+
7586
w1, w1_scale = self.w1
7687
w2, w2_scale = self.w2
7788
use_fp8_w8a8 = self.quant_method is not None

lightllm/common/fused_moe/grouped_fused_moe.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -462,6 +462,7 @@ def grouped_matmul(
462462
out: torch.Tensor,
463463
mul_routed_weight: bool,
464464
use_fp8_w8a8: bool,
465+
alloc_tensor_func=torch.empty,
465466
reused_mblock_infos=None,
466467
run_config: Optional[dict] = None,
467468
):
@@ -525,8 +526,8 @@ def grouped_matmul(
525526
else:
526527
_m, _k = token_inputs.shape
527528
assert _k % block_size_k == 0
528-
input_scale = torch.empty((_m, _k // block_size_k), dtype=torch.float32, device=token_inputs.device)
529-
qinput_tensor = torch.empty((_m, _k), dtype=expert_weights.dtype, device=token_inputs.device)
529+
input_scale = alloc_tensor_func((_m, _k // block_size_k), dtype=torch.float32, device=token_inputs.device)
530+
qinput_tensor = alloc_tensor_func((_m, _k), dtype=expert_weights.dtype, device=token_inputs.device)
530531
per_token_group_quant_fp8(token_inputs, block_size_k, qinput_tensor, input_scale)
531532
token_inputs, token_input_scale = qinput_tensor, input_scale
532533

@@ -611,6 +612,7 @@ def fused_experts_impl(
611612
w2_scale: Optional[torch.Tensor] = None,
612613
a1_scale: Optional[torch.Tensor] = None,
613614
a2_scale: Optional[torch.Tensor] = None,
615+
alloc_tensor_func=torch.empty,
614616
run_config: Optional[dict] = None,
615617
):
616618
# Check constraints.
@@ -625,26 +627,29 @@ def fused_experts_impl(
625627
CHUNK_SIZE = FFN_MOE_CHUNK_SIZE
626628
topk_num = topk_ids.shape[1]
627629
M = min(num_tokens, CHUNK_SIZE)
628-
629-
cache = torch.empty(M * topk_num * max(N, w2.shape[1]), device=hidden_states.device, dtype=hidden_states.dtype)
630-
intermediate_cache1 = cache[: M * topk_num * N].view(M, topk_num, N)
631-
intermediate_cache2 = torch.empty((M, topk_num, N // 2), device=hidden_states.device, dtype=hidden_states.dtype)
632-
intermediate_cache3 = cache[: M * topk_num * w2.shape[1]].view(M, topk_num, w2.shape[1])
630+
631+
intermediate_cache13_shared = alloc_tensor_func((M, topk_num, max(N, w2.shape[1])), device=hidden_states.device, dtype=hidden_states.dtype)
632+
intermediate_cache1 = intermediate_cache13_shared.view(-1)[:(M * topk_num * N)].view(M, topk_num, N)
633+
intermediate_cache2 = alloc_tensor_func(
634+
(M, topk_num, N // 2), device=hidden_states.device, dtype=hidden_states.dtype
635+
)
636+
intermediate_cache3 = intermediate_cache13_shared.view(-1)[:(M * topk_num * w2.shape[1])].view(M, topk_num, w2.shape[1])
633637

634638
if inplace:
635639
out_hidden_states = hidden_states
636640
else:
637-
out_hidden_states = torch.empty(hidden_states.shape, device=hidden_states.device, dtype=hidden_states.dtype)
641+
out_hidden_states = alloc_tensor_func(
642+
hidden_states.shape, device=hidden_states.device, dtype=hidden_states.dtype
643+
)
638644

639645
for chunk in range(triton.cdiv(num_tokens, CHUNK_SIZE)):
640646
begin_chunk_idx, end_chunk_idx = (chunk * CHUNK_SIZE, min((chunk + 1) * CHUNK_SIZE, num_tokens))
641647
curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx]
642648
tokens_in_chunk, _ = curr_hidden_states.shape
643649

644-
if tokens_in_chunk < CHUNK_SIZE and chunk > 0:
645-
intermediate_cache1 = intermediate_cache1[:tokens_in_chunk]
646-
intermediate_cache2 = intermediate_cache2[:tokens_in_chunk]
647-
intermediate_cache3 = intermediate_cache3[:tokens_in_chunk]
650+
intermediate_cache1 = intermediate_cache1[:tokens_in_chunk]
651+
intermediate_cache2 = intermediate_cache2[:tokens_in_chunk]
652+
intermediate_cache3 = intermediate_cache3[:tokens_in_chunk]
648653

649654
curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx]
650655
curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx]
@@ -668,6 +673,7 @@ def fused_experts_impl(
668673
out=intermediate_cache1.view(-1, N),
669674
mul_routed_weight=False,
670675
use_fp8_w8a8=use_fp8_w8a8,
676+
alloc_tensor_func=alloc_tensor_func,
671677
run_config=run_config,
672678
)
673679

@@ -686,6 +692,7 @@ def fused_experts_impl(
686692
out=intermediate_cache3.view(-1, w2.shape[1]),
687693
mul_routed_weight=True,
688694
use_fp8_w8a8=use_fp8_w8a8,
695+
alloc_tensor_func=alloc_tensor_func,
689696
reused_mblock_infos=reused_mblock_infos,
690697
run_config=run_config,
691698
)

lightllm/common/fused_moe/grouped_topk.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,6 @@ def triton_grouped_topk(
208208
topk_group: int = 0,
209209
scoring_func: str = "softmax",
210210
group_score_used_topk_num=2,
211-
num_fused_shared_experts: int = 0,
212211
):
213212

214213
if correction_bias is not None:
@@ -223,8 +222,8 @@ def triton_grouped_topk(
223222
dtype = torch.float32
224223

225224
scores_buffer = torch.empty((token_num, total_expert_num), dtype=dtype, device="cuda")
226-
out_topk_weights = torch.empty((token_num, topk + num_fused_shared_experts), dtype=torch.float32, device="cuda")
227-
out_topk_ids = torch.empty((token_num, topk + num_fused_shared_experts), dtype=torch.long, device="cuda")
225+
out_topk_weights = torch.empty((token_num, topk), dtype=torch.float32, device="cuda")
226+
out_topk_ids = torch.empty((token_num, topk), dtype=torch.long, device="cuda")
228227

229228
assert total_expert_num % num_expert_group == 0
230229

lightllm/common/fused_moe/moe_silu_and_mul.py

Lines changed: 46 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,8 @@
44
import triton.language as tl
55
from .moe_silu_and_mul_config import MoeSiluAndMulKernelConfig
66

7-
87
@triton.jit
9-
def _silu_and_mul_kernel(
8+
def _silu_and_mul_kernel_fast(
109
input_ptr,
1110
output_ptr,
1211
stride_input_m,
@@ -17,89 +16,48 @@ def _silu_and_mul_kernel(
1716
size_n,
1817
BLOCK_M: tl.constexpr,
1918
BLOCK_N: tl.constexpr,
20-
):
21-
stride_input_m = tl.cast(stride_input_m, dtype=tl.int64)
22-
stride_output_m = tl.cast(stride_output_m, dtype=tl.int64)
23-
24-
tid = tl.program_id(0)
25-
input_m_offsets = tid * BLOCK_M + tl.arange(0, BLOCK_M)
26-
output_m_offsets = tid * BLOCK_M + tl.arange(0, BLOCK_M)
27-
28-
pid = tl.program_id(1)
29-
input_n_offsets = pid * BLOCK_N + tl.arange(0, BLOCK_N)
30-
output_n_offsets = pid * BLOCK_N + tl.arange(0, BLOCK_N)
31-
32-
up_offsets = input_m_offsets[:, None] * stride_input_m + (input_n_offsets[None, :] + size_n)
33-
gate_offsets = input_m_offsets[:, None] * stride_input_m + input_n_offsets[None, :]
34-
res_offsets = output_m_offsets[:, None] * stride_output_m + output_n_offsets[None, :]
35-
36-
up = tl.load(
37-
input_ptr + up_offsets,
38-
mask=(input_n_offsets < size_n)[None, :] * (input_m_offsets < size_m)[:, None],
39-
other=0.0,
40-
)
41-
gate = tl.load(
42-
input_ptr + gate_offsets,
43-
mask=(input_n_offsets < size_n)[None, :] * (input_m_offsets < size_m)[:, None],
44-
other=0.0,
45-
).to(tl.float32)
46-
47-
gate = gate / (1 + tl.exp(-gate))
48-
gate = gate.to(input_ptr.dtype.element_ty)
49-
50-
tl.store(
51-
output_ptr + res_offsets,
52-
up * gate,
53-
mask=(output_n_offsets < size_n)[None, :] * (output_m_offsets < size_m)[:, None],
54-
)
55-
56-
57-
@triton.jit
58-
def _silu_and_mul_kernel_fast(
59-
input_ptr,
60-
output_ptr,
61-
stride_input_m,
62-
stride_input_n,
63-
stride_output_m,
64-
stride_output_n,
65-
size_n,
66-
BLOCK_N: tl.constexpr,
6719
NEED_MASK: tl.constexpr,
6820
):
6921
stride_input_m = tl.cast(stride_input_m, dtype=tl.int64)
7022
stride_output_m = tl.cast(stride_output_m, dtype=tl.int64)
7123

72-
cur_batch = tl.program_id(0)
73-
pid = tl.program_id(1)
74-
n_offsets = pid * BLOCK_N + tl.arange(0, BLOCK_N)
75-
76-
up_offsets = cur_batch * stride_input_m + (n_offsets[None, :] + size_n)
77-
gate_offsets = cur_batch * stride_input_m + n_offsets[None, :]
78-
res_offsets = cur_batch * stride_output_m + n_offsets[None, :]
24+
m_block_index = tl.program_id(0)
25+
n_block_index = tl.program_id(1)
26+
n_offsets = n_block_index * BLOCK_N + tl.arange(0, BLOCK_N)
27+
m_start_index = m_block_index * BLOCK_M
28+
m_end_index = (m_block_index + 1) * BLOCK_M
29+
m_end_index = tl.where(m_end_index < size_m, m_end_index, size_m)
7930
if NEED_MASK:
8031
mask = n_offsets[None, :] < size_n
32+
other = 0.0
8133
else:
82-
mask = True
83-
84-
up = tl.load(
85-
input_ptr + up_offsets,
86-
mask=mask,
87-
other=0.0,
88-
)
89-
gate = tl.load(
90-
input_ptr + gate_offsets,
91-
mask=mask,
92-
other=0.0,
93-
).to(tl.float32)
94-
95-
gate = gate / (1 + tl.exp(-gate))
96-
gate = gate.to(input_ptr.dtype.element_ty)
97-
98-
tl.store(
99-
output_ptr + res_offsets,
100-
up * gate,
101-
mask=mask,
102-
)
34+
mask = None
35+
other = None
36+
37+
for m_index in range(m_start_index, m_end_index):
38+
gate_offsets = m_index * stride_input_m + n_offsets[None, :]
39+
up_offsets = m_index * stride_input_m + (n_offsets[None, :] + size_n)
40+
out_offsets = m_index * stride_output_m + n_offsets[None, :]
41+
42+
up = tl.load(
43+
input_ptr + up_offsets,
44+
mask=mask,
45+
other=other,
46+
)
47+
gate = tl.load(
48+
input_ptr + gate_offsets,
49+
mask=mask,
50+
other=other,
51+
).to(tl.float32)
52+
53+
gate = gate / (1 + tl.exp(-gate))
54+
gate = gate.to(input_ptr.dtype.element_ty)
55+
56+
tl.store(
57+
output_ptr + out_offsets,
58+
up * gate,
59+
mask=mask,
60+
)
10361

10462

10563
def silu_and_mul_fwd(input: torch.Tensor, output: torch.Tensor, **run_config):
@@ -116,26 +74,6 @@ def silu_and_mul_fwd(input: torch.Tensor, output: torch.Tensor, **run_config):
11674
if not run_config:
11775
run_config = MoeSiluAndMulKernelConfig.try_to_get_best_config(M=size_m, N=size_n, out_dtype=str(output.dtype))
11876

119-
if size_m <= 4096:
120-
BLOCK_N = run_config["BLOCK_N"]
121-
grid = (
122-
size_m,
123-
triton.cdiv(size_n, BLOCK_N),
124-
)
125-
NEED_MASK = size_n % BLOCK_N != 0
126-
_silu_and_mul_kernel_fast[grid](
127-
input,
128-
output,
129-
stride_input_m,
130-
stride_input_n,
131-
stride_output_m,
132-
stride_output_n,
133-
size_n,
134-
BLOCK_N=BLOCK_N,
135-
NEED_MASK=NEED_MASK,
136-
)
137-
return
138-
13977
BLOCK_M = run_config["BLOCK_M"]
14078
BLOCK_N = run_config["BLOCK_N"]
14179
num_warps = run_config["num_warps"]
@@ -144,17 +82,19 @@ def silu_and_mul_fwd(input: torch.Tensor, output: torch.Tensor, **run_config):
14482
triton.cdiv(size_m, BLOCK_M),
14583
triton.cdiv(size_n, BLOCK_N),
14684
)
147-
_silu_and_mul_kernel[grid](
148-
input,
149-
output,
150-
stride_input_m,
151-
stride_input_n,
152-
stride_output_m,
153-
stride_output_n,
154-
size_m,
155-
size_n,
85+
NEED_MASK = (size_n % BLOCK_N) != 0
86+
_silu_and_mul_kernel_fast[grid](
87+
input_ptr=input,
88+
output_ptr=output,
89+
stride_input_m=stride_input_m,
90+
stride_input_n=stride_input_n,
91+
stride_output_m=stride_output_m,
92+
stride_output_n=stride_output_n,
93+
size_m=size_m,
94+
size_n=size_n,
15695
BLOCK_M=BLOCK_M,
15796
BLOCK_N=BLOCK_N,
97+
NEED_MASK=NEED_MASK,
15898
num_warps=num_warps,
15999
)
160100
return

lightllm/common/fused_moe/topk_select.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,6 @@ def select_experts(
175175
num_expert_group: Optional[int] = None,
176176
scoring_func: str = "softmax",
177177
custom_routing_function: Optional[Callable] = None,
178-
num_fused_shared_experts: int = 0,
179178
):
180179
from lightllm.common.fused_moe.topk_select import fused_topk
181180
from lightllm.common.fused_moe.grouped_topk import triton_grouped_topk
@@ -211,7 +210,6 @@ def select_experts(
211210
topk_group=topk_group,
212211
scoring_func=scoring_func,
213212
group_score_used_topk_num=group_score_topk_num,
214-
num_fused_shared_experts=num_fused_shared_experts,
215213
)
216214

217215
elif custom_routing_function is None:

0 commit comments

Comments
 (0)