Skip to content

Commit b6a7f07

Browse files
authored
[Perf][MoE] Improve MoE multistream parallel performace. (vllm-project#1891)
This PR designs the shared expert multi-stream parallelism of w8a8-dynamic-quantized MoE stage in more detail to achieve better performance. - vLLM version: v0.10.0 - vLLM main: vllm-project/vllm@2cc5711 Signed-off-by: whx-sjtu <[email protected]>
1 parent 4df8e00 commit b6a7f07

File tree

3 files changed

+124
-14
lines changed

3 files changed

+124
-14
lines changed

vllm_ascend/models/deepseek_v2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -393,7 +393,7 @@ def forward(self,
393393

394394
# router_logits: (num_tokens, n_experts)
395395
router_logits = None
396-
if not self.rm_router_logits:
396+
if not self.rm_router_logits and not self.enable_multistream_moe:
397397
router_logits, _ = self.gate(hidden_states)
398398

399399
experts_hidden_states = self.experts(

vllm_ascend/ops/fused_moe.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1334,6 +1334,21 @@ def forward(self,
13341334
forward_context = get_forward_context()
13351335
fused_moe_state = forward_context.fused_moe_state
13361336
mc2_mask = forward_context.mc2_mask
1337+
# For w8a8 dynamic we can do npu_dynamic_quant and gate in parallel.
1338+
quantized_x_for_share, dynamic_scale_for_share = None, None
1339+
from vllm_ascend.quantization.w8a8_dynamic import \
1340+
AscendW8A8DynamicFusedMoEMethod
1341+
if self.enable_multistream_moe:
1342+
if not self.rm_router_logits:
1343+
router_logits, _ = gate(hidden_states)
1344+
if hasattr(self.quant_method, "quant_method") and \
1345+
isinstance(self.quant_method.quant_method,
1346+
AscendW8A8DynamicFusedMoEMethod
1347+
) and fused_moe_state == FusedMoEState.MC2:
1348+
with npu_stream_switch("moe_secondary", 0):
1349+
quantized_x_for_share, dynamic_scale_for_share = torch_npu.npu_dynamic_quant(
1350+
hidden_states)
1351+
13371352
if shared_experts:
13381353
if not self.enable_multistream_moe or fused_moe_state != FusedMoEState.MC2:
13391354
# When all_reduce_merge is in progress, shared_experts does not do all_reduce in mlp, but waits until shared_experts+router_experts are completed before doing all_reduce
@@ -1419,6 +1434,8 @@ def forward(self,
14191434
shared_experts=shared_experts if self.torchair_graph_enabled
14201435
and self.enable_multistream_moe and not is_prefill else None,
14211436
mc2_mask=mc2_mask,
1437+
quantized_x_for_share=quantized_x_for_share,
1438+
dynamic_scale_for_share=dynamic_scale_for_share,
14221439
)
14231440

14241441
if shared_experts:

vllm_ascend/quantization/w8a8_dynamic.py

Lines changed: 106 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,82 @@
3333
dispose_tensor, get_ascend_soc_version)
3434

3535

36+
def apply_mlp_decode(hidden_states: torch.Tensor,
37+
w1: torch.Tensor,
38+
w1_scale: torch.Tensor,
39+
w2: torch.Tensor,
40+
w2_scale: torch.Tensor,
41+
group_list: torch.Tensor,
42+
dynamic_scale: torch.Tensor = None,
43+
group_list_type: int = 1) -> torch.Tensor:
44+
"""
45+
apply MLP: gate_up_proj -> swiglu -> down_proj
46+
Args:
47+
hidden_states_wrapper: wrapper of input hidden states with shape (num_tokens, hidden_size).
48+
w1: expert weights1 with shape
49+
(num_experts, hidden_size, intermediate_size * 2)
50+
w1_scale: weights1 scale with shape (num_experts, intermediate_size * 2)
51+
w2: expert weights2 with shape
52+
(num_experts, intermediate_size, hidden_size)
53+
w2_scale: weights2 scale with shape (num_experts, hidden_size)
54+
group_list: number of tokens for each expert, follow cumsum mode, and
55+
with shape (num_experts).
56+
transpose_weight:
57+
w1: (num_experts, intermediate_size * 2, hidden_size) ->
58+
(num_experts, hidden_size, intermediate_size * 2)
59+
w2: (num_experts, hidden_size, intermediate_size) ->
60+
(num_experts, intermediate_size, hidden_size)
61+
Returns:
62+
hidden_states: output hidden states after MLP.
63+
"""
64+
65+
if dynamic_scale is None:
66+
unquantized_hidden_states = hidden_states
67+
hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(
68+
hidden_states)
69+
# Dispose the original unquantized hidden states
70+
# to save npu memory because they're no longer used.
71+
dispose_tensor(unquantized_hidden_states)
72+
else:
73+
pertoken_scale = dynamic_scale
74+
75+
# gmm1: gate_up_proj
76+
hidden_states = torch_npu.npu_grouped_matmul(
77+
x=[hidden_states],
78+
weight=[w1],
79+
split_item=3,
80+
group_list_type=group_list_type,
81+
group_type=0,
82+
group_list=group_list,
83+
output_dtype=torch.int32)[0]
84+
85+
# act_fn: swiglu
86+
hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant(
87+
x=hidden_states,
88+
weight_scale=w1_scale,
89+
activation_scale=pertoken_scale,
90+
bias=None,
91+
quant_scale=None,
92+
quant_offset=None,
93+
group_index=group_list,
94+
activate_left=True,
95+
quant_mode=1,
96+
)
97+
98+
# gmm2: down_proj
99+
hidden_states = torch_npu.npu_grouped_matmul(
100+
x=[hidden_states],
101+
weight=[w2],
102+
scale=[w2_scale],
103+
per_token_scale=[swiglu_out_scale],
104+
split_item=2,
105+
group_list_type=group_list_type,
106+
group_type=0,
107+
group_list=group_list,
108+
output_dtype=w2_scale.dtype)[0]
109+
return hidden_states
110+
111+
36112
def apply_mlp(hidden_states: torch.Tensor,
37113
w1: torch.Tensor,
38114
w1_scale: torch.Tensor,
@@ -124,6 +200,8 @@ def fused_experts_with_mc2(
124200
quantized_x_for_share: Optional[Any] = None,
125201
dynamic_scale_for_share: Optional[Any] = None,
126202
mc2_mask: Optional[torch.Tensor] = None,
203+
shared_gate_up: Optional[Any] = None,
204+
shared_dequant_scale: Optional[Any] = None,
127205
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
128206
assert mc2_mask is not None
129207
if log2phy is not None:
@@ -186,18 +264,19 @@ def fused_experts_with_mc2(
186264

187265
if shared_experts is not None:
188266
with npu_stream_switch("moe_secondary", 0):
189-
npu_wait_tensor(quantized_x_for_share, expand_x)
267+
npu_wait_tensor(shared_gate_up, expand_x)
190268
shared_act_out = shared_experts.act_fn(
191-
(quantized_x_for_share, dynamic_scale_for_share))
269+
(shared_gate_up, shared_dequant_scale))
192270
shared_act, swiglu_out_scale = shared_act_out[0], shared_act_out[1]
193271

194-
down_out_list = apply_mlp(expand_x,
195-
w1,
196-
w1_scale,
197-
w2,
198-
w2_scale,
199-
expert_token_nums,
200-
dynamic_scale=dynamic_scale)
272+
# `expand_x` will be disposed in the `apply_mlp` function
273+
down_out_list = apply_mlp_decode(expand_x,
274+
w1,
275+
w1_scale,
276+
w2,
277+
w2_scale,
278+
expert_token_nums,
279+
dynamic_scale=dynamic_scale)
201280

202281
# moeCombine
203282
kwargs_mc2 = {
@@ -745,6 +824,8 @@ def apply(
745824
log2phy: torch.Tensor = None,
746825
global_redundant_expert_num: int = 0,
747826
shared_experts: Optional[Any] = None,
827+
quantized_x_for_share: Optional[Any] = None,
828+
dynamic_scale_for_share: Optional[Any] = None,
748829
**kwargs,
749830
) -> torch.Tensor:
750831
assert router_logits.shape[
@@ -781,15 +862,23 @@ def apply(
781862
e_score_correction_bias=e_score_correction_bias,
782863
)
783864

865+
fused_moe_state = get_forward_context().fused_moe_state
866+
shared_gate_up, shared_dequant_scale = None, None
867+
if shared_experts is not None and fused_moe_state == FusedMoEState.MC2:
868+
with npu_stream_switch("moe_secondary", 0):
869+
npu_wait_tensor(quantized_x_for_share, router_logits)
870+
share_up_out, _ = shared_experts.gate_up_proj(
871+
(quantized_x_for_share, dynamic_scale_for_share))
872+
shared_gate_up, shared_dequant_scale = share_up_out[
873+
0], share_up_out[1]
874+
784875
# this is a naive implementation for experts load balance so as
785876
# to avoid accumulating too much tokens on a single rank.
786877
# currently it is only activated when doing profile runs.
787878
if enable_force_load_balance:
788879
topk_ids = torch.randint_like(topk_ids, 0, global_num_experts)
789880

790881
topk_weights = topk_weights.to(x.dtype)
791-
792-
fused_moe_state = get_forward_context().fused_moe_state
793882
if fused_moe_state == FusedMoEState.AllGatherEP:
794883
return fused_experts_with_allgather(
795884
hidden_states=x,
@@ -806,7 +895,7 @@ def apply(
806895
hidden_states=x,
807896
w1=layer.w13_weight,
808897
w2=layer.w2_weight,
809-
w1_scale=layer.w13_weight_scale,
898+
w1_scale=layer.w13_weight_scale_fp32,
810899
w2_scale=layer.w2_weight_scale,
811900
topk_weights=topk_weights,
812901
topk_ids=topk_ids,
@@ -817,7 +906,9 @@ def apply(
817906
global_redundant_expert_num=global_redundant_expert_num,
818907
shared_experts=shared_experts,
819908
is_torchair=self.torchair_graph_enabled,
820-
mc2_mask=kwargs.get("mc2_mask", None))
909+
mc2_mask=kwargs.get("mc2_mask", None),
910+
shared_gate_up=shared_gate_up,
911+
shared_dequant_scale=shared_dequant_scale)
821912
elif fused_moe_state in [
822913
FusedMoEState.AllGather, FusedMoEState.NaiveMulticast
823914
]:
@@ -860,6 +951,8 @@ def process_weights_after_loading(self, layer):
860951
torch_npu.npu_format_cast_(layer.w2_weight, ACL_FORMAT_FRACTAL_NZ)
861952
layer.w13_weight_scale.data = layer.w13_weight_scale.data.view(
862953
layer.w13_weight_scale.data.shape[0], -1)
954+
layer.w13_weight_scale_fp32 = layer.w13_weight_scale.data.to(
955+
torch.float32)
863956
layer.w13_weight_offset.data = layer.w13_weight_offset.data.view(
864957
layer.w13_weight_offset.data.shape[0], -1)
865958
layer.w2_weight_scale.data = layer.w2_weight_scale.data.view(

0 commit comments

Comments
 (0)