@@ -171,10 +171,10 @@ def __init__(self, *args, **kwargs):
171171 self .moe_config .dp_group = get_dp_group ()
172172 self .moe_config .ep_group = get_ep_group ()
173173 self .moe_config .mc2_group = get_mc2_group ()
174- ascend_config = get_ascend_config ()
175- self .dynamic_eplb = ascend_config .dynamic_eplb or ascend_config .expert_map_record_path
176- self .expert_map_path = ascend_config .expert_map_path
177- self .global_redundant_expert_num = ascend_config .init_redundancy_expert
174+ self . ascend_config = get_ascend_config ()
175+ self .dynamic_eplb = self . ascend_config .dynamic_eplb or self . ascend_config .expert_map_record_path
176+ self .expert_map_path = self . ascend_config .expert_map_path
177+ self .global_redundant_expert_num = self . ascend_config .init_redundancy_expert
178178 self .global_num_experts = num_experts + self .global_redundant_expert_num
179179 if self .custom_routing_function is None and self .e_score_correction_bias is not None :
180180 vllm_config = get_current_vllm_config ()
@@ -194,8 +194,8 @@ def __init__(self, *args, **kwargs):
194194 self .expert_load_balancer = ExpertLoadBalancer (
195195 self .expert_map_path , num_experts )
196196 self .expert_load_balancer .check_expert_map_tensor ()
197- self .global_redundant_expert_num = (
198- self .expert_load_balancer .get_global_redundant_expert_num ())
197+ # self.global_redundant_expert_num = (
198+ # self.expert_load_balancer.get_global_redundant_expert_num())
199199 self .global_num_experts = num_experts + self .global_redundant_expert_num
200200 try :
201201 self .local_num_experts , self .expert_map = (
@@ -253,7 +253,7 @@ def __init__(self, *args, **kwargs):
253253 moe_quant_params ["intermediate_size_full" ] = intermediate_size
254254 self .quant_method .create_weights (layer = self , ** moe_quant_params )
255255
256- self .enable_shared_expert_dp = ascend_config .enable_shared_expert_dp
256+ self .enable_shared_expert_dp = self . ascend_config .enable_shared_expert_dp
257257
258258 setup_moe_comm_method (self .moe_config )
259259 self .quant_type = self ._get_quant_type ()
@@ -459,8 +459,8 @@ def __init__(
459459 self ._shared_experts = shared_experts
460460 self .use_overlapped = use_overlapped
461461 self .shared_expert_stream = None
462- ascend_config = get_ascend_config ()
463- self .multistream_overlap_shared_expert = ascend_config .multistream_overlap_shared_expert
462+ self . ascend_config = get_ascend_config ()
463+ self .multistream_overlap_shared_expert = self . ascend_config .multistream_overlap_shared_expert
464464 if enable_sp ():
465465 logger .info_once (
466466 "Sequence parallelism is enabled, shared experts are replicated for best performance."
@@ -488,11 +488,19 @@ def forward(
488488 hidden_states : torch .Tensor ,
489489 router_logits : torch .Tensor ,
490490 ) -> tuple [torch .Tensor , torch .Tensor ]:
491- shared_out , fused_out = AscendFusedMoE .forward (
492- self ,
493- hidden_states = hidden_states ,
494- router_logits = router_logits ,
495- )
491+ if self ._shared_experts is None :
492+ fused_out = AscendFusedMoE .forward (
493+ self ,
494+ hidden_states = hidden_states ,
495+ router_logits = router_logits ,
496+ )
497+ shared_out = None
498+ else :
499+ shared_out , fused_out = AscendFusedMoE .forward (
500+ self ,
501+ hidden_states = hidden_states ,
502+ router_logits = router_logits ,
503+ )
496504 return shared_out , fused_out
497505
498506 def forward_impl (self , hidden_states : torch .Tensor ,
@@ -506,7 +514,10 @@ def forward_impl(self, hidden_states: torch.Tensor,
506514 # Use a separate stream to run shared experts.
507515 # Note that currently we only support calculations in separate streams with aclgraph.
508516 # Communication operations in another stream might cause unknown errors.
509- shared_out = self ._shared_experts (hidden_states )
517+ if self ._shared_experts is None :
518+ shared_out = None
519+ else :
520+ shared_out = self ._shared_experts (hidden_states )
510521
511522 fused_output = AscendFusedMoE .forward_impl (
512523 self ,
@@ -521,6 +532,9 @@ def forward_impl(self, hidden_states: torch.Tensor,
521532 forward_context = get_forward_context ()
522533 moe_comm_type = forward_context .moe_comm_type
523534 if moe_comm_type in {MoECommType .ALLTOALL , MoECommType .MC2 } \
524- and not shared_expert_dp_enabled ():
535+ and not shared_expert_dp_enabled () and shared_out is not None :
525536 shared_out = tensor_model_parallel_all_reduce (shared_out )
526- return shared_out , fused_output
537+ if shared_out is None :
538+ return fused_output
539+ else :
540+ return shared_out , fused_output
0 commit comments