Skip to content

Commit 169c4bd

Browse files
author
tanqingshan (A)
committed
Feat-4470
Signed-off-by: tanqingshan (A) <50050625@china.huawei.com>
1 parent 7a21bd7 commit 169c4bd

File tree

7 files changed

+384
-36
lines changed

7 files changed

+384
-36
lines changed

vllm_ascend/ascend_config.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ class AscendConfig:
3434

3535
def __init__(self, vllm_config):
3636
additional_config = vllm_config.additional_config if vllm_config.additional_config is not None else {}
37+
self.mix_placement = additional_config.get("mix_placement",False)
3738
torchair_graph_config = additional_config.get("torchair_graph_config",
3839
{})
3940

@@ -368,4 +369,4 @@ def check_ascend_config(vllm_config, enforce_eager):
368369
logger.warning(
369370
"ACL Graph is currently experimental. Please "
370371
"raise an issue on https://github.com/vllm-project/vllm-ascend/issues"
371-
" if you encourage any Error")
372+
" if you encourage any Error")

vllm_ascend/eplb/adaptor/vllm_adaptor.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -221,12 +221,13 @@ def _export_tensor_to_file(self, expert_maps, expert_map_record_path: str):
221221
json.dump(record, f, indent=4)
222222

223223
def do_update_expert_map(self, layer_id, updated_expert_map):
224-
pad_len = self.expert_map_per_layer[layer_id].shape[
225-
0] - updated_expert_map.shape[0]
226-
updated_expert_map_padded = torch.nn.functional.pad(updated_expert_map,
227-
pad=(0, pad_len),
228-
mode='constant',
229-
value=-1)
224+
pad_len = self.expert_map_per_layer[layer_id].shape[0] - updated_expert_map.shape[0]
225+
updated_expert_map_padded = torch.nn.functional.pad(
226+
updated_expert_map,
227+
pad=(0,pad_len),
228+
mode='constant',
229+
value=-1
230+
)
230231
self.expert_map_per_layer[layer_id].copy_(updated_expert_map_padded)
231232
self.expert_map_per_layer_cpu[layer_id].copy_(updated_expert_map)
232233

@@ -240,15 +241,14 @@ def do_update_expert_weight(self, layer_id, local_expert_to_replace,
240241

241242
def do_update_log2phy_map(self, layer_id, updated_log2phy_map):
242243
if self.log2phy_map_per_layer[layer_id] is not None:
243-
pad_len = self.log2phy_map_per_layer[layer_id].shape[
244-
0] - updated_log2phy_map.shape[0]
244+
pad_len = self.log2phy_map_per_layer[layer_id].shape[0] - updated_log2phy_map.shape[0]
245245
updated_log2phy_map_padded = torch.nn.functional.pad(
246-
updated_log2phy_map,
247-
pad=(0, pad_len),
248-
mode='constant',
249-
value=-1)
250-
self.log2phy_map_per_layer[layer_id].copy_(
251-
updated_log2phy_map_padded)
246+
updated_log2phy_map,
247+
pad=(0,pad_len),
248+
mode='constant',
249+
value=-1
250+
)
251+
self.log2phy_map_per_layer[layer_id].copy_(updated_log2phy_map_padded)
252252

253253
def global2local(self, placement: torch.Tensor,
254254
E_local: int) -> torch.Tensor:

vllm_ascend/ops/fused_moe/experts_selector.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ def select_experts(hidden_states: torch.Tensor,
3333
routed_scaling_factor=1.0,
3434
e_score_correction_bias: Optional[torch.Tensor] = None,
3535
indices_type: Optional[torch.dtype] = None,
36+
mix_placement: Optional[bool] = False,
37+
num_logical_experts: int = -1,
3638
global_num_experts: int = -1):
3739
"""
3840
Fused experts with select experts.
@@ -87,6 +89,19 @@ def select_experts(hidden_states: torch.Tensor,
8789
e_score_correction_bias=e_score_correction_bias,
8890
global_num_experts=global_num_experts,
8991
)
92+
if mix_placement:
93+
pad_shared_expert_ids = torch.full((topk_ids.shape[0], 1),
94+
num_logical_experts,
95+
dtype=topk_ids.dtype,
96+
device=topk_ids.device)
97+
98+
pad_shared_expert_weights = torch.full((topk_weights.shape[0], 1),
99+
0.4,
100+
dtype=topk_weights.dtype,
101+
device=topk_weights.device)
102+
topk_ids = torch.cat([topk_ids, pad_shared_expert_ids], dim=1)
103+
topk_weights = torch.cat([topk_weights, pad_shared_expert_weights],
104+
dim=1)
90105
return topk_weights, topk_ids
91106

92107

@@ -271,4 +286,4 @@ def _native_select_experts(
271286
topk_ids = topk_ids.to(torch.int32)
272287
topk_weights = _renormalize_topk_weights(topk_weights, renormalize)
273288

274-
return topk_weights, topk_ids
289+
return topk_weights, topk_ids

vllm_ascend/ops/fused_moe/fused_moe.py

Lines changed: 31 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -171,10 +171,10 @@ def __init__(self, *args, **kwargs):
171171
self.moe_config.dp_group = get_dp_group()
172172
self.moe_config.ep_group = get_ep_group()
173173
self.moe_config.mc2_group = get_mc2_group()
174-
ascend_config = get_ascend_config()
175-
self.dynamic_eplb = ascend_config.dynamic_eplb or ascend_config.expert_map_record_path
176-
self.expert_map_path = ascend_config.expert_map_path
177-
self.global_redundant_expert_num = ascend_config.init_redundancy_expert
174+
self.ascend_config = get_ascend_config()
175+
self.dynamic_eplb = self.ascend_config.dynamic_eplb or self.ascend_config.expert_map_record_path
176+
self.expert_map_path = self.ascend_config.expert_map_path
177+
self.global_redundant_expert_num = self.ascend_config.init_redundancy_expert
178178
self.global_num_experts = num_experts + self.global_redundant_expert_num
179179
if self.custom_routing_function is None and self.e_score_correction_bias is not None:
180180
vllm_config = get_current_vllm_config()
@@ -194,8 +194,8 @@ def __init__(self, *args, **kwargs):
194194
self.expert_load_balancer = ExpertLoadBalancer(
195195
self.expert_map_path, num_experts)
196196
self.expert_load_balancer.check_expert_map_tensor()
197-
self.global_redundant_expert_num = (
198-
self.expert_load_balancer.get_global_redundant_expert_num())
197+
# self.global_redundant_expert_num = (
198+
# self.expert_load_balancer.get_global_redundant_expert_num())
199199
self.global_num_experts = num_experts + self.global_redundant_expert_num
200200
try:
201201
self.local_num_experts, self.expert_map = (
@@ -253,7 +253,7 @@ def __init__(self, *args, **kwargs):
253253
moe_quant_params["intermediate_size_full"] = intermediate_size
254254
self.quant_method.create_weights(layer=self, **moe_quant_params)
255255

256-
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
256+
self.enable_shared_expert_dp = self.ascend_config.enable_shared_expert_dp
257257

258258
setup_moe_comm_method(self.moe_config)
259259
self.quant_type = self._get_quant_type()
@@ -459,8 +459,8 @@ def __init__(
459459
self._shared_experts = shared_experts
460460
self.use_overlapped = use_overlapped
461461
self.shared_expert_stream = None
462-
ascend_config = get_ascend_config()
463-
self.multistream_overlap_shared_expert = ascend_config.multistream_overlap_shared_expert
462+
self.ascend_config = get_ascend_config()
463+
self.multistream_overlap_shared_expert = self.ascend_config.multistream_overlap_shared_expert
464464
if enable_sp():
465465
logger.info_once(
466466
"Sequence parallelism is enabled, shared experts are replicated for best performance."
@@ -488,11 +488,19 @@ def forward(
488488
hidden_states: torch.Tensor,
489489
router_logits: torch.Tensor,
490490
) -> tuple[torch.Tensor, torch.Tensor]:
491-
shared_out, fused_out = AscendFusedMoE.forward(
492-
self,
493-
hidden_states=hidden_states,
494-
router_logits=router_logits,
495-
)
491+
if self._shared_experts is None:
492+
fused_out = AscendFusedMoE.forward(
493+
self,
494+
hidden_states=hidden_states,
495+
router_logits=router_logits,
496+
)
497+
shared_out = None
498+
else:
499+
shared_out, fused_out = AscendFusedMoE.forward(
500+
self,
501+
hidden_states=hidden_states,
502+
router_logits=router_logits,
503+
)
496504
return shared_out, fused_out
497505

498506
def forward_impl(self, hidden_states: torch.Tensor,
@@ -506,7 +514,10 @@ def forward_impl(self, hidden_states: torch.Tensor,
506514
# Use a separate stream to run shared experts.
507515
# Note that currently we only support calculations in separate streams with aclgraph.
508516
# Communication operations in another stream might cause unknown errors.
509-
shared_out = self._shared_experts(hidden_states)
517+
if self._shared_experts is None:
518+
shared_out = None
519+
else:
520+
shared_out = self._shared_experts(hidden_states)
510521

511522
fused_output = AscendFusedMoE.forward_impl(
512523
self,
@@ -521,6 +532,9 @@ def forward_impl(self, hidden_states: torch.Tensor,
521532
forward_context = get_forward_context()
522533
moe_comm_type = forward_context.moe_comm_type
523534
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2} \
524-
and not shared_expert_dp_enabled():
535+
and not shared_expert_dp_enabled() and shared_out is not None:
525536
shared_out = tensor_model_parallel_all_reduce(shared_out)
526-
return shared_out, fused_output
537+
if shared_out is None:
538+
return fused_output
539+
else:
540+
return shared_out, fused_output

vllm_ascend/ops/fused_moe/moe_mlp.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -127,9 +127,8 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
127127
if quantized_hidden_states is not None:
128128
dispose_tensor(quantized_hidden_states)
129129
# act_fn: swiglu
130-
group_diff = torch.diff(group_list, dim=0)
131-
new_group = torch.cat([group_list[0].unsqueeze(0), group_diff],
132-
dim=0)
130+
group_diff = torch.diff(group_list)
131+
new_group = torch.cat([group_diff[0].unsqueeze(0), group_diff], dim=0)
133132
hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant(
134133
x=hidden_states,
135134
weight_scale=w1_scale,

vllm_ascend/patch/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,3 +138,4 @@
138138
# Future Plan:
139139
# Remove this patch when adapted vllm version contains the above PR.
140140
#
141+
from vllm_ascend.patch.worker import patch_deepseekv3

0 commit comments

Comments
 (0)