Skip to content

Commit 323c590

Browse files
committed
save global arg for rms_norm out put
1 parent b5ebfdd commit 323c590

File tree

1 file changed

+15
-16
lines changed

1 file changed

+15
-16
lines changed

paddlenlp/transformers/deepseek_v2/modeling_pp.py

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@
7777

7878
global_inputs_embeds_mtp_queue = queue.Queue()
7979

80+
global norm_out
8081

8182
DSV3_USE_FP8_GEMM = os.getenv("DSV3_USE_FP8_GEMM", "False").lower() == "true"
8283
DSV3_USE_FP8_DISPATCH = os.getenv("DSV3_USE_FP8_DISPATCH", "False").lower() == "true"
@@ -183,12 +184,16 @@ def forward_without_residual(self, inputs):
183184
with paddle.no_grad():
184185
if self.shared_experts is not None:
185186
if self.using_post_norm_recompute:
186-
shared_expert_output = FP8LinearFunctionBase.fp8_mlp_fwd_norm_rc(
187-
hidden_states,
188-
self.shared_experts.norm_weight,
189-
self.shared_experts.norm_eps,
190-
self.shared_experts.w1,
191-
self.shared_experts.w2,
187+
global norm_out
188+
# shared_expert_output = FP8LinearFunctionBase.fp8_mlp_fwd_norm_rc(
189+
# hidden_states,
190+
# self.shared_experts.norm_weight,
191+
# self.shared_experts.norm_eps,
192+
# self.shared_experts.w1,
193+
# self.shared_experts.w2,
194+
# )
195+
_, _, shared_expert_output = FP8LinearFunctionBase.fp8_mlp_fwd(
196+
norm_output, self.shared_experts.w1, self.shared_experts.w2
192197
)
193198
else:
194199
_, _, shared_expert_output = FP8LinearFunctionBase.fp8_mlp_fwd(
@@ -566,17 +571,12 @@ def attn_forward(self, inputs):
566571
inputs = self.attn_and_gate_node.forward(inputs)
567572

568573
if self.send_mtp_embed:
569-
if self.using_post_norm_recompute:
570-
inputs_embeds_mtp, hidden_states, residual, probs, routing_map, l_aux, norm_out = inputs
571-
else:
572-
inputs_embeds_mtp, hidden_states, residual, probs, routing_map, l_aux = inputs
574+
inputs_embeds_mtp, hidden_states, residual, probs, routing_map, l_aux = inputs
573575
else:
574-
if self.using_post_norm_recompute:
575-
hidden_states, residual, probs, routing_map, l_aux, norm_out = inputs
576-
else:
577-
hidden_states, residual, probs, routing_map, l_aux = inputs
576+
hidden_states, residual, probs, routing_map, l_aux = inputs
578577

579578
if self.using_post_norm_recompute:
579+
global norm_out
580580
hs_2d, token_indices, token_probs = self.fp8_fusion_moe_node.dispatch_quant_node.forward(
581581
norm_out, probs, routing_map
582582
)
@@ -1222,6 +1222,7 @@ def attn_compute_for_fusion(self, args):
12221222
_, _, d_model = hidden_states.shape
12231223

12241224
if self.using_post_norm_recompute:
1225+
global norm_out
12251226
probs, routing_map, l_aux, _, norm_out = self.mlp.router(hidden_states)
12261227
else:
12271228
probs, routing_map, l_aux, _ = self.mlp.router(hidden_states)
@@ -1236,8 +1237,6 @@ def attn_compute_for_fusion(self, args):
12361237
)
12371238
# append mtp embed if needed
12381239
ret = (inputs_embeds_mtp, *ret) if send_mtp_embed else ret
1239-
# append norm_out if using post_norm recompute
1240-
ret = (*ret, norm_out) if self.using_post_norm_recompute else ret
12411240

12421241
return ret
12431242

0 commit comments

Comments
 (0)