77
77
78
78
global_inputs_embeds_mtp_queue = queue .Queue ()
79
79
80
+ global norm_out
80
81
81
82
DSV3_USE_FP8_GEMM = os .getenv ("DSV3_USE_FP8_GEMM" , "False" ).lower () == "true"
82
83
DSV3_USE_FP8_DISPATCH = os .getenv ("DSV3_USE_FP8_DISPATCH" , "False" ).lower () == "true"
@@ -183,12 +184,16 @@ def forward_without_residual(self, inputs):
183
184
with paddle .no_grad ():
184
185
if self .shared_experts is not None :
185
186
if self .using_post_norm_recompute :
186
- shared_expert_output = FP8LinearFunctionBase .fp8_mlp_fwd_norm_rc (
187
- hidden_states ,
188
- self .shared_experts .norm_weight ,
189
- self .shared_experts .norm_eps ,
190
- self .shared_experts .w1 ,
191
- self .shared_experts .w2 ,
187
+ global norm_out
188
+ # shared_expert_output = FP8LinearFunctionBase.fp8_mlp_fwd_norm_rc(
189
+ # hidden_states,
190
+ # self.shared_experts.norm_weight,
191
+ # self.shared_experts.norm_eps,
192
+ # self.shared_experts.w1,
193
+ # self.shared_experts.w2,
194
+ # )
195
+ _ , _ , shared_expert_output = FP8LinearFunctionBase .fp8_mlp_fwd (
196
+ norm_output , self .shared_experts .w1 , self .shared_experts .w2
192
197
)
193
198
else :
194
199
_ , _ , shared_expert_output = FP8LinearFunctionBase .fp8_mlp_fwd (
@@ -566,17 +571,12 @@ def attn_forward(self, inputs):
566
571
inputs = self .attn_and_gate_node .forward (inputs )
567
572
568
573
if self .send_mtp_embed :
569
- if self .using_post_norm_recompute :
570
- inputs_embeds_mtp , hidden_states , residual , probs , routing_map , l_aux , norm_out = inputs
571
- else :
572
- inputs_embeds_mtp , hidden_states , residual , probs , routing_map , l_aux = inputs
574
+ inputs_embeds_mtp , hidden_states , residual , probs , routing_map , l_aux = inputs
573
575
else :
574
- if self .using_post_norm_recompute :
575
- hidden_states , residual , probs , routing_map , l_aux , norm_out = inputs
576
- else :
577
- hidden_states , residual , probs , routing_map , l_aux = inputs
576
+ hidden_states , residual , probs , routing_map , l_aux = inputs
578
577
579
578
if self .using_post_norm_recompute :
579
+ global norm_out
580
580
hs_2d , token_indices , token_probs = self .fp8_fusion_moe_node .dispatch_quant_node .forward (
581
581
norm_out , probs , routing_map
582
582
)
@@ -1222,6 +1222,7 @@ def attn_compute_for_fusion(self, args):
1222
1222
_ , _ , d_model = hidden_states .shape
1223
1223
1224
1224
if self .using_post_norm_recompute :
1225
+ global norm_out
1225
1226
probs , routing_map , l_aux , _ , norm_out = self .mlp .router (hidden_states )
1226
1227
else :
1227
1228
probs , routing_map , l_aux , _ = self .mlp .router (hidden_states )
@@ -1236,8 +1237,6 @@ def attn_compute_for_fusion(self, args):
1236
1237
)
1237
1238
# append mtp embed if needed
1238
1239
ret = (inputs_embeds_mtp , * ret ) if send_mtp_embed else ret
1239
- # append norm_out if using post_norm recompute
1240
- ret = (* ret , norm_out ) if self .using_post_norm_recompute else ret
1241
1240
1242
1241
return ret
1243
1242
0 commit comments