@@ -72,39 +72,33 @@ def fp8_linear(
7272 ) -> torch .Tensor :
7373 device = input .device
7474 origin_dtype = input .dtype
75- scale_a = 1.0
75+ origin_shape = input .shape
76+ input = input .reshape (- 1 , origin_shape [- 1 ])
77+
78+ x_max = torch .max (torch .abs (input ), dim = - 1 , keepdim = True ).values
79+ fp8_max = 448.0
7680 # For float8_e4m3fnuz, the maximum representable value is half of that of e4m3fn.
7781 # To avoid overflow and ensure numerical compatibility during FP8 computation,
7882 # we scale down the input by 2.0 in advance.
7983 # This scaling will be compensated later during the final result scaling.
8084 if DTYPE_FP8 == torch .float8_e4m3fnuz :
81- scale_a = 2.0
82- input = input / scale_a
85+ fp8_max = fp8_max / 2.0
86+ scale_a = torch .clamp (x_max / fp8_max , min = 1.0 ).float ().to (device = device )
87+ scale_b = torch .ones ((weight .shape [0 ], 1 )).float ().to (device = device )
88+ input = input / scale_a
8389 input = input .to (DTYPE_FP8 )
8490 weight = weight .to (DTYPE_FP8 )
8591
86- if len (input .shape ) > 2 :
87- origin_shape = input .shape
88- input = input .reshape (- 1 , origin_shape [- 1 ])
89- result = torch ._scaled_mm (
90- input ,
91- weight .T ,
92- scale_a = torch .tensor (scale_a ).to (device = device ),
93- scale_b = torch .tensor (1.0 ).to (device = device ),
94- bias = bias ,
95- out_dtype = origin_dtype ,
96- )
97- new_shape = origin_shape [:- 1 ] + result .shape [- 1 :]
98- result = result .reshape (new_shape )
99- else :
100- result = torch ._scaled_mm (
101- input ,
102- weight .T ,
103- scale_a = torch .tensor (scale_a ).to (device = device ),
104- scale_b = torch .tensor (1.0 ).to (device = device ),
105- bias = bias ,
106- out_dtype = origin_dtype ,
107- )
92+ result = torch ._scaled_mm (
93+ input ,
94+ weight .T ,
95+ scale_a = scale_a ,
96+ scale_b = scale_b .T ,
97+ bias = bias ,
98+ out_dtype = origin_dtype ,
99+ )
100+ new_shape = origin_shape [:- 1 ] + result .shape [- 1 :]
101+ result = result .reshape (new_shape )
108102 return result
109103
110104 F .linear = fp8_linear
0 commit comments