32
32
FP8_MAX = torch .finfo (torch .float8_e4m3fn ).max
33
33
34
34
35
- def fp8_per_tensor_gemm (quant_module , input , bias = None ):
36
- """GEMM function for fp8 per tensor quantization."""
35
+ @torch .compile (dynamic = True )
36
+ def _to_fp8 (x , scale ):
37
+ return (x / scale ).clamp (FP8_MIN , FP8_MAX ).to (torch .float8_e4m3fn )
38
+
39
+
40
+ @torch .compile (dynamic = True )
41
+ def _fp8_gemm_impl (input , weight_fp8 , scale_a , scale_b , bias = None ):
42
+ input_shape = input .shape
43
+ input_fp8 = _to_fp8 (input , scale_a ).reshape (- 1 , input_shape [- 1 ])
44
+ weight_fp8_t = weight_fp8 .reshape (- 1 , weight_fp8 .shape [- 1 ]).t ()
45
+ output = torch ._scaled_mm (
46
+ input_fp8 ,
47
+ weight_fp8_t ,
48
+ scale_a = scale_a ,
49
+ scale_b = scale_b ,
50
+ bias = bias ,
51
+ out_dtype = input .dtype ,
52
+ use_fast_accum = True ,
53
+ )
54
+ return output .reshape (* input_shape [:- 1 ], output .shape [- 1 ])
37
55
38
- @torch .compile (dynamic = True )
39
- def _to_fp8 (x , scale ):
40
- return (x / scale ).clamp (FP8_MIN , FP8_MAX ).to (torch .float8_e4m3fn )
41
-
42
- @torch .compile (dynamic = True )
43
- def _fp8_gemm_impl (input , weight_fp8 , scale_a , scale_b , bias = None ):
44
- input_shape = input .shape
45
- input_fp8 = _to_fp8 (input , scale_a ).reshape (- 1 , input_shape [- 1 ])
46
- weight_fp8_t = weight_fp8 .reshape (- 1 , weight_fp8 .shape [- 1 ]).t ()
47
- output = torch ._scaled_mm (
48
- input_fp8 ,
49
- weight_fp8_t ,
50
- scale_a = scale_a ,
51
- scale_b = scale_b ,
52
- bias = bias ,
53
- out_dtype = input .dtype ,
54
- use_fast_accum = True ,
55
- )
56
- return output .reshape (* input_shape [:- 1 ], output .shape [- 1 ])
57
56
57
+ def fp8_per_tensor_gemm (quant_module , input , bias = None ):
58
+ """GEMM function for fp8 per tensor quantization."""
58
59
cached_scale_a = (
59
60
hasattr (quant_module , "_scale_a" ) and quant_module .input_quantizer .amax is not None
60
61
)
61
62
62
63
if not cached_scale_a :
63
- input_amax = quant_module .input_quantizer .amax or reduce_amax (input )
64
+ input_amax = quant_module .input_quantizer .amax
65
+ if input_amax is None :
66
+ input_amax = reduce_amax (input )
64
67
assert input_amax != 0
65
68
quant_module ._scale_a = (input_amax .float () / 448.0 ).to (device = input .device )
66
69
@@ -69,7 +72,9 @@ def _fp8_gemm_impl(input, weight_fp8, scale_a, scale_b, bias=None):
69
72
)
70
73
71
74
if not cached_scale_b :
72
- weight_amax = quant_module .weight_quantizer .amax or reduce_amax (quant_module .weight )
75
+ weight_amax = quant_module .weight_quantizer .amax
76
+ if weight_amax is None :
77
+ weight_amax = reduce_amax (quant_module .weight )
73
78
assert weight_amax != 0
74
79
quant_module ._scale_b = (weight_amax .float () / 448.0 ).to (device = quant_module .weight .device )
75
80
@@ -146,9 +151,9 @@ def forward(
146
151
ctx .save_for_backward (
147
152
input_tensor if weight .requires_grad else None ,
148
153
weight if input_tensor .requires_grad else None ,
149
- torch .empty (0 , dtype = torch .uint8 ) if bias is not None and bias .requires_grad else None ,
150
154
getattr (quant_module .weight_quantizer , "_scale" , None ),
151
155
)
156
+ ctx .compute_bias_grad = bias is not None and bias .requires_grad
152
157
ctx .block_sizes = getattr (quant_module .weight_quantizer , "_block_sizes" , None )
153
158
154
159
ctx .allreduce_dgrad = allreduce_dgrad
@@ -166,7 +171,7 @@ def backward(ctx, grad_outputs):
166
171
dequantize it to compute the input gradient. If the weight is not compressed, we will save
167
172
the unquantized weight and use it directly to compute the input gradient.
168
173
"""
169
- input_tensor , weight , compute_bias_grad , scale = ctx .saved_tensors
174
+ input_tensor , weight , scale = ctx .saved_tensors
170
175
grad_input = grad_weight = grad_bias = None
171
176
if weight is not None :
172
177
if isinstance (weight , QTensorWrapper ):
@@ -175,8 +180,10 @@ def backward(ctx, grad_outputs):
175
180
weight = weight .dequantize (scale = scale , block_sizes = ctx .block_sizes )
176
181
grad_input = grad_outputs @ weight
177
182
if input_tensor is not None :
178
- grad_weight = grad_outputs .transpose (- 2 , 1 ) @ input_tensor
179
- if compute_bias_grad is not None :
183
+ grad_weight = grad_outputs .reshape (- 1 , grad_outputs .shape [- 1 ]).T @ input_tensor .reshape (
184
+ - 1 , input_tensor .shape [- 1 ]
185
+ )
186
+ if ctx .compute_bias_grad :
180
187
# Sum all dimensions except the last one
181
188
grad_bias = grad_outputs .sum (dim = list (range (grad_outputs .dim () - 1 )))
182
189
0 commit comments