3232FP8_MAX = torch .finfo (torch .float8_e4m3fn ).max
3333
3434
35- def fp8_per_tensor_gemm (quant_module , input , bias = None ):
36- """GEMM function for fp8 per tensor quantization."""
35+ @torch .compile (dynamic = True )
36+ def _to_fp8 (x , scale ):
37+ return (x / scale ).clamp (FP8_MIN , FP8_MAX ).to (torch .float8_e4m3fn )
38+
39+
40+ @torch .compile (dynamic = True )
41+ def _fp8_gemm_impl (input , weight_fp8 , scale_a , scale_b , bias = None ):
42+ input_shape = input .shape
43+ input_fp8 = _to_fp8 (input , scale_a ).reshape (- 1 , input_shape [- 1 ])
44+ weight_fp8_t = weight_fp8 .reshape (- 1 , weight_fp8 .shape [- 1 ]).t ()
45+ output = torch ._scaled_mm (
46+ input_fp8 ,
47+ weight_fp8_t ,
48+ scale_a = scale_a ,
49+ scale_b = scale_b ,
50+ bias = bias ,
51+ out_dtype = input .dtype ,
52+ use_fast_accum = True ,
53+ )
54+ return output .reshape (* input_shape [:- 1 ], output .shape [- 1 ])
3755
38- @torch .compile (dynamic = True )
39- def _to_fp8 (x , scale ):
40- return (x / scale ).clamp (FP8_MIN , FP8_MAX ).to (torch .float8_e4m3fn )
41-
42- @torch .compile (dynamic = True )
43- def _fp8_gemm_impl (input , weight_fp8 , scale_a , scale_b , bias = None ):
44- input_shape = input .shape
45- input_fp8 = _to_fp8 (input , scale_a ).reshape (- 1 , input_shape [- 1 ])
46- weight_fp8_t = weight_fp8 .reshape (- 1 , weight_fp8 .shape [- 1 ]).t ()
47- output = torch ._scaled_mm (
48- input_fp8 ,
49- weight_fp8_t ,
50- scale_a = scale_a ,
51- scale_b = scale_b ,
52- bias = bias ,
53- out_dtype = input .dtype ,
54- use_fast_accum = True ,
55- )
56- return output .reshape (* input_shape [:- 1 ], output .shape [- 1 ])
5756
57+ def fp8_per_tensor_gemm (quant_module , input , bias = None ):
58+ """GEMM function for fp8 per tensor quantization."""
5859 cached_scale_a = (
5960 hasattr (quant_module , "_scale_a" ) and quant_module .input_quantizer .amax is not None
6061 )
6162
6263 if not cached_scale_a :
63- input_amax = quant_module .input_quantizer .amax or reduce_amax (input )
64+ input_amax = quant_module .input_quantizer .amax
65+ if input_amax is None :
66+ input_amax = reduce_amax (input )
6467 assert input_amax != 0
6568 quant_module ._scale_a = (input_amax .float () / 448.0 ).to (device = input .device )
6669
@@ -69,7 +72,9 @@ def _fp8_gemm_impl(input, weight_fp8, scale_a, scale_b, bias=None):
6972 )
7073
7174 if not cached_scale_b :
72- weight_amax = quant_module .weight_quantizer .amax or reduce_amax (quant_module .weight )
75+ weight_amax = quant_module .weight_quantizer .amax
76+ if weight_amax is None :
77+ weight_amax = reduce_amax (quant_module .weight )
7378 assert weight_amax != 0
7479 quant_module ._scale_b = (weight_amax .float () / 448.0 ).to (device = quant_module .weight .device )
7580
@@ -146,9 +151,9 @@ def forward(
146151 ctx .save_for_backward (
147152 input_tensor if weight .requires_grad else None ,
148153 weight if input_tensor .requires_grad else None ,
149- torch .empty (0 , dtype = torch .uint8 ) if bias is not None and bias .requires_grad else None ,
150154 getattr (quant_module .weight_quantizer , "_scale" , None ),
151155 )
156+ ctx .compute_bias_grad = bias is not None and bias .requires_grad
152157 ctx .block_sizes = getattr (quant_module .weight_quantizer , "_block_sizes" , None )
153158
154159 ctx .allreduce_dgrad = allreduce_dgrad
@@ -166,7 +171,7 @@ def backward(ctx, grad_outputs):
166171 dequantize it to compute the input gradient. If the weight is not compressed, we will save
167172 the unquantized weight and use it directly to compute the input gradient.
168173 """
169- input_tensor , weight , compute_bias_grad , scale = ctx .saved_tensors
174+ input_tensor , weight , scale = ctx .saved_tensors
170175 grad_input = grad_weight = grad_bias = None
171176 if weight is not None :
172177 if isinstance (weight , QTensorWrapper ):
@@ -175,8 +180,10 @@ def backward(ctx, grad_outputs):
175180 weight = weight .dequantize (scale = scale , block_sizes = ctx .block_sizes )
176181 grad_input = grad_outputs @ weight
177182 if input_tensor is not None :
178- grad_weight = grad_outputs .transpose (- 2 , 1 ) @ input_tensor
179- if compute_bias_grad is not None :
183+ grad_weight = grad_outputs .reshape (- 1 , grad_outputs .shape [- 1 ]).T @ input_tensor .reshape (
184+ - 1 , input_tensor .shape [- 1 ]
185+ )
186+ if ctx .compute_bias_grad :
180187 # Sum all dimensions except the last one
181188 grad_bias = grad_outputs .sum (dim = list (range (grad_outputs .dim () - 1 )))
182189
0 commit comments