Skip to content

Commit 21e5df6

Browse files
author
none
committed
fix router_scaling_scale
1 parent ce39675 commit 21e5df6

File tree

3 files changed

+6
-10
lines changed

3 files changed

+6
-10
lines changed

lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_ep.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def __init__(
9595
self.n_group = network_config["n_group"]
9696
network_config["topk_group"] = network_config.get("topk_group", 0)
9797
self.topk_group = network_config["topk_group"]
98-
network_config["routed_scaling_factor"] = network_config.get("routed_scaling_factor", 0)
98+
network_config["routed_scaling_factor"] = network_config.get("routed_scaling_factor", 1.0)
9999
self.routed_scaling_factor = network_config["routed_scaling_factor"]
100100

101101
self.lock = threading.Lock()
@@ -126,6 +126,7 @@ def experts(
126126
num_expert_group=num_expert_group,
127127
scoring_func=self.scoring_func,
128128
)
129+
topk_weights.mul_(self.routed_scaling_factor)
129130

130131
if self.redundancy_expert_num > 0:
131132
redundancy_topk_ids_repair(
@@ -173,6 +174,7 @@ def low_latency_dispatch(
173174
num_expert_group=self.n_group,
174175
scoring_func=self.scoring_func,
175176
)
177+
topk_weights.mul_(self.routed_scaling_factor)
176178

177179
if self.redundancy_expert_num > 0:
178180
redundancy_topk_ids_repair(
@@ -213,6 +215,7 @@ def select_experts_and_quant_input(
213215
num_expert_group=self.n_group,
214216
scoring_func=self.scoring_func,
215217
)
218+
topk_weights.mul_(self.routed_scaling_factor)
216219
if self.redundancy_expert_num > 0:
217220
redundancy_topk_ids_repair(
218221
topk_ids=topk_idx,

lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_tp.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ def experts(self, input_tensor, router_logits, top_k, renormalize, use_grouped_t
6868
num_expert_group=num_expert_group,
6969
scoring_func=self.scoring_func,
7070
)
71+
topk_weights.mul_(self.routed_scaling_factor)
7172
if self.num_fused_shared_experts > 0:
7273
pad_topk_ids = torch.arange(
7374
start=self.n_routed_experts - self.num_fused_shared_experts,
@@ -76,7 +77,7 @@ def experts(self, input_tensor, router_logits, top_k, renormalize, use_grouped_t
7677
dtype=topk_ids.dtype,
7778
device="cuda").view(1, self.num_fused_shared_experts).repeat(topk_ids.shape[0], 1)
7879
pad_topk_weights = torch.full((topk_weights.shape[0], self.num_fused_shared_experts),
79-
fill_value=1.0 / self.routed_scaling_factor,
80+
fill_value=1.0,
8081
device="cuda",
8182
dtype=topk_weights.dtype)
8283

lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,6 @@ def __init__(self, layer_num, network_config, mode=[]):
5656
self.norm_topk_prob = network_config["norm_topk_prob"]
5757
self.n_group = network_config["n_group"]
5858
self.topk_group = network_config["topk_group"]
59-
self.routed_scaling_factor = network_config["routed_scaling_factor"]
6059

6160
self.softmax_scale = (self.qk_nope_head_dim + self.qk_rope_head_dim) ** (-0.5)
6261
if network_config.get("rope_scaling", None) is not None:
@@ -680,8 +679,6 @@ def _moe_ffn(
680679
num_expert_group=self.n_group,
681680
)
682681

683-
hidden_states.mul_(self.routed_scaling_factor)
684-
685682
if self.n_shared_experts is not None and layer_weight.num_fused_shared_experts == 0:
686683
hidden_states.add_(shared_output)
687684

@@ -707,7 +704,6 @@ def _moe_ffn_edp(
707704
num_expert_group=self.n_group,
708705
is_prefill=infer_state.is_prefill,
709706
)
710-
ep_output.mul_(self.routed_scaling_factor)
711707

712708
if self.n_shared_experts is not None:
713709
ep_output.add_(shared_output)
@@ -819,7 +815,6 @@ def overlap_tpsp_token_forward(
819815
# 0 hook
820816
if getattr(infer_state, "hook", None) is not None:
821817
infer_state.hook()
822-
_0_ffn_out *= self.routed_scaling_factor
823818
if self.n_shared_experts is not None:
824819
_0_ffn_out.add_(_0_shared_output)
825820
input_embdings.add_(_0_ffn_out.view(-1, self.embed_dim_))
@@ -833,7 +828,6 @@ def overlap_tpsp_token_forward(
833828
def _1_hook_post():
834829
_1_hook()
835830
nonlocal _1_ffn_out
836-
_1_ffn_out *= self.routed_scaling_factor
837831
if self.n_shared_experts is not None:
838832
_1_ffn_out.add_(_1_shared_output)
839833
input_embdings1.add_(_1_ffn_out.view(-1, self.embed_dim_))
@@ -965,7 +959,6 @@ def overlap_tpsp_context_forward(
965959

966960
_1_combine_event = Buffer.capture()
967961

968-
_0_ffn_out *= self.routed_scaling_factor
969962
if self.n_shared_experts is not None:
970963
_0_ffn_out.add_(_0_shared_output)
971964
input_embdings.add_(_0_ffn_out.view(-1, self.embed_dim_))
@@ -976,7 +969,6 @@ def overlap_tpsp_context_forward(
976969
def _1_hook_post():
977970
_1_hook()
978971
nonlocal _1_ffn_out
979-
_1_ffn_out *= self.routed_scaling_factor
980972
if self.n_shared_experts is not None:
981973
_1_ffn_out.add_(_1_shared_output)
982974
input_embdings1.add_(_1_ffn_out.view(-1, self.embed_dim_))

0 commit comments

Comments
 (0)