Skip to content

Commit 87fa744

Browse files
chen2016013root
andauthored
Fix mtp bug when send_mtp_embed=True (#10909)
* Fix mtp bug * fix mtp bug * Update fp8_utils.py * fix mtp bug --------- Co-authored-by: root <[email protected]>
1 parent 2ff16f8 commit 87fa744

File tree

2 files changed

+51
-27
lines changed

2 files changed

+51
-27
lines changed

paddlenlp/transformers/deepseek_v2/modeling_pp.py

Lines changed: 32 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,7 @@ def forward_without_residual(self, inputs):
192192

193193
if self.send_mtp_embed:
194194
hidden_states = paddle.concat([hidden_states, inputs_embeds_mtp], axis=-1)
195+
self.mtp_embed_shape = inputs_embeds_mtp.shape # 保存mtp_embed的shape用于反向传播
195196

196197
return return_args(hidden_states)
197198

@@ -227,37 +228,47 @@ def forward(self, inputs):
227228

228229
if self.send_mtp_embed:
229230
hidden_states = paddle.concat([hidden_states, inputs_embeds_mtp], axis=-1)
231+
self.mtp_embed_shape = inputs_embeds_mtp.shape # 保存mtp_embed的shape用于反向传播
230232

231233
return return_args(hidden_states)
232234

233235
@paddle.no_grad()
234236
def backward(self, output_grad):
235237
(do3,) = output_grad
236238

237-
assert not self.send_mtp_embed, "not support have mtp have yet"
239+
if self.send_mtp_embed:
240+
# 分割梯度:do3的前部分对应hidden_states,后部分对应inputs_embeds_mtp
241+
hidden_size = do3.shape[-1] - self.mtp_embed_shape[-1]
242+
hidden_states_grad = do3[..., :hidden_size]
243+
inputs_embeds_mtp_grad = do3[..., hidden_size:]
244+
else:
245+
hidden_states_grad = do3
246+
inputs_embeds_mtp_grad = None
247+
238248
if self.using_post_norm_recompute:
239249
dx = FP8LinearFunctionBase.fp8_mlp_bwd_norm_rc(
240-
do3,
250+
hidden_states_grad,
241251
self.x,
242252
self.shared_experts.norm_weight,
243253
self.shared_experts.norm_eps,
244254
self.shared_experts.w1,
245255
self.shared_experts.w2,
246256
)
247257
else:
248-
dx = FP8LinearFunctionBase.fp8_mlp_bwd(do3, self.x, self.shared_experts.w1, self.shared_experts.w2)
258+
dx = FP8LinearFunctionBase.fp8_mlp_bwd(
259+
hidden_states_grad, self.x, self.shared_experts.w1, self.shared_experts.w2, True
260+
)
249261

250262
self.x = None
251263

252-
residual_grad = do3
253-
254-
hidden_states_grad = dx
255-
264+
residual_grad = hidden_states_grad
256265
l_aux_grad = paddle.ones(1, dtype=self.l_aux.dtype) * self.alpha
266+
final_hidden_states_grad = hidden_states_grad
257267

258-
final_hidden_states_grad = do3
259-
260-
return (hidden_states_grad, residual_grad, l_aux_grad, final_hidden_states_grad)
268+
if self.send_mtp_embed:
269+
return (inputs_embeds_mtp_grad, dx, residual_grad, l_aux_grad, final_hidden_states_grad)
270+
else:
271+
return (dx, residual_grad, l_aux_grad, final_hidden_states_grad)
261272

262273

263274
class DecoderLayerNode(ScheduleNode):
@@ -749,6 +760,9 @@ def attn_backward(self, output_grad):
749760
hs_grad,
750761
token_probs_grad,
751762
) = output_grad
763+
inputs_embeds_mtp_grad_shape = hidden_states_grad.shape
764+
inputs_embeds_mtp_grad_shape[-1] = -1
765+
inputs_embeds_mtp_grad = inputs_embeds_mtp_grad.view(inputs_embeds_mtp_grad_shape)
752766
else:
753767
hidden_states_grad, residual_grad, l_aux_grad, hs_grad, token_probs_grad = output_grad
754768

@@ -906,8 +920,11 @@ def forward_backward(self, inputs, output_grad, combine_bw_event_to_wait=None, p
906920
combine_forward_event.calc_stream_wait(self.forward_node.moe_group.id)
907921

908922
final_out = self.forward_node.post_process_node.forward_without_residual(inputs)
909-
inputs = final_out + combine_fwd_out
910-
923+
if final_out.shape[-1] != combine_fwd_out.shape[-1]:
924+
final_out[:, :, : combine_fwd_out.shape[-1]] += combine_fwd_out # 直接广播并相加
925+
else:
926+
final_out += combine_fwd_out
927+
inputs = final_out
911928
combine_fwd_out._record_stream()
912929

913930
paddle.base.core.nvprof_nvtx_pop()
@@ -1072,7 +1089,7 @@ def forward(self, args):
10721089
if self.config.send_mtp_embed:
10731090
batch_size, _, hidden_size = hidden_states.shape
10741091
batch_size_mtp = hidden_size // (self.config.num_nextn_predict_layers + 1)
1075-
inputs_embeds_mtp = hidden_states[..., -batch_size_mtp:]
1092+
inputs_embeds_mtp = hidden_states[..., batch_size_mtp:]
10761093
hidden_states = hidden_states[..., :batch_size_mtp]
10771094

10781095
has_gradient = not hidden_states.stop_gradient
@@ -1129,7 +1146,7 @@ def attn_compute(self, args):
11291146

11301147
batch_size, _, hidden_size = hidden_states.shape
11311148
batch_size_mtp = hidden_size // (self.config.num_nextn_predict_layers + 1)
1132-
inputs_embeds_mtp = hidden_states[..., -batch_size_mtp:]
1149+
inputs_embeds_mtp = hidden_states[..., batch_size_mtp:]
11331150
hidden_states = hidden_states[..., :batch_size_mtp]
11341151

11351152
def attn_compute_func(hidden_states):
@@ -1162,7 +1179,7 @@ def attn_compute_for_fusion(self, args):
11621179
# slice from holy tensor
11631180
batch_size, _, hidden_size = hidden_states.shape
11641181
batch_size_mtp = hidden_size // (self.config.num_nextn_predict_layers + 1)
1165-
inputs_embeds_mtp = hidden_states[..., -batch_size_mtp:]
1182+
inputs_embeds_mtp = hidden_states[..., batch_size_mtp:]
11661183
hidden_states = hidden_states[..., :batch_size_mtp]
11671184

11681185
hidden_states, residual = self.self_attn_compute(hidden_states)

paddlenlp/transformers/fp8_utils.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,7 @@ def fp8_mlp_fwd_norm_rc(x, norm_w, norm_eps, w1, w2):
327327
return o3
328328

329329
@staticmethod
330-
def fp8_mlp_bwd(do3, x, w1, w2):
330+
def fp8_mlp_bwd(do3, x, w1, w2, apply_backward_hook=False):
331331
do3_orig_shape = do3.shape
332332
do3 = do3.reshape([-1, do3_orig_shape[-1]])
333333

@@ -336,22 +336,29 @@ def fp8_mlp_bwd(do3, x, w1, w2):
336336

337337
x_fp8, x_scale, x_t_fp8, x_t_scale = FP8LinearFunctionBase.padding_and_quant_input(x)
338338

339-
dx = FP8LinearFunctionBase.common_fp8_mlp_bwd(
340-
do3, x_fp8, x_scale, x_t_fp8, x_t_scale, w1, w2, apply_backward_hook=True
341-
)
342-
343-
if len(x_orig_shape) > 2:
344-
dx = dx.reshape([x_orig_shape[0], -1, dx.shape[-1]])
339+
if apply_backward_hook:
340+
dx = FP8LinearFunctionBase.common_fp8_mlp_bwd(
341+
do3, x_fp8, x_scale, x_t_fp8, x_t_scale, w1, w2, apply_backward_hook=apply_backward_hook
342+
)
343+
if len(x_orig_shape) > 2:
344+
dx = dx.reshape([x_orig_shape[0], -1, dx.shape[-1]])
345+
return dx
346+
else:
347+
dx, dw1, dw2 = FP8LinearFunctionBase.common_fp8_mlp_bwd(
348+
do3, x_fp8, x_scale, x_t_fp8, x_t_scale, w1, w2, apply_backward_hook=apply_backward_hook
349+
)
350+
if len(x_orig_shape) > 2:
351+
dx = dx.reshape([x_orig_shape[0], -1, dx.shape[-1]])
345352

346-
return dx
353+
return dx, dw1, dw2
347354

348355
@staticmethod
349356
def fp8_mlp_bwd_norm_rc(do3, x, norm_w, norm_eps, w1, w2):
350357
# ===== recompute norm_output =====
351358
norm_output, invar = fused_ln.fused_rms_norm(x, norm_w, norm_eps)
352359

353360
# ===== compute fp8_mlp_fwd =====
354-
d_norm_output = FP8LinearFunctionBase.fp8_mlp_bwd(do3, norm_output, w1, w2)
361+
d_norm_output = FP8LinearFunctionBase.fp8_mlp_bwd(do3, norm_output, w1, w2, True)
355362

356363
# ===== compute norm grad =====
357364
dx, d_rms_norm_weight = fused_ln.fused_rms_norm_grad_func(x, norm_w, invar, d_norm_output, norm_eps)
@@ -480,7 +487,7 @@ def forward(ctx, x, norm_w, w1, w2, norm_eps):
480487
norm_output = norm_output.reshape([-1, x_orig_shape[-1]])
481488

482489
# ===== call func fp8_mlp_fwd =====
483-
o3, _, _ = FP8LinearFunctionBase.fp8_mlp_fwd(norm_output, w1, w2)
490+
_, _, o3 = FP8LinearFunctionBase.fp8_mlp_fwd(norm_output, w1, w2)
484491

485492
# ===== reshape to origin shape =====
486493
if len(x_orig_shape) > 2:
@@ -517,7 +524,7 @@ def backward(ctx, do3):
517524
)
518525

519526
# ===== call func common_fp8_mlp_bwd =====
520-
d_norm_output, dw1, dw2 = FP8LinearFunctionBase.fp8_mlp_bwd(do3, x_fp8, x_scale, x_t_fp8, x_t_scale, w1, w2)
527+
d_norm_output, dw1, dw2 = FP8LinearFunctionBase.common_fp8_mlp_bwd(do3, x_fp8, x_scale, x_t_fp8, x_t_scale, w1, w2)
521528

522529
# ===== reshape to origin shape =====
523530
if len(x_orig_shape) > 2:
@@ -574,7 +581,7 @@ def backward(ctx, do3):
574581
)
575582

576583
# ===== call func common_fp8_mlp_bwd =====
577-
dx, dw1, dw2 = FP8LinearFunctionBase.common_fp8_mlp_bwd(do3, x_fp8, x_scale, x_t_fp8, x_t_scale, w1, w2)
584+
dx, dw1, dw2 = FP8LinearFunctionBase.common_fp8_mlp_bwd(do3, x_fp8, x_scale, x_t_fp8, x_t_scale, w1, w2, False)
578585

579586
# ===== reshape to origin shape =====
580587
if len(x_orig_shape) > 2:

0 commit comments

Comments
 (0)