33
33
dispose_tensor , get_ascend_soc_version )
34
34
35
35
36
+ def apply_mlp_decode (hidden_states : torch .Tensor ,
37
+ w1 : torch .Tensor ,
38
+ w1_scale : torch .Tensor ,
39
+ w2 : torch .Tensor ,
40
+ w2_scale : torch .Tensor ,
41
+ group_list : torch .Tensor ,
42
+ dynamic_scale : torch .Tensor = None ,
43
+ group_list_type : int = 1 ) -> torch .Tensor :
44
+ """
45
+ apply MLP: gate_up_proj -> swiglu -> down_proj
46
+ Args:
47
+ hidden_states_wrapper: wrapper of input hidden states with shape (num_tokens, hidden_size).
48
+ w1: expert weights1 with shape
49
+ (num_experts, hidden_size, intermediate_size * 2)
50
+ w1_scale: weights1 scale with shape (num_experts, intermediate_size * 2)
51
+ w2: expert weights2 with shape
52
+ (num_experts, intermediate_size, hidden_size)
53
+ w2_scale: weights2 scale with shape (num_experts, hidden_size)
54
+ group_list: number of tokens for each expert, follow cumsum mode, and
55
+ with shape (num_experts).
56
+ transpose_weight:
57
+ w1: (num_experts, intermediate_size * 2, hidden_size) ->
58
+ (num_experts, hidden_size, intermediate_size * 2)
59
+ w2: (num_experts, hidden_size, intermediate_size) ->
60
+ (num_experts, intermediate_size, hidden_size)
61
+ Returns:
62
+ hidden_states: output hidden states after MLP.
63
+ """
64
+
65
+ if dynamic_scale is None :
66
+ unquantized_hidden_states = hidden_states
67
+ hidden_states , pertoken_scale = torch_npu .npu_dynamic_quant (
68
+ hidden_states )
69
+ # Dispose the original unquantized hidden states
70
+ # to save npu memory because they're no longer used.
71
+ dispose_tensor (unquantized_hidden_states )
72
+ else :
73
+ pertoken_scale = dynamic_scale
74
+
75
+ # gmm1: gate_up_proj
76
+ hidden_states = torch_npu .npu_grouped_matmul (
77
+ x = [hidden_states ],
78
+ weight = [w1 ],
79
+ split_item = 3 ,
80
+ group_list_type = group_list_type ,
81
+ group_type = 0 ,
82
+ group_list = group_list ,
83
+ output_dtype = torch .int32 )[0 ]
84
+
85
+ # act_fn: swiglu
86
+ hidden_states , swiglu_out_scale = torch_npu .npu_dequant_swiglu_quant (
87
+ x = hidden_states ,
88
+ weight_scale = w1_scale ,
89
+ activation_scale = pertoken_scale ,
90
+ bias = None ,
91
+ quant_scale = None ,
92
+ quant_offset = None ,
93
+ group_index = group_list ,
94
+ activate_left = True ,
95
+ quant_mode = 1 ,
96
+ )
97
+
98
+ # gmm2: down_proj
99
+ hidden_states = torch_npu .npu_grouped_matmul (
100
+ x = [hidden_states ],
101
+ weight = [w2 ],
102
+ scale = [w2_scale ],
103
+ per_token_scale = [swiglu_out_scale ],
104
+ split_item = 2 ,
105
+ group_list_type = group_list_type ,
106
+ group_type = 0 ,
107
+ group_list = group_list ,
108
+ output_dtype = w2_scale .dtype )[0 ]
109
+ return hidden_states
110
+
111
+
36
112
def apply_mlp (hidden_states : torch .Tensor ,
37
113
w1 : torch .Tensor ,
38
114
w1_scale : torch .Tensor ,
@@ -124,6 +200,8 @@ def fused_experts_with_mc2(
124
200
quantized_x_for_share : Optional [Any ] = None ,
125
201
dynamic_scale_for_share : Optional [Any ] = None ,
126
202
mc2_mask : Optional [torch .Tensor ] = None ,
203
+ shared_gate_up : Optional [Any ] = None ,
204
+ shared_dequant_scale : Optional [Any ] = None ,
127
205
) -> Union [torch .Tensor , Tuple [torch .Tensor , torch .Tensor ]]:
128
206
assert mc2_mask is not None
129
207
if log2phy is not None :
@@ -186,18 +264,19 @@ def fused_experts_with_mc2(
186
264
187
265
if shared_experts is not None :
188
266
with npu_stream_switch ("moe_secondary" , 0 ):
189
- npu_wait_tensor (quantized_x_for_share , expand_x )
267
+ npu_wait_tensor (shared_gate_up , expand_x )
190
268
shared_act_out = shared_experts .act_fn (
191
- (quantized_x_for_share , dynamic_scale_for_share ))
269
+ (shared_gate_up , shared_dequant_scale ))
192
270
shared_act , swiglu_out_scale = shared_act_out [0 ], shared_act_out [1 ]
193
271
194
- down_out_list = apply_mlp (expand_x ,
195
- w1 ,
196
- w1_scale ,
197
- w2 ,
198
- w2_scale ,
199
- expert_token_nums ,
200
- dynamic_scale = dynamic_scale )
272
+ # `expand_x` will be disposed in the `apply_mlp` function
273
+ down_out_list = apply_mlp_decode (expand_x ,
274
+ w1 ,
275
+ w1_scale ,
276
+ w2 ,
277
+ w2_scale ,
278
+ expert_token_nums ,
279
+ dynamic_scale = dynamic_scale )
201
280
202
281
# moeCombine
203
282
kwargs_mc2 = {
@@ -745,6 +824,8 @@ def apply(
745
824
log2phy : torch .Tensor = None ,
746
825
global_redundant_expert_num : int = 0 ,
747
826
shared_experts : Optional [Any ] = None ,
827
+ quantized_x_for_share : Optional [Any ] = None ,
828
+ dynamic_scale_for_share : Optional [Any ] = None ,
748
829
** kwargs ,
749
830
) -> torch .Tensor :
750
831
assert router_logits .shape [
@@ -781,15 +862,23 @@ def apply(
781
862
e_score_correction_bias = e_score_correction_bias ,
782
863
)
783
864
865
+ fused_moe_state = get_forward_context ().fused_moe_state
866
+ shared_gate_up , shared_dequant_scale = None , None
867
+ if shared_experts is not None and fused_moe_state == FusedMoEState .MC2 :
868
+ with npu_stream_switch ("moe_secondary" , 0 ):
869
+ npu_wait_tensor (quantized_x_for_share , router_logits )
870
+ share_up_out , _ = shared_experts .gate_up_proj (
871
+ (quantized_x_for_share , dynamic_scale_for_share ))
872
+ shared_gate_up , shared_dequant_scale = share_up_out [
873
+ 0 ], share_up_out [1 ]
874
+
784
875
# this is a naive implementation for experts load balance so as
785
876
# to avoid accumulating too much tokens on a single rank.
786
877
# currently it is only activated when doing profile runs.
787
878
if enable_force_load_balance :
788
879
topk_ids = torch .randint_like (topk_ids , 0 , global_num_experts )
789
880
790
881
topk_weights = topk_weights .to (x .dtype )
791
-
792
- fused_moe_state = get_forward_context ().fused_moe_state
793
882
if fused_moe_state == FusedMoEState .AllGatherEP :
794
883
return fused_experts_with_allgather (
795
884
hidden_states = x ,
@@ -806,7 +895,7 @@ def apply(
806
895
hidden_states = x ,
807
896
w1 = layer .w13_weight ,
808
897
w2 = layer .w2_weight ,
809
- w1_scale = layer .w13_weight_scale ,
898
+ w1_scale = layer .w13_weight_scale_fp32 ,
810
899
w2_scale = layer .w2_weight_scale ,
811
900
topk_weights = topk_weights ,
812
901
topk_ids = topk_ids ,
@@ -817,7 +906,9 @@ def apply(
817
906
global_redundant_expert_num = global_redundant_expert_num ,
818
907
shared_experts = shared_experts ,
819
908
is_torchair = self .torchair_graph_enabled ,
820
- mc2_mask = kwargs .get ("mc2_mask" , None ))
909
+ mc2_mask = kwargs .get ("mc2_mask" , None ),
910
+ shared_gate_up = shared_gate_up ,
911
+ shared_dequant_scale = shared_dequant_scale )
821
912
elif fused_moe_state in [
822
913
FusedMoEState .AllGather , FusedMoEState .NaiveMulticast
823
914
]:
@@ -860,6 +951,8 @@ def process_weights_after_loading(self, layer):
860
951
torch_npu .npu_format_cast_ (layer .w2_weight , ACL_FORMAT_FRACTAL_NZ )
861
952
layer .w13_weight_scale .data = layer .w13_weight_scale .data .view (
862
953
layer .w13_weight_scale .data .shape [0 ], - 1 )
954
+ layer .w13_weight_scale_fp32 = layer .w13_weight_scale .data .to (
955
+ torch .float32 )
863
956
layer .w13_weight_offset .data = layer .w13_weight_offset .data .view (
864
957
layer .w13_weight_offset .data .shape [0 ], - 1 )
865
958
layer .w2_weight_scale .data = layer .w2_weight_scale .data .view (
0 commit comments