Skip to content

Commit 6a3fb15

Browse files
authored
not recomput o1 (#10953)
* not recomput o1 * not recomput o1 * not recomput o1 * not recomput o1 * not recomput o1 * fix
1 parent e91f55a commit 6a3fb15

File tree

2 files changed

+105
-40
lines changed

2 files changed

+105
-40
lines changed

paddlenlp/transformers/deepseek_v2/modeling_pp.py

Lines changed: 49 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,8 @@
4949
DeepseekV2PretrainedModel,
5050
DeepseekV2PretrainingCriterion,
5151
DeepseekV2RMSNorm,
52-
set_global_step,
5352
TemporaryVarContext,
53+
set_global_step,
5454
)
5555

5656
try:
@@ -187,13 +187,13 @@ def forward_without_residual(self, inputs):
187187
with paddle.no_grad():
188188
if self.shared_experts is not None:
189189
if self.using_post_norm_recompute:
190-
_, _, shared_expert_output = FP8LinearFunctionBase.fp8_mlp_fwd(
190+
_, _, _, shared_expert_output = FP8LinearFunctionBase.fp8_mlp_fwd(
191191
norm_out, self.shared_experts.w1, self.shared_experts.w2
192192
)
193193
norm_out = None
194194
del norm_out
195195
else:
196-
_, _, shared_expert_output = FP8LinearFunctionBase.fp8_mlp_fwd(
196+
_, _, _, shared_expert_output = FP8LinearFunctionBase.fp8_mlp_fwd(
197197
hidden_states, self.shared_experts.w1, self.shared_experts.w2
198198
)
199199
residual = residual + shared_expert_output
@@ -229,13 +229,13 @@ def forward(self, inputs):
229229
with paddle.no_grad():
230230
if self.shared_experts is not None:
231231
if self.using_post_norm_recompute:
232-
_, _, shared_expert_output = FP8LinearFunctionBase.fp8_mlp_fwd(
232+
_, _, _, shared_expert_output = FP8LinearFunctionBase.fp8_mlp_fwd(
233233
norm_out, self.shared_experts.w1, self.shared_experts.w2
234234
)
235235
norm_out = None
236236
del norm_out
237237
else:
238-
_, _, shared_expert_output = FP8LinearFunctionBase.fp8_mlp_fwd(
238+
_, _, _, shared_expert_output = FP8LinearFunctionBase.fp8_mlp_fwd(
239239
hidden_states, self.shared_experts.w1, self.shared_experts.w2
240240
)
241241
final_hidden_states = final_hidden_states + shared_expert_output
@@ -282,10 +282,18 @@ def backward(self, output_grad):
282282
residual_grad = hidden_states_grad
283283
l_aux_grad = paddle.ones(1, dtype=self.l_aux.dtype) * self.alpha
284284
final_hidden_states_grad = hidden_states_grad
285-
285+
286286
if self.using_post_norm_recompute:
287287
if self.send_mtp_embed:
288-
return (inputs_embeds_mtp_grad, dx, residual_grad, l_aux_grad, final_hidden_states_grad, norm_out, invar)
288+
return (
289+
inputs_embeds_mtp_grad,
290+
dx,
291+
residual_grad,
292+
l_aux_grad,
293+
final_hidden_states_grad,
294+
norm_out,
295+
invar,
296+
)
289297
else:
290298
return (dx, residual_grad, l_aux_grad, final_hidden_states_grad, norm_out, invar)
291299
else:
@@ -724,7 +732,6 @@ def post_process_forward(self, inputs, with_residual=True):
724732
inputs = (inputs_embeds_mtp, *inputs) if self.send_mtp_embed else inputs
725733
inputs = (*inputs, norm_out) if self.using_post_norm_recompute else inputs
726734

727-
728735
if with_residual:
729736
inputs = self.post_process_node.forward(inputs)
730737
else:
@@ -736,7 +743,15 @@ def post_process_backward(self, output_grad, event_to_wait=None):
736743

737744
if self.using_post_norm_recompute:
738745
if self.send_mtp_embed:
739-
inputs_embeds_mtp_grad, hidden_states_grad, residual_grad, l_aux_grad, final_hidden_states_grad, norm_out, invar = grad
746+
(
747+
inputs_embeds_mtp_grad,
748+
hidden_states_grad,
749+
residual_grad,
750+
l_aux_grad,
751+
final_hidden_states_grad,
752+
norm_out,
753+
invar,
754+
) = grad
740755
else:
741756
hidden_states_grad, residual_grad, l_aux_grad, final_hidden_states_grad, norm_out, invar = grad
742757
else:
@@ -815,17 +830,30 @@ def combine_backward(self, output_grad, previous_event=None, async_finish=False,
815830
def mlp_backward(self, output_grad):
816831
if self.using_post_norm_recompute:
817832
if self.send_mtp_embed:
818-
inputs_embeds_mtp_grad, hidden_states_grad, residual_grad, l_aux_grad, hidden_states_out_grad, norm_out, invar = output_grad
833+
(
834+
inputs_embeds_mtp_grad,
835+
hidden_states_grad,
836+
residual_grad,
837+
l_aux_grad,
838+
hidden_states_out_grad,
839+
norm_out,
840+
invar,
841+
) = output_grad
819842
else:
820843
hidden_states_grad, residual_grad, l_aux_grad, hidden_states_out_grad, norm_out, invar = output_grad
821844
else:
822845
if self.send_mtp_embed:
823-
inputs_embeds_mtp_grad, hidden_states_grad, residual_grad, l_aux_grad, hidden_states_out_grad = output_grad
846+
(
847+
inputs_embeds_mtp_grad,
848+
hidden_states_grad,
849+
residual_grad,
850+
l_aux_grad,
851+
hidden_states_out_grad,
852+
) = output_grad
824853
else:
825854
hidden_states_grad, residual_grad, l_aux_grad, hidden_states_out_grad = output_grad
826855
hs_dispatched_grad, dispatched_probs_grad = self.fp8_fusion_moe_node.mlp_node.backward(hidden_states_out_grad)
827856

828-
829857
ret = (hidden_states_grad, residual_grad, l_aux_grad, hs_dispatched_grad, dispatched_probs_grad)
830858
ret = (inputs_embeds_mtp_grad, *ret) if self.send_mtp_embed else ret
831859
ret = (*ret, norm_out, invar) if self.using_post_norm_recompute else ret
@@ -845,7 +873,15 @@ def dispatch_backward(self, output_grad, async_finish=False, previous_event=None
845873
invar,
846874
) = output_grad
847875
else:
848-
hidden_states_grad, residual_grad, l_aux_grad, hs_dispatched_grad, dispatched_probs_grad, norm_out, invar = output_grad
876+
(
877+
hidden_states_grad,
878+
residual_grad,
879+
l_aux_grad,
880+
hs_dispatched_grad,
881+
dispatched_probs_grad,
882+
norm_out,
883+
invar,
884+
) = output_grad
849885
else:
850886
if self.send_mtp_embed:
851887
(

paddlenlp/transformers/fp8_utils.py

Lines changed: 56 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,17 @@
11
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2-
#
2+
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
55
# You may obtain a copy of the License at
6-
#
6+
77
# http://www.apache.org/licenses/LICENSE-2.0
8-
#
8+
99
# Unless required by applicable law or agreed to in writing, software
1010
# distributed under the License is distributed on an "AS IS" BASIS,
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
1415
import os
1516
from functools import partial
1617

@@ -301,16 +302,22 @@ def compute_expert_w_grad(
301302
return result
302303

303304
@staticmethod
304-
def common_fp8_mlp_bwd(do3, x_fp8, x_scale, x_t_fp8, x_t_scale, w1, w2, apply_backward_hook=False):
305+
def common_fp8_mlp_bwd(
306+
do3, x_t_fp8, x_t_scale, w1, w2, o1=None, x_fp8=None, x_scale=None, apply_backward_hook=False
307+
):
308+
if o1 is not None and (x_fp8 is not None or x_scale is not None):
309+
raise ValueError("When o1 is provided, both x_fp8 and x_scale must be None.")
305310

306-
# # ===== [recompute] o1 = deep_gemm(x_fp8, w1_t_fp8) =====
307-
# o1, x_t_fp8, x_t_scale = FP8LinearFunctionBase.compute_fp8_linear(
308-
# x, w1, weight_transpose=True, return_transpose_only=True, return_mode="with_input_transpose_quant"
309-
# )
311+
if o1 is None:
312+
if x_fp8 is None or x_scale is None:
313+
raise ValueError("When o1 is None, both x_fp8 and x_scale must be provided.")
310314

311-
w1_fp8, w1_scale = weight_quant(w1, True)
312-
o1 = paddle.empty([x_fp8.shape[0], w1_fp8.shape[0]], dtype=do3.dtype)
313-
deep_gemm.gemm_fp8_fp8_bf16_nt((x_fp8, x_scale.T), (w1_fp8, w1_scale), o1, num_sms=118)
315+
# # ===== [recompute] o1 = deep_gemm(x_fp8, w1_t_fp8) =====
316+
317+
# Recompute o1 using deep_gemm(x_fp8, w1_t_fp8)
318+
w1_fp8, w1_scale = weight_quant(w1, True)
319+
o1 = paddle.empty([x_fp8.shape[0], w1_fp8.shape[0]], dtype=do3.dtype)
320+
deep_gemm.gemm_fp8_fp8_bf16_nt((x_fp8, x_scale.T), (w1_fp8, w1_scale), o1, num_sms=118)
314321

315322
# ===== [recompute] o2 = swiglu(o1) =====
316323
o2 = swiglu(o1)
@@ -409,7 +416,15 @@ def fp8_mlp_fwd(x, w1, w2):
409416
if len(x_orig_shape) > 2:
410417
o3 = o3.reshape([x_orig_shape[0], -1, o3.shape[-1]])
411418

412-
return x_fp8, x_scale, o3
419+
return o1, x_fp8, x_scale, o3
420+
421+
@staticmethod
422+
def fp8_mlp_fwd_norm_rc(x, norm_w, norm_eps, w1, w2):
423+
# ===== compute norm_output =====
424+
norm_output, _ = fused_ln.fused_rms_norm(x, norm_w, norm_eps)
425+
# ===== compute fp8_mlp_fwd =====
426+
_, _, _, o3 = FP8LinearFunctionBase.fp8_mlp_fwd(norm_output, w1, w2)
427+
return o3
413428

414429
@staticmethod
415430
def fp8_mlp_bwd(do3, x, w1, w2, apply_backward_hook=False):
@@ -423,14 +438,30 @@ def fp8_mlp_bwd(do3, x, w1, w2, apply_backward_hook=False):
423438

424439
if apply_backward_hook:
425440
dx = FP8LinearFunctionBase.common_fp8_mlp_bwd(
426-
do3, x_fp8, x_scale, x_t_fp8, x_t_scale, w1, w2, apply_backward_hook=apply_backward_hook
441+
do3,
442+
x_t_fp8,
443+
x_t_scale,
444+
w1,
445+
w2,
446+
o1=None,
447+
x_fp8=x_fp8,
448+
x_scale=x_scale,
449+
apply_backward_hook=apply_backward_hook,
427450
)
428451
if len(x_orig_shape) > 2:
429452
dx = dx.reshape([x_orig_shape[0], -1, dx.shape[-1]])
430453
return dx
431454
else:
432455
dx, dw1, dw2 = FP8LinearFunctionBase.common_fp8_mlp_bwd(
433-
do3, x_fp8, x_scale, x_t_fp8, x_t_scale, w1, w2, apply_backward_hook=apply_backward_hook
456+
do3,
457+
x_t_fp8,
458+
x_t_scale,
459+
w1,
460+
w2,
461+
o1=None,
462+
x_fp8=x_fp8,
463+
x_scale=x_scale,
464+
apply_backward_hook=apply_backward_hook,
434465
)
435466
if len(x_orig_shape) > 2:
436467
dx = dx.reshape([x_orig_shape[0], -1, dx.shape[-1]])
@@ -580,14 +611,16 @@ def forward(ctx, x, norm_w, w1, w2, norm_eps):
580611
norm_output = norm_output.reshape([-1, x_orig_shape[-1]])
581612

582613
# ===== call func fp8_mlp_fwd =====
583-
_, _, o3 = FP8LinearFunctionBase.fp8_mlp_fwd(norm_output, w1, w2)
614+
_, _, _, o3 = FP8LinearFunctionBase.fp8_mlp_fwd(norm_output, w1, w2)
584615

585616
# ===== reshape to origin shape =====
586617
if len(x_orig_shape) > 2:
587618
o3 = o3.reshape([x_orig_shape[0], -1, o3.shape[-1]])
588619

589620
# ===== save for backward =====
590621
ctx.save_for_backward(
622+
norm_output,
623+
invar,
591624
x,
592625
norm_w,
593626
w1,
@@ -604,21 +637,15 @@ def backward(ctx, do3):
604637
do3 = do3.reshape([-1, do3_orig_shape[-1]])
605638

606639
# ===== recive saved tensors =====
607-
x, norm_w, w1, w2, norm_eps, x_orig_shape = ctx.saved_tensor()
608-
609-
# ===== recompute norm =====
610-
norm_output, invar = fused_ln.fused_rms_norm(x, norm_w, norm_eps)
611-
612-
# ===== compute x_t_fp8, x_t_scale for dw1 =====
613-
norm_output = norm_output.reshape([-1, x_orig_shape[-1]])
640+
norm_output, invar, x, norm_w, w1, w2, norm_eps, x_orig_shape = ctx.saved_tensor()
614641

615642
x_fp8, x_scale, x_t_fp8, x_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise(
616643
norm_output, output_scale_transpose=True, quant_method="1x128", input_transpose=True
617644
)
618645

619646
# ===== call func common_fp8_mlp_bwd =====
620647
d_norm_output, dw1, dw2 = FP8LinearFunctionBase.common_fp8_mlp_bwd(
621-
do3, x_fp8, x_scale, x_t_fp8, x_t_scale, w1, w2
648+
do3, x_t_fp8, x_t_scale, w1, w2, o1=None, x_fp8=x_fp8, x_scale=x_scale
622649
)
623650

624651
# ===== reshape to origin shape =====
@@ -639,13 +666,14 @@ def forward(ctx, x, w1, w2):
639666
x = x.reshape([-1, x_orig_shape[-1]])
640667

641668
# ===== call func fp8_mlp_fwd =====
642-
x_fp8, x_scale, o3 = FP8LinearFunctionBase.fp8_mlp_fwd(x, w1, w2)
669+
o1, x_fp8, x_scale, o3 = FP8LinearFunctionBase.fp8_mlp_fwd(x, w1, w2)
643670
# ===== reshape to origin shape =====
644671
if len(x_orig_shape) > 2:
645672
o3 = o3.reshape([x_orig_shape[0], -1, o3.shape[-1]])
646673

647674
# ===== save for backward =====
648675
ctx.save_for_backward(
676+
o1,
649677
x_fp8,
650678
x_scale,
651679
w1,
@@ -661,7 +689,7 @@ def backward(ctx, do3):
661689
do3 = do3.reshape([-1, do3_orig_shape[-1]])
662690

663691
# ===== recive saved tensors =====
664-
x_fp8, x_scale, w1, w2, x_orig_shape = ctx.saved_tensor()
692+
o1, x_fp8, x_scale, w1, w2, x_orig_shape = ctx.saved_tensor()
665693

666694
# ===== compute x_t_fp8, x_t_scale for dw1 =====
667695
x_dequant_fp16 = paddle.incubate.nn.functional.fused_act_dequant(x_fp8, x_scale.T.contiguous())
@@ -676,8 +704,9 @@ def backward(ctx, do3):
676704
)
677705

678706
# ===== call func common_fp8_mlp_bwd =====
679-
dx = FP8LinearFunctionBase.common_fp8_mlp_bwd(do3, x_fp8, x_scale, x_t_fp8, x_t_scale, w1, w2, True)
680-
707+
dx = FP8LinearFunctionBase.common_fp8_mlp_bwd(
708+
do3, x_t_fp8, x_t_scale, w1, w2, o1=o1, x_fp8=None, x_scale=None, apply_backward_hook=True
709+
)
681710
# ===== reshape to origin shape =====
682711
if len(x_orig_shape) > 2:
683712
dx = dx.reshape([x_orig_shape[0], -1, dx.shape[-1]])

0 commit comments

Comments
 (0)