1
1
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2
- #
2
+
3
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
4
# you may not use this file except in compliance with the License.
5
5
# You may obtain a copy of the License at
6
- #
6
+
7
7
# http://www.apache.org/licenses/LICENSE-2.0
8
- #
8
+
9
9
# Unless required by applicable law or agreed to in writing, software
10
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
+
14
15
import os
15
16
from functools import partial
16
17
@@ -301,16 +302,22 @@ def compute_expert_w_grad(
301
302
return result
302
303
303
304
@staticmethod
304
- def common_fp8_mlp_bwd (do3 , x_fp8 , x_scale , x_t_fp8 , x_t_scale , w1 , w2 , apply_backward_hook = False ):
305
+ def common_fp8_mlp_bwd (
306
+ do3 , x_t_fp8 , x_t_scale , w1 , w2 , o1 = None , x_fp8 = None , x_scale = None , apply_backward_hook = False
307
+ ):
308
+ if o1 is not None and (x_fp8 is not None or x_scale is not None ):
309
+ raise ValueError ("When o1 is provided, both x_fp8 and x_scale must be None." )
305
310
306
- # # ===== [recompute] o1 = deep_gemm(x_fp8, w1_t_fp8) =====
307
- # o1, x_t_fp8, x_t_scale = FP8LinearFunctionBase.compute_fp8_linear(
308
- # x, w1, weight_transpose=True, return_transpose_only=True, return_mode="with_input_transpose_quant"
309
- # )
311
+ if o1 is None :
312
+ if x_fp8 is None or x_scale is None :
313
+ raise ValueError ("When o1 is None, both x_fp8 and x_scale must be provided." )
310
314
311
- w1_fp8 , w1_scale = weight_quant (w1 , True )
312
- o1 = paddle .empty ([x_fp8 .shape [0 ], w1_fp8 .shape [0 ]], dtype = do3 .dtype )
313
- deep_gemm .gemm_fp8_fp8_bf16_nt ((x_fp8 , x_scale .T ), (w1_fp8 , w1_scale ), o1 , num_sms = 118 )
315
+ # # ===== [recompute] o1 = deep_gemm(x_fp8, w1_t_fp8) =====
316
+
317
+ # Recompute o1 using deep_gemm(x_fp8, w1_t_fp8)
318
+ w1_fp8 , w1_scale = weight_quant (w1 , True )
319
+ o1 = paddle .empty ([x_fp8 .shape [0 ], w1_fp8 .shape [0 ]], dtype = do3 .dtype )
320
+ deep_gemm .gemm_fp8_fp8_bf16_nt ((x_fp8 , x_scale .T ), (w1_fp8 , w1_scale ), o1 , num_sms = 118 )
314
321
315
322
# ===== [recompute] o2 = swiglu(o1) =====
316
323
o2 = swiglu (o1 )
@@ -409,7 +416,15 @@ def fp8_mlp_fwd(x, w1, w2):
409
416
if len (x_orig_shape ) > 2 :
410
417
o3 = o3 .reshape ([x_orig_shape [0 ], - 1 , o3 .shape [- 1 ]])
411
418
412
- return x_fp8 , x_scale , o3
419
+ return o1 , x_fp8 , x_scale , o3
420
+
421
+ @staticmethod
422
+ def fp8_mlp_fwd_norm_rc (x , norm_w , norm_eps , w1 , w2 ):
423
+ # ===== compute norm_output =====
424
+ norm_output , _ = fused_ln .fused_rms_norm (x , norm_w , norm_eps )
425
+ # ===== compute fp8_mlp_fwd =====
426
+ _ , _ , _ , o3 = FP8LinearFunctionBase .fp8_mlp_fwd (norm_output , w1 , w2 )
427
+ return o3
413
428
414
429
@staticmethod
415
430
def fp8_mlp_bwd (do3 , x , w1 , w2 , apply_backward_hook = False ):
@@ -423,14 +438,30 @@ def fp8_mlp_bwd(do3, x, w1, w2, apply_backward_hook=False):
423
438
424
439
if apply_backward_hook :
425
440
dx = FP8LinearFunctionBase .common_fp8_mlp_bwd (
426
- do3 , x_fp8 , x_scale , x_t_fp8 , x_t_scale , w1 , w2 , apply_backward_hook = apply_backward_hook
441
+ do3 ,
442
+ x_t_fp8 ,
443
+ x_t_scale ,
444
+ w1 ,
445
+ w2 ,
446
+ o1 = None ,
447
+ x_fp8 = x_fp8 ,
448
+ x_scale = x_scale ,
449
+ apply_backward_hook = apply_backward_hook ,
427
450
)
428
451
if len (x_orig_shape ) > 2 :
429
452
dx = dx .reshape ([x_orig_shape [0 ], - 1 , dx .shape [- 1 ]])
430
453
return dx
431
454
else :
432
455
dx , dw1 , dw2 = FP8LinearFunctionBase .common_fp8_mlp_bwd (
433
- do3 , x_fp8 , x_scale , x_t_fp8 , x_t_scale , w1 , w2 , apply_backward_hook = apply_backward_hook
456
+ do3 ,
457
+ x_t_fp8 ,
458
+ x_t_scale ,
459
+ w1 ,
460
+ w2 ,
461
+ o1 = None ,
462
+ x_fp8 = x_fp8 ,
463
+ x_scale = x_scale ,
464
+ apply_backward_hook = apply_backward_hook ,
434
465
)
435
466
if len (x_orig_shape ) > 2 :
436
467
dx = dx .reshape ([x_orig_shape [0 ], - 1 , dx .shape [- 1 ]])
@@ -580,14 +611,16 @@ def forward(ctx, x, norm_w, w1, w2, norm_eps):
580
611
norm_output = norm_output .reshape ([- 1 , x_orig_shape [- 1 ]])
581
612
582
613
# ===== call func fp8_mlp_fwd =====
583
- _ , _ , o3 = FP8LinearFunctionBase .fp8_mlp_fwd (norm_output , w1 , w2 )
614
+ _ , _ , _ , o3 = FP8LinearFunctionBase .fp8_mlp_fwd (norm_output , w1 , w2 )
584
615
585
616
# ===== reshape to origin shape =====
586
617
if len (x_orig_shape ) > 2 :
587
618
o3 = o3 .reshape ([x_orig_shape [0 ], - 1 , o3 .shape [- 1 ]])
588
619
589
620
# ===== save for backward =====
590
621
ctx .save_for_backward (
622
+ norm_output ,
623
+ invar ,
591
624
x ,
592
625
norm_w ,
593
626
w1 ,
@@ -604,21 +637,15 @@ def backward(ctx, do3):
604
637
do3 = do3 .reshape ([- 1 , do3_orig_shape [- 1 ]])
605
638
606
639
# ===== recive saved tensors =====
607
- x , norm_w , w1 , w2 , norm_eps , x_orig_shape = ctx .saved_tensor ()
608
-
609
- # ===== recompute norm =====
610
- norm_output , invar = fused_ln .fused_rms_norm (x , norm_w , norm_eps )
611
-
612
- # ===== compute x_t_fp8, x_t_scale for dw1 =====
613
- norm_output = norm_output .reshape ([- 1 , x_orig_shape [- 1 ]])
640
+ norm_output , invar , x , norm_w , w1 , w2 , norm_eps , x_orig_shape = ctx .saved_tensor ()
614
641
615
642
x_fp8 , x_scale , x_t_fp8 , x_t_scale = paddle .incubate .nn .functional .fp8_quant_blockwise (
616
643
norm_output , output_scale_transpose = True , quant_method = "1x128" , input_transpose = True
617
644
)
618
645
619
646
# ===== call func common_fp8_mlp_bwd =====
620
647
d_norm_output , dw1 , dw2 = FP8LinearFunctionBase .common_fp8_mlp_bwd (
621
- do3 , x_fp8 , x_scale , x_t_fp8 , x_t_scale , w1 , w2
648
+ do3 , x_t_fp8 , x_t_scale , w1 , w2 , o1 = None , x_fp8 = x_fp8 , x_scale = x_scale
622
649
)
623
650
624
651
# ===== reshape to origin shape =====
@@ -639,13 +666,14 @@ def forward(ctx, x, w1, w2):
639
666
x = x .reshape ([- 1 , x_orig_shape [- 1 ]])
640
667
641
668
# ===== call func fp8_mlp_fwd =====
642
- x_fp8 , x_scale , o3 = FP8LinearFunctionBase .fp8_mlp_fwd (x , w1 , w2 )
669
+ o1 , x_fp8 , x_scale , o3 = FP8LinearFunctionBase .fp8_mlp_fwd (x , w1 , w2 )
643
670
# ===== reshape to origin shape =====
644
671
if len (x_orig_shape ) > 2 :
645
672
o3 = o3 .reshape ([x_orig_shape [0 ], - 1 , o3 .shape [- 1 ]])
646
673
647
674
# ===== save for backward =====
648
675
ctx .save_for_backward (
676
+ o1 ,
649
677
x_fp8 ,
650
678
x_scale ,
651
679
w1 ,
@@ -661,7 +689,7 @@ def backward(ctx, do3):
661
689
do3 = do3 .reshape ([- 1 , do3_orig_shape [- 1 ]])
662
690
663
691
# ===== recive saved tensors =====
664
- x_fp8 , x_scale , w1 , w2 , x_orig_shape = ctx .saved_tensor ()
692
+ o1 , x_fp8 , x_scale , w1 , w2 , x_orig_shape = ctx .saved_tensor ()
665
693
666
694
# ===== compute x_t_fp8, x_t_scale for dw1 =====
667
695
x_dequant_fp16 = paddle .incubate .nn .functional .fused_act_dequant (x_fp8 , x_scale .T .contiguous ())
@@ -676,8 +704,9 @@ def backward(ctx, do3):
676
704
)
677
705
678
706
# ===== call func common_fp8_mlp_bwd =====
679
- dx = FP8LinearFunctionBase .common_fp8_mlp_bwd (do3 , x_fp8 , x_scale , x_t_fp8 , x_t_scale , w1 , w2 , True )
680
-
707
+ dx = FP8LinearFunctionBase .common_fp8_mlp_bwd (
708
+ do3 , x_t_fp8 , x_t_scale , w1 , w2 , o1 = o1 , x_fp8 = None , x_scale = None , apply_backward_hook = True
709
+ )
681
710
# ===== reshape to origin shape =====
682
711
if len (x_orig_shape ) > 2 :
683
712
dx = dx .reshape ([x_orig_shape [0 ], - 1 , dx .shape [- 1 ]])
0 commit comments