@@ -679,10 +679,9 @@ def fwd_gate_up(self, x_bf16, expert_w1, num_expert, tokens_per_expert):
679
679
else :
680
680
# quant x_bf16
681
681
x_fp8 , x_scale = paddle .incubate .nn .functional .fp8_quant_blockwise (
682
- x_bf16 , output_scale_transpose = False , quant_method = "1x128" , input_transpose = False
682
+ x_bf16 , output_scale_transpose = True , quant_method = "1x128" , input_transpose = False
683
683
)
684
-
685
- x_scale = paddle .transpose (paddle .transpose (x_scale , [1 , 0 ]).contiguous (), [1 , 0 ])
684
+ x_scale = x_scale .T
686
685
687
686
# compute gemm
688
687
o1 = paddle .empty ([x_fp8 .shape [0 ], w1_t_quant .shape [1 ]], dtype = expert_w1 [0 ].dtype )
@@ -730,9 +729,8 @@ def fwd_down(self, o1, unzipped_probs, expert_w2, num_expert, o3=None, clear_o1=
730
729
o3_shape = [o2_fp8 .shape [0 ], w2_quant .shape [1 ]]
731
730
if o3 is not None :
732
731
assert o3 .shape == o3_shape , "{} vs {}" .format (o3 .shape , o3_shape )
733
- o3 .zero_ ()
734
732
else :
735
- o3 = paddle .zeros (o3_shape , dtype = o1 .dtype )
733
+ o3 = paddle .empty (o3_shape , dtype = o1 .dtype )
736
734
if numpy .prod (o2_fp8 .shape ) != 0 :
737
735
if self .is_split_group_gemm :
738
736
split_group_gemm (o2_fp8 , o2_scale , w2_quant , w2_sacle , self .tokens_per_expert , o3 )
@@ -756,18 +754,17 @@ def bwd_dowm_input(self, expert_w2, unzipped_grad, o1, inplace_swiglu_prob=False
756
754
757
755
# compute gemm
758
756
unzipped_grad_fp8 , unzipped_grad_scale = paddle .incubate .nn .functional .fp8_quant_blockwise (
759
- unzipped_grad , output_scale_transpose = False , quant_method = "1x128" , input_transpose = False
757
+ unzipped_grad , output_scale_transpose = True , quant_method = "1x128" , input_transpose = False
760
758
)
759
+ unzipped_grad_scale = unzipped_grad_scale .T
760
+
761
761
do2_s = paddle .empty ([unzipped_grad_fp8 .shape [0 ], bw_w2_quant .shape [1 ]], dtype = unzipped_grad .dtype )
762
762
if numpy .prod (unzipped_grad_fp8 .shape ) != 0 :
763
763
if self .is_split_group_gemm :
764
764
split_group_gemm (
765
765
unzipped_grad_fp8 , unzipped_grad_scale , bw_w2_quant , bw_w2_scale , self .tokens_per_expert , do2_s
766
766
)
767
767
else :
768
- unzipped_grad_scale = paddle .transpose (
769
- paddle .transpose (unzipped_grad_scale , [1 , 0 ]).contiguous (), [1 , 0 ]
770
- )
771
768
deep_gemm .m_grouped_gemm_fp8_fp8_bf16_nt_contiguous (
772
769
(unzipped_grad_fp8 , unzipped_grad_scale ),
773
770
(bw_w2_quant , bw_w2_scale ),
@@ -801,21 +798,19 @@ def bwd_gate_up_input(self, do1, expert_w1, dx=None):
801
798
802
799
# quant do1
803
800
do1_fp8 , do1_scale = paddle .incubate .nn .functional .fp8_quant_blockwise (
804
- do1 , output_scale_transpose = False , quant_method = "1x128" , input_transpose = False
801
+ do1 , output_scale_transpose = True , quant_method = "1x128" , input_transpose = False
805
802
)
806
-
803
+ do1_scale = do1_scale . T
807
804
# compute gemm
808
805
dx_shape = [do1_fp8 .shape [0 ], bw_w1_quant .shape [1 ]]
809
806
if dx is None :
810
807
dx = paddle .empty (shape = dx_shape , dtype = do1 .dtype )
811
808
else :
812
809
assert dx .shape == dx_shape , f"{ dx .shape } vs { dx_shape } "
813
- dx .zero_ ()
814
810
if numpy .prod (do1_fp8 .shape ) != 0 :
815
811
if self .is_split_group_gemm :
816
812
split_group_gemm (do1_fp8 , do1_scale , bw_w1_quant , bw_w1_scale , self .tokens_per_expert , dx )
817
813
else :
818
- do1_scale = paddle .transpose (paddle .transpose (do1_scale , [1 , 0 ]).contiguous (), [1 , 0 ])
819
814
deep_gemm .m_grouped_gemm_fp8_fp8_bf16_nt_contiguous (
820
815
(do1_fp8 , do1_scale ), (bw_w1_quant , bw_w1_scale ), dx , m_indices = self .m_indices , num_sms = 112
821
816
)
0 commit comments