@@ -92,6 +92,8 @@ def fused_stack_quant(expert_weight_list, transpose=False):
92
92
w , scale = _get_fp8_weight_and_scale (expert_weight_list [0 ], stacked = True , transpose = True )
93
93
elif transpose is True and hasattr (expert_weight_list [0 ], "fp8_weight_stacked" ):
94
94
w , scale = _get_fp8_weight_and_scale (expert_weight_list [0 ], stacked = True , transpose = False )
95
+ elif transpose is False and hasattr (expert_weight_list [0 ], "fp8_weight_stacked_transpose" ):
96
+ w , scale = _get_fp8_weight_and_scale (expert_weight_list [0 ], stacked = True , transpose = True )
95
97
else :
96
98
w , scale = paddle .incubate .nn .functional .fused_stack_transpose_quant (expert_weight_list , transpose = transpose )
97
99
return w , scale
@@ -114,6 +116,8 @@ def weight_quant(weight, transpose=False):
114
116
else :
115
117
if hasattr (weight , "fp8_weight" ):
116
118
return weight .fp8_weight , weight .fp8_scale
119
+ elif hasattr (weight , "fp8_weight_transpose" ):
120
+ return weight .fp8_weight_transpose .T .contiguous (), weight .fp8_scale_transpose .T .contiguous ()
117
121
else :
118
122
return paddle .incubate .nn .functional .fp8_quant_blockwise (
119
123
weight ,
@@ -596,23 +600,33 @@ def forward(self, x):
596
600
return FP8LinearFunction .apply (x , self , keep_x = False )
597
601
598
602
599
- def cache_fp8_weight (weight , quant_transpose = True ):
600
- if hasattr (weight , "fp8_weight" ):
603
+ def cache_fp8_weight (weight , quant_transpose = None ):
604
+ if hasattr (weight , "fp8_weight" ) or hasattr ( weight , "fp8_weight_transpose" ) :
601
605
return
602
-
603
- if quant_transpose :
606
+ if quant_transpose is None :
604
607
w_fp8 , w_scale , w_t_fp8 , w_t_scale = paddle .incubate .nn .functional .fp8_quant_blockwise (
605
608
weight ,
606
609
output_scale_transpose = False ,
607
610
quant_method = "128x128" ,
608
611
input_transpose = True ,
609
612
return_transpose_only = False ,
610
613
)
614
+
611
615
setattr (weight , "fp8_weight_transpose" , w_t_fp8 )
612
616
setattr (weight , "fp8_scale_transpose" , w_t_scale )
613
617
setattr (weight , "fp8_weight" , w_fp8 )
614
618
setattr (weight , "fp8_scale" , w_scale )
615
- else :
619
+ elif quant_transpose is True :
620
+ w_t_fp8 , w_t_scale = paddle .incubate .nn .functional .fp8_quant_blockwise (
621
+ weight ,
622
+ output_scale_transpose = False ,
623
+ quant_method = "128x128" ,
624
+ input_transpose = True ,
625
+ return_transpose_only = True ,
626
+ )
627
+ setattr (weight , "fp8_weight_transpose" , w_t_fp8 )
628
+ setattr (weight , "fp8_scale_transpose" , w_t_scale )
629
+ elif quant_transpose is False :
616
630
w_fp8 , w_scale = paddle .incubate .nn .functional .fp8_quant_blockwise (
617
631
weight ,
618
632
output_scale_transpose = False ,
@@ -622,6 +636,8 @@ def cache_fp8_weight(weight, quant_transpose=True):
622
636
)
623
637
setattr (weight , "fp8_weight" , w_fp8 )
624
638
setattr (weight , "fp8_scale" , w_scale )
639
+ else :
640
+ raise ValueError ("quant_transpose must be either True, False or None." )
625
641
626
642
627
643
class FP8KeepXLinear (paddle .nn .Layer ):
@@ -636,7 +652,7 @@ def __init__(self, in_features: int, out_features: int, bias_attr: bool = False)
636
652
)
637
653
set_parameter_color ([self .weight ], "attn_out_project" )
638
654
639
- def fp8_quant_weight (self , quant_transpose = True ):
655
+ def fp8_quant_weight (self , quant_transpose = None ):
640
656
cache_fp8_weight (self .weight , quant_transpose = quant_transpose )
641
657
642
658
def forward (self , x ):
@@ -798,7 +814,7 @@ def __init__(
798
814
is_bias = False ,
799
815
)
800
816
801
- def fp8_quant_weight (self , quant_transpose = True ):
817
+ def fp8_quant_weight (self , quant_transpose = None ):
802
818
cache_fp8_weight (self .w1 , quant_transpose )
803
819
cache_fp8_weight (self .w2 , quant_transpose )
804
820
@@ -980,6 +996,10 @@ def bwd_dowm_input(self, expert_w2, unzipped_grad, o1, tokens_per_expert, m_indi
980
996
bw_w2_quant = bw_w2_quant .reshape ([len (expert_w2 ), - 1 , bw_w2_quant .shape [- 1 ]])
981
997
bw_w2_scale = bw_w2_scale .reshape ([len (expert_w2 ), - 1 , bw_w2_scale .shape [- 1 ]])
982
998
999
+ if hasattr (expert_w2 [0 ], "fp8_weight_stacked_transpose" ) and not hasattr (expert_w2 [0 ], "fp8_weight_stacked" ):
1000
+ bw_w2_quant = bw_w2_quant .contiguous ().transpose ([0 , 2 , 1 ]).contiguous ()
1001
+ bw_w2_scale = bw_w2_scale .contiguous ().transpose ([0 , 2 , 1 ]).contiguous ()
1002
+
983
1003
# compute gemm
984
1004
if isinstance (unzipped_grad , tuple ):
985
1005
(unzipped_grad_fp8 , unzipped_grad_scale ) = unzipped_grad
@@ -1024,6 +1044,10 @@ def bwd_gate_up_input(self, do1, expert_w1, tokens_per_expert, m_indices=None, d
1024
1044
bw_w1_quant = bw_w1_quant .reshape ([len (expert_w1 ), - 1 , bw_w1_quant .shape [- 1 ]])
1025
1045
bw_w1_scale = bw_w1_scale .reshape ([len (expert_w1 ), - 1 , bw_w1_scale .shape [- 1 ]])
1026
1046
1047
+ if hasattr (expert_w1 [0 ], "fp8_weight_stacked_transpose" ) and not hasattr (expert_w1 [0 ], "fp8_weight_stacked" ):
1048
+ bw_w1_quant = bw_w1_quant .contiguous ().transpose ([0 , 2 , 1 ]).contiguous ()
1049
+ bw_w1_scale = bw_w1_scale .contiguous ().transpose ([0 , 2 , 1 ]).contiguous ()
1050
+
1027
1051
# quant do1
1028
1052
do1_fp8 , do1_scale = paddle .incubate .nn .functional .fp8_quant_blockwise (
1029
1053
do1 , output_scale_transpose = True , quant_method = "1x128" , input_transpose = False
0 commit comments