Skip to content

Commit aeb80a5

Browse files
committed
add fused expert
1 parent 038e089 commit aeb80a5

15 files changed

+82
-40
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
{"1": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 1, "num_warps": 2, "num_stages": 3}, "8": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 3}, "64": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 3}, "128": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}, "256": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}, "512": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 2}, "1024": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}, "4096": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}, "8192": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
{"1": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 2, "num_warps": 4, "num_stages": 5}, "8": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 2, "num_warps": 4, "num_stages": 5}, "64": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 5}, "128": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 5}, "256": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 5}}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
{"1": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 5}, "8": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 5}, "64": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 3}, "128": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 5}, "256": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 2, "num_warps": 4, "num_stages": 5}, "512": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 4}, "1024": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 3}, "4096": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}, "8192": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}}

lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_ep.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -84,17 +84,17 @@ def __init__(
8484
self.e_score_correction_bias = None
8585
self.w2_list = [None] * ep_load_expert_num
8686
self.w2_scale_list = [None] * ep_load_expert_num
87-
self.scoring_func = network_config["scoring_func"]
87+
self.scoring_func = "softmax" # network_config["scoring_func"]
8888
self.w1 = [None, None] # weight, weight_scale
8989
self.w2 = [None, None] # weight, weight_scale
9090
self.use_fp8_w8a8 = self.quant_method is not None
91-
91+
network_config["n_group"] = 0
9292
self.num_experts_per_tok = network_config["num_experts_per_tok"]
9393
self.use_grouped_topk = network_config["n_group"] > 0
9494
self.norm_topk_prob = network_config["norm_topk_prob"]
9595
self.n_group = network_config["n_group"]
96-
self.topk_group = network_config["topk_group"]
97-
self.routed_scaling_factor = network_config["routed_scaling_factor"]
96+
self.topk_group = 0 # network_config["topk_group"]
97+
self.routed_scaling_factor = 0 # network_config["routed_scaling_factor"]
9898

9999
self.lock = threading.Lock()
100100
# init buffer

lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_tp.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ def __init__(
1616
e_score_correction_bias_name: str,
1717
weight_prefix: str,
1818
n_routed_experts: int,
19+
num_fused_shared_experts: int,
1920
split_inter_size: int,
2021
data_type: torch.dtype,
2122
network_config: Dict[str, Any],
@@ -34,7 +35,10 @@ def __init__(
3435

3536
self.e_score_correction_bias_name = e_score_correction_bias_name
3637
self.weight_prefix = weight_prefix
37-
self.n_routed_experts = n_routed_experts
38+
assert num_fused_shared_experts in [0, 1], "num_fused_shared_experts can only support 0 or 1 now."
39+
self.n_routed_experts = n_routed_experts + num_fused_shared_experts
40+
self.num_fused_shared_experts = num_fused_shared_experts
41+
self.routed_scaling_factor = network_config.get("routed_scaling_factor", 1.0)
3842
self.split_inter_size = split_inter_size
3943
self.data_type_ = data_type
4044
self.tp_rank_ = get_current_rank_in_dp()
@@ -63,7 +67,11 @@ def experts(self, input_tensor, router_logits, top_k, renormalize, use_grouped_t
6367
topk_group=topk_group,
6468
num_expert_group=num_expert_group,
6569
scoring_func=self.scoring_func,
70+
num_fused_shared_experts=self.num_fused_shared_experts,
6671
)
72+
if self.num_fused_shared_experts > 0:
73+
topk_ids[:, -1] = self.n_routed_experts - 1
74+
topk_weights[:, -1] = 1.0 / self.routed_scaling_factor
6775
w1, w1_scale = self.w1
6876
w2, w2_scale = self.w2
6977
use_fp8_w8a8 = self.quant_method is not None

lightllm/common/fused_moe/grouped_fused_moe.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -648,7 +648,6 @@ def fused_experts_impl(
648648
CHUNK_SIZE = FFN_MOE_CHUNK_SIZE
649649
topk_num = topk_ids.shape[1]
650650
M = min(num_tokens, CHUNK_SIZE)
651-
652651
intermediate_cache1 = alloc_tensor_func((M, topk_num, N), device=hidden_states.device, dtype=hidden_states.dtype)
653652
intermediate_cache2 = alloc_tensor_func(
654653
(M, topk_num, N // 2), device=hidden_states.device, dtype=hidden_states.dtype

lightllm/common/fused_moe/grouped_topk.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,7 @@ def triton_grouped_topk(
208208
topk_group: int = 0,
209209
scoring_func: str = "softmax",
210210
group_score_used_topk_num=2,
211+
num_fused_shared_experts: int = 0,
211212
):
212213

213214
if correction_bias is not None:
@@ -222,8 +223,8 @@ def triton_grouped_topk(
222223
dtype = torch.float32
223224

224225
scores_buffer = torch.empty((token_num, total_expert_num), dtype=dtype, device="cuda")
225-
out_topk_weights = torch.empty((token_num, topk), dtype=torch.float32, device="cuda")
226-
out_topk_ids = torch.empty((token_num, topk), dtype=torch.long, device="cuda")
226+
out_topk_weights = torch.empty((token_num, topk + num_fused_shared_experts), dtype=torch.float32, device="cuda")
227+
out_topk_ids = torch.empty((token_num, topk + num_fused_shared_experts), dtype=torch.long, device="cuda")
227228

228229
assert total_expert_num % num_expert_group == 0
229230

lightllm/common/fused_moe/moe_kernel_configs.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,12 +42,12 @@ def try_to_get_best_config(
4242
else:
4343
if M <= expert_num:
4444
config = {
45-
"BLOCK_SIZE_M": 16,
46-
"BLOCK_SIZE_N": 32,
47-
"BLOCK_SIZE_K": 64,
48-
"GROUP_SIZE_M": 1,
45+
"BLOCK_SIZE_M": 32,
46+
"BLOCK_SIZE_N": 128,
47+
"BLOCK_SIZE_K": 128,
48+
"GROUP_SIZE_M": 32,
4949
"num_warps": 4,
50-
"num_stages": 1,
50+
"num_stages": 3,
5151
}
5252
else:
5353
config = {

lightllm/common/fused_moe/topk_select.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,7 @@ def select_experts(
181181
num_expert_group: Optional[int] = None,
182182
scoring_func: str = "softmax",
183183
custom_routing_function: Optional[Callable] = None,
184+
num_fused_shared_experts: int = 0,
184185
):
185186
from lightllm.common.fused_moe.topk_select import fused_topk
186187
from lightllm.common.fused_moe.grouped_topk import triton_grouped_topk
@@ -216,6 +217,7 @@ def select_experts(
216217
topk_group=topk_group,
217218
scoring_func=scoring_func,
218219
group_score_used_topk_num=group_score_topk_num,
220+
num_fused_shared_experts=num_fused_shared_experts,
219221
)
220222

221223
elif custom_routing_function is None:

lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -665,7 +665,8 @@ def _moe_ffn(
665665
hidden_states = input.view(-1, self.embed_dim_)
666666
num_tokens, hidden_dim = hidden_states.shape
667667

668-
if self.n_shared_experts is not None:
668+
# if fused_shared_experts is not enabled, compute shared_output
669+
if self.n_shared_experts is not None and layer_weight.num_fused_shared_experts == 0:
669670
shared_output = LlamaTransformerLayerInfer._ffn(self, hidden_states, infer_state, layer_weight)
670671

671672
router_logits = layer_weight.moe_gate.mm(hidden_states)
@@ -681,7 +682,7 @@ def _moe_ffn(
681682

682683
hidden_states.mul_(self.routed_scaling_factor)
683684

684-
if self.n_shared_experts is not None:
685+
if self.n_shared_experts is not None and layer_weight.num_fused_shared_experts == 0:
685686
hidden_states.add_(shared_output)
686687

687688
return hidden_states.view(num_tokens, hidden_dim)

0 commit comments

Comments
 (0)