@@ -217,22 +217,23 @@ def compute_fp8_linear(
217
217
input , weight , weight_transpose = False , return_transpose_only = False , return_mode = "output_only" , * , out = None
218
218
):
219
219
"""
220
- FP8 Linear 计算函数,支持多种返回模式,支持量化/未量化输入。
220
+ FP8 Linear computation function supporting multiple return modes and quantized/unquantized inputs.
221
221
222
222
Args:
223
- input: 输入张量(原始或已经量化的(input_fp8, input_scale) 元组)。
224
- weight: 权重张量。
225
- weight_transpose (bool): 是否转置权重。
226
- return_transpose_only (bool): 是否仅返回转置后的权重。
227
- return_mode (str): 返回模式,可选:
228
- - "output_only": 仅返回输出张量。
229
- - "with_input_quant": 返回输出 + 输入量化结果 (input_fp8, input_scale)。
230
- - "with_input_transpose_quant": 返回输出(out) + 输入量化转置结果 (input_t_fp8, input_t_scale).
223
+ input: Input tensor (raw tensor or quantized as (input_fp8, input_scale) tuple)
224
+ weight: Weight tensor
225
+ weight_transpose (bool): Whether to transpose weight
226
+ return_transpose_only (bool): Whether to return only transposed weight
227
+ return_mode (str): Return mode options:
228
+ - "output_only": Returns only output tensor
229
+ - "with_input_quant": Returns output + input quant results (input_fp8, input_scale)
230
+ - "with_input_transpose_quant": Returns output + transposed quant results (input_t_fp8, input_t_scale)
231
+
231
232
Returns:
232
- 根据 return_mode 返回不同组合的张量。
233
+ Different combinations of tensors based on return_mode
233
234
234
235
Raises:
235
- RuntimeError: 如果 return_mode 不支持。
236
+ RuntimeError: If return_mode is not supported
236
237
"""
237
238
# check input
238
239
is_input_quantized = isinstance (input , (tuple , list )) and len (input ) == 2
@@ -294,7 +295,7 @@ def compute_expert_w_grad(
294
295
rtn_dtype = paddle .bfloat16 ,
295
296
):
296
297
"""
297
- 统一处理 expert_w 的梯度计算(支持 main_grad 和普通 grad)
298
+ Unified gradient computation for expert_w weights (supports both main_grad and regular grad).
298
299
"""
299
300
300
301
if input_t is None or numpy .prod (input_t .shape ) == 0 :
@@ -352,22 +353,22 @@ def common_fp8_mlp_bwd(
352
353
if x_fp8 is None or x_scale is None :
353
354
raise ValueError ("When o1 is None, both x_fp8 and x_scale must be provided." )
354
355
355
- # # ===== [recompute] o1 = deep_gemm(x_fp8, w1_t_fp8) =====
356
+ # [recompute] o1 = deep_gemm(x_fp8, w1_t_fp8)
356
357
357
358
# Recompute o1 using deep_gemm(x_fp8, w1_t_fp8)
358
359
w1_fp8 , w1_scale = weight_quant (w1 , True )
359
360
o1 = paddle .empty ([x_fp8 .shape [0 ], w1_fp8 .shape [0 ]], dtype = do3 .dtype )
360
361
deep_gemm .gemm_fp8_fp8_bf16_nt ((x_fp8 , x_scale .T ), (w1_fp8 , w1_scale ), o1 , num_sms = get_sm_num ())
361
362
362
- # ===== [recompute] o2 = swiglu(o1) =====
363
+ # [recompute] o2 = swiglu(o1)
363
364
o2 = swiglu (o1 )
364
365
365
- # ===== do2 = deep_gemm(do3_fp8, w2_fp8)
366
+ # do2 = deep_gemm(do3_fp8, w2_fp8)
366
367
do2 , do3_t_fp8 , do3_t_scale = FP8LinearFunctionBase .compute_fp8_linear (
367
368
do3 , w2 , return_mode = "with_input_transpose_quant"
368
369
)
369
370
370
- # ===== dw2 = deep_gemm(o2_t_fp8, do3_t_fp8)
371
+ # dw2 = deep_gemm(o2_t_fp8, do3_t_fp8)
371
372
o2 = FP8LinearFunctionBase .padding (o2 , 0 )
372
373
o2_t_fp8 , o2_t_scale = paddle .incubate .nn .functional .fp8_quant_blockwise (
373
374
o2 , output_scale_transpose = True , quant_method = "1x128" , input_transpose = True , return_transpose_only = True
@@ -397,15 +398,15 @@ def common_fp8_mlp_bwd(
397
398
o2_t_fp8 , o2_t_scale , do3_t_fp8 , do3_t_scale , True , True , rtn_dtype = paddle .float32
398
399
)
399
400
400
- # ===== do1 = swiglu_grad(o1, None, do2) =====
401
+ # do1 = swiglu_grad(o1, None, do2)
401
402
do1 , _ = paddle ._C_ops .swiglu_grad (o1 , None , do2 )
402
403
403
- # ===== dx = deep_gemm(do1_fp8, w1_fp8) =====
404
+ # dx = deep_gemm(do1_fp8, w1_fp8)
404
405
dx , do1_t_fp8 , do1_t_scale = FP8LinearFunctionBase .compute_fp8_linear (
405
406
do1 , w1 , return_mode = "with_input_transpose_quant"
406
407
)
407
408
408
- # ===== dw1 = deep_gemm(x_t_fp8, do1_t_fp8) =====
409
+ # dw1 = deep_gemm(x_t_fp8, do1_t_fp8)
409
410
if apply_backward_hook :
410
411
if WeightGradStore .enabled :
411
412
WeightGradStore .put (
@@ -442,15 +443,15 @@ def fp8_mlp_fwd(x, w1, w2):
442
443
x_orig_shape = x .shape
443
444
x = x .reshape ([- 1 , x_orig_shape [- 1 ]])
444
445
445
- # ===== o1 = deep_gemm(x_fp8, w1_t_fp8) =====
446
+ # o1 = deep_gemm(x_fp8, w1_t_fp8)
446
447
o1 , x_fp8 , x_scale = FP8LinearFunctionBase .compute_fp8_linear (
447
448
x , w1 , weight_transpose = True , return_transpose_only = True , return_mode = "with_input_quant"
448
449
)
449
450
450
- # ===== o2 = swiglu(o1) =====
451
+ # o2 = swiglu(o1)
451
452
o2 = swiglu (o1 )
452
453
453
- # ===== o3 = deep_gemm(o2_fp8, w2_t_fp8) =====
454
+ # o3 = deep_gemm(o2_fp8, w2_t_fp8)
454
455
o3 = FP8LinearFunctionBase .compute_fp8_linear (o2 , w2 , weight_transpose = True , return_transpose_only = True )
455
456
456
457
if len (x_orig_shape ) > 2 :
@@ -460,9 +461,9 @@ def fp8_mlp_fwd(x, w1, w2):
460
461
461
462
@staticmethod
462
463
def fp8_mlp_fwd_norm_rc (x , norm_w , norm_eps , w1 , w2 ):
463
- # ===== compute norm_output =====
464
+ # compute norm_output
464
465
norm_output , _ = fused_ln .fused_rms_norm (x , norm_w , norm_eps )
465
- # ===== compute fp8_mlp_fwd =====
466
+ # compute fp8_mlp_fwd
466
467
_ , _ , _ , o3 = FP8LinearFunctionBase .fp8_mlp_fwd (norm_output , w1 , w2 )
467
468
return o3
468
469
@@ -510,10 +511,10 @@ def fp8_mlp_bwd(do3, x, w1, w2, apply_backward_hook=False):
510
511
511
512
@staticmethod
512
513
def fp8_mlp_bwd_norm_rc (do3 , x , norm_w , norm_eps , w1 , w2 ):
513
- # ===== recompute norm_output =====
514
+ # recompute norm_output
514
515
norm_output , invar = fused_ln .fused_rms_norm (x , norm_w , norm_eps )
515
516
516
- # ===== compute fp8_mlp_fwd =====
517
+ # compute fp8_mlp_fwd
517
518
d_norm_output = FP8LinearFunctionBase .fp8_mlp_bwd (do3 , norm_output , w1 , w2 , True )
518
519
519
520
if hasattr (norm_w , "_apply_backward_hook" ):
@@ -567,7 +568,7 @@ def backward(ctx, dout):
567
568
x , output_scale_transpose = True , quant_method = "1x128" , input_transpose = True , return_transpose_only = True
568
569
)
569
570
570
- # ===== dx = deep_gemm(dout_fp8, w_fp8)
571
+ # dx = deep_gemm(dout_fp8, w_fp8)
571
572
dx , dout_t_fp8 , dout_t_scale = FP8LinearFunctionBase .compute_fp8_linear (
572
573
dout_2d , weight , weight_transpose = False , return_mode = "with_input_transpose_quant"
573
574
)
@@ -576,15 +577,15 @@ def backward(ctx, dout):
576
577
else :
577
578
x_t_fp8 , x_t_scale = x
578
579
579
- # ===== dx = deep_gemm(dout_fp8, w_fp8)
580
+ # dx = deep_gemm(dout_fp8, w_fp8)
580
581
dx , dout_t_fp8 , dout_t_scale = FP8LinearFunctionBase .compute_fp8_linear (
581
582
dout_2d , weight , weight_transpose = False , return_mode = "with_input_transpose_quant"
582
583
)
583
584
dx_orig_shape = dout .shape [:- 1 ]
584
585
dx_orig_shape .append (ctx .x_t_shape [0 ])
585
586
dx = dx .reshape (dx_orig_shape )
586
587
587
- # ===== dw1 = deep_gemm(x_t_fp8, dout_t_fp8)
588
+ # dw1 = deep_gemm(x_t_fp8, dout_t_fp8)
588
589
FP8LinearFunctionBase .compute_expert_w_grad (
589
590
x_t_fp8 , x_t_scale , dout_t_fp8 , dout_t_scale , True , True , weight , paddle .float32
590
591
)
@@ -668,20 +669,20 @@ def forward(self, x):
668
669
class FusedNormFP8MLPFunction (paddle .autograd .PyLayer ):
669
670
@staticmethod
670
671
def forward (ctx , x , norm_w , w1 , w2 , norm_eps ):
671
- # ===== compute norm_output =====
672
+ # compute norm_output
672
673
norm_output , invar = fused_ln .fused_rms_norm (x , norm_w , norm_eps )
673
- # ===== reshape for deep_gemm, since deep_gemm only support 2D =====
674
+ # reshape for deep_gemm, since deep_gemm only support 2D
674
675
x_orig_shape = norm_output .shape
675
676
norm_output = norm_output .reshape ([- 1 , x_orig_shape [- 1 ]])
676
677
677
- # ===== call func fp8_mlp_fwd =====
678
+ # call func fp8_mlp_fwd
678
679
_ , _ , _ , o3 = FP8LinearFunctionBase .fp8_mlp_fwd (norm_output , w1 , w2 )
679
680
680
- # ===== reshape to origin shape =====
681
+ # reshape to origin shape
681
682
if len (x_orig_shape ) > 2 :
682
683
o3 = o3 .reshape ([x_orig_shape [0 ], - 1 , o3 .shape [- 1 ]])
683
684
684
- # ===== save for backward =====
685
+ # save for backward
685
686
ctx .save_for_backward (
686
687
norm_output ,
687
688
invar ,
@@ -696,27 +697,27 @@ def forward(ctx, x, norm_w, w1, w2, norm_eps):
696
697
697
698
@staticmethod
698
699
def backward (ctx , do3 ):
699
- # ===== reshape for deep_gemm, since deep_gemm only support 2D =====
700
+ # reshape for deep_gemm, since deep_gemm only support 2D
700
701
do3_orig_shape = do3 .shape
701
702
do3 = do3 .reshape ([- 1 , do3_orig_shape [- 1 ]])
702
703
703
- # ===== recive saved tensors =====
704
+ # recive saved tensors
704
705
norm_output , invar , x , norm_w , w1 , w2 , norm_eps , x_orig_shape = ctx .saved_tensor ()
705
706
706
707
x_fp8 , x_scale , x_t_fp8 , x_t_scale = paddle .incubate .nn .functional .fp8_quant_blockwise (
707
708
norm_output , output_scale_transpose = True , quant_method = "1x128" , input_transpose = True
708
709
)
709
710
710
- # ===== call func common_fp8_mlp_bwd =====
711
+ # call func common_fp8_mlp_bwd
711
712
d_norm_output , dw1 , dw2 = FP8LinearFunctionBase .common_fp8_mlp_bwd (
712
713
do3 , x_t_fp8 , x_t_scale , w1 , w2 , o1 = None , x_fp8 = x_fp8 , x_scale = x_scale
713
714
)
714
715
715
- # ===== reshape to origin shape =====
716
+ # reshape to origin shape
716
717
if len (x_orig_shape ) > 2 :
717
718
d_norm_output = d_norm_output .reshape ([x_orig_shape [0 ], - 1 , d_norm_output .shape [- 1 ]])
718
719
719
- # ===== compute norm grad =====
720
+ # compute norm grad
720
721
dx , d_rms_norm_weight = fused_ln .fused_rms_norm_grad_func (x , norm_w , invar , d_norm_output , norm_eps )
721
722
722
723
return dx , d_rms_norm_weight , dw1 , dw2
@@ -725,17 +726,17 @@ def backward(ctx, do3):
725
726
class FP8MlpFunction (paddle .autograd .PyLayer ):
726
727
@staticmethod
727
728
def forward (ctx , x , w1 , w2 , recompute_fwd_gate_up ):
728
- # ===== reshape for deep_gemm, since deep_gemm only support 2D =====
729
+ # reshape for deep_gemm, since deep_gemm only support 2D
729
730
x_orig_shape = x .shape
730
731
x = x .reshape ([- 1 , x_orig_shape [- 1 ]])
731
732
732
- # ===== call func fp8_mlp_fwd =====
733
+ # call func fp8_mlp_fwd
733
734
o1 , x_fp8 , x_scale , o3 = FP8LinearFunctionBase .fp8_mlp_fwd (x , w1 , w2 )
734
- # ===== reshape to origin shape =====
735
+ # reshape to origin shape
735
736
if len (x_orig_shape ) > 2 :
736
737
o3 = o3 .reshape ([x_orig_shape [0 ], - 1 , o3 .shape [- 1 ]])
737
738
738
- # ===== save for backward =====
739
+ # save for backward
739
740
o1 = None if recompute_fwd_gate_up else o1
740
741
ctx .save_for_backward (
741
742
o1 ,
@@ -749,14 +750,14 @@ def forward(ctx, x, w1, w2, recompute_fwd_gate_up):
749
750
750
751
@staticmethod
751
752
def backward (ctx , do3 ):
752
- # ===== reshape for deep_gemm, since deep_gemm only support 2D =====
753
+ # reshape for deep_gemm, since deep_gemm only support 2D
753
754
do3_orig_shape = do3 .shape
754
755
do3 = do3 .reshape ([- 1 , do3_orig_shape [- 1 ]])
755
756
756
- # ===== recive saved tensors =====
757
+ # recive saved tensors
757
758
o1 , x_fp8 , x_scale , w1 , w2 , x_orig_shape = ctx .saved_tensor ()
758
759
759
- # ===== compute x_t_fp8, x_t_scale for dw1 =====
760
+ # compute x_t_fp8, x_t_scale for dw1
760
761
x_dequant_fp16 = paddle .incubate .nn .functional .fused_act_dequant (x_fp8 , x_scale .T .contiguous ())
761
762
x_dequant_fp16 = FP8LinearFunctionBase .padding (x_dequant_fp16 , 0 )
762
763
@@ -768,7 +769,7 @@ def backward(ctx, do3):
768
769
return_transpose_only = True ,
769
770
)
770
771
771
- # ===== call func common_fp8_mlp_bwd =====
772
+ # call func common_fp8_mlp_bwd
772
773
if o1 is None :
773
774
dx = FP8LinearFunctionBase .common_fp8_mlp_bwd (
774
775
do3 , x_t_fp8 , x_t_scale , w1 , w2 , o1 = None , x_fp8 = x_fp8 , x_scale = x_scale , apply_backward_hook = True
@@ -777,7 +778,7 @@ def backward(ctx, do3):
777
778
dx = FP8LinearFunctionBase .common_fp8_mlp_bwd (
778
779
do3 , x_t_fp8 , x_t_scale , w1 , w2 , o1 = o1 , x_fp8 = None , x_scale = None , apply_backward_hook = True
779
780
)
780
- # ===== reshape to origin shape =====
781
+ # reshape to origin shape
781
782
if len (x_orig_shape ) > 2 :
782
783
dx = dx .reshape ([x_orig_shape [0 ], - 1 , dx .shape [- 1 ]])
783
784
0 commit comments