Skip to content

Commit ee03094

Browse files
committed
more cleanups; make amax float32 by default
Signed-off-by: realAsma <[email protected]>
1 parent c80416a commit ee03094

File tree

7 files changed

+67
-367
lines changed

7 files changed

+67
-367
lines changed

modelopt/torch/quantization/nn/modules/tensor_quantizer.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -241,8 +241,8 @@ def amax(self, value):
241241

242242
if not isinstance(value, torch.Tensor):
243243
value = torch.tensor(value)
244-
245-
if not hasattr(self, "_amax"):
244+
value = value.to(torch.float32)
245+
if not hasattr(self, "_amax") or self._amax.dtype != torch.float32:
246246
self.register_buffer("_amax", value.clone().detach())
247247
else:
248248
if self._amax.shape != value.shape:
@@ -265,7 +265,7 @@ def reset_bias(self):
265265
@property
266266
def step_size(self):
267267
"""Return step size for integer quantization."""
268-
if not hasattr(self, "_amax"):
268+
if self.amax is None:
269269
warnings.warn("step_size is undefined under dynamic amax mode!")
270270
return None
271271
assert isinstance(self._num_bits, int), (
@@ -516,10 +516,7 @@ def load_calib_amax(self, *args, **kwargs):
516516
err_msg
517517
+ " Passing 'strict=False' to `load_calib_amax()` will ignore the error."
518518
)
519-
if not hasattr(self, "_amax"):
520-
self.register_buffer("_amax", calib_amax.clone().detach())
521-
else:
522-
self._amax.data.copy_(calib_amax.clone().detach())
519+
self.amax = calib_amax
523520

524521
def load_calib_bias(self, *args, **kwargs):
525522
"""Load affine bias for quantization."""
@@ -537,8 +534,10 @@ def load_calib_bias(self, *args, **kwargs):
537534

538535
def _get_amax(self, inputs):
539536
"""Get amax from buffer or compute it dynamically."""
540-
if hasattr(self, "_amax"):
541-
amax = self._amax
537+
if self.amax is not None:
538+
amax = self.amax
539+
if amax.dtype != torch.float32:
540+
self.amax = amax.to(torch.float32)
542541
else:
543542
reduce_axis = quant_utils.convert_quantization_axis_to_reduce_axis(inputs, self._axis)
544543
amax = quant_utils.reduce_amax(inputs, axis=reduce_axis, keepdims=True).detach()
@@ -988,8 +987,6 @@ def _short_amax(self, fmt=".4f"):
988987
return "None"
989988
if not hasattr(self, "_amax"):
990989
return "dynamic"
991-
if self._amax is None:
992-
return "None"
993990
if self._amax.is_meta:
994991
return "meta"
995992
if self._amax.numel() == 1:

modelopt/torch/quantization/src/tensor_quant_fp8.cpp

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,19 +18,24 @@
1818
#include <ATen/ATen.h>
1919
#include <cuda_fp8.h>
2020
#include <torch/extension.h>
21+
#include <optional>
2122

22-
at::Tensor fake_e4m3fy_cuda(at::Tensor inputs, at::Tensor amax);
23+
at::Tensor fake_e4m3fy_cuda(at::Tensor inputs, std::optional<at::Tensor> amax);
2324
at::Tensor fake_e4m3fy_cuda_with_axis(at::Tensor inputs, at::Tensor amax, int axis);
2425

25-
at::Tensor fake_e4m3fy(at::Tensor inputs, at::Tensor amax) {
26-
TORCH_CHECK(amax.numel(), 1);
26+
at::Tensor fake_e4m3fy(at::Tensor inputs, std::optional<at::Tensor> amax) {
2727
inputs = inputs.contiguous();
28-
auto amax_view = amax.view(-1).to(at::kFloat);
28+
if (amax.has_value()) {
29+
amax = amax.value().view(-1).to(at::kFloat);
30+
}
2931
if (inputs.is_cuda()) {
30-
return fake_e4m3fy_cuda(inputs, amax_view);
32+
return fake_e4m3fy_cuda(inputs, amax);
3133
} else {
3234
TORCH_CHECK(inputs.dtype() == at::ScalarType::Float);
33-
float scale = 448.f / amax_view[0].item<float>();
35+
float scale = 1.f;
36+
if (amax.has_value()) {
37+
scale = 448.f / amax.value()[0].item<float>();
38+
}
3439
float inv_scale = 1.f / scale;
3540
auto out = at::zeros_like(inputs);
3641
for (int i = 0; i < inputs.numel(); ++i) {
@@ -49,7 +54,7 @@ at::Tensor fake_e4m3fy_with_axis(at::Tensor inputs, at::Tensor amax, int axis) {
4954

5055
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
5156
m.def("fake_e4m3fy", &fake_e4m3fy, "Reduce precision to E4M3", py::arg("inputs"),
52-
py::arg("amax"));
57+
py::arg("amax") = py::none());
5358
m.def("fake_e4m3fy_with_axis", &fake_e4m3fy_with_axis, "Reduce precision to E4M3 (fused)",
5459
py::arg("inputs"), py::arg("amax"), py::arg("axis"));
5560
}

modelopt/torch/quantization/src/tensor_quant_gpu_fp8.cu

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include <c10/cuda/CUDAStream.h>
2020
#include <cuda_fp8.h>
2121
#include <torch/extension.h>
22+
#include <optional>
2223

2324
#define BLOCK_SIZE 128
2425

@@ -80,9 +81,14 @@ at::Tensor fake_e4m3fy_cuda_with_axis(at::Tensor inputs, at::Tensor amax, int ax
8081
return outputs;
8182
}
8283

83-
at::Tensor fake_e4m3fy_cuda(at::Tensor inputs, at::Tensor amax) {
84+
at::Tensor fake_e4m3fy_cuda(at::Tensor inputs, std::optional<at::Tensor> amax_opt) {
8485
size_t numel = inputs.numel();
85-
auto scale = 448.f / amax;
86+
at::Tensor scale;
87+
if (amax_opt.has_value()) {
88+
scale = 448.f / amax_opt.value();
89+
} else {
90+
scale = at::ones({1}, inputs.options().dtype(at::kFloat));
91+
}
8692
auto inv_scale = 1.f / scale;
8793
auto outputs = torch::empty_like(inputs);
8894
auto stream = c10::cuda::getCurrentCUDAStream();

modelopt/torch/quantization/tensor_quant.py

Lines changed: 4 additions & 201 deletions
Original file line numberDiff line numberDiff line change
@@ -60,10 +60,7 @@ def scaled_e4m3_impl(
6060
with torch.cuda.device(
6161
None if inputs.device.index == torch.cuda.current_device() else inputs.device.index
6262
):
63-
if amax is None:
64-
# This adds overhead; however this is not a common use case.
65-
amax = torch.tensor(448.0, device=inputs.device, dtype=inputs.dtype)
66-
if amax.numel() == 1:
63+
if amax is None or amax.numel() == 1:
6764
outputs = cuda_ext_fp8.fake_e4m3fy(inputs, amax)
6865
else:
6966
if amax.squeeze().ndim > 1:
@@ -556,136 +553,6 @@ def backward(ctx, grad_outputs):
556553
return _fake_quant_backward_function(ctx, grad_outputs, num_args=9)
557554

558555

559-
class TensorQuantFunction(Function):
560-
"""A universal tensor quantization function.
561-
562-
Take an input tensor, output an quantized tensor. The granularity of scale can be interpreted from the
563-
shape of amax.
564-
output_dtype indicates whether the quantized value will be stored in integer or float. The reason we want to store
565-
it in float is the pytorch function takes the quantized value may not accept integer input, e.g. Conv2D.
566-
567-
It uses 2^num_bits -1 values instead of 2^num_bits. e.g., for num_bits=8, it uses [-127, 127] instead of [-128, 127]
568-
"""
569-
570-
@staticmethod
571-
@symbolic_helper.parse_args("v", "t", "t", "i", "b", "b", "s")
572-
def symbolic(
573-
g,
574-
inputs,
575-
amax,
576-
bias=None,
577-
num_bits=8,
578-
unsigned=False,
579-
narrow_range=True,
580-
trt_high_precision_dtype=None,
581-
):
582-
"""ONNX symbolic function."""
583-
from .export_onnx import export_int8
584-
585-
return export_int8(
586-
g, inputs, amax, num_bits, unsigned, narrow_range, trt_high_precision_dtype
587-
)
588-
589-
@staticmethod
590-
def forward(
591-
ctx,
592-
inputs,
593-
amax,
594-
bias=None,
595-
num_bits=8,
596-
unsigned=False,
597-
narrow_range=True,
598-
trt_high_precision_dtype=None,
599-
):
600-
"""Forward method.
601-
602-
Follow tensorflow convention, max value is passed in and used to decide scale, instead of inputting scale
603-
directly. Though inputting scale directly may be more natural to use.
604-
605-
Args:
606-
ctx: A Context object to store tensors for backward.
607-
inputs: A Tensor of type float32.
608-
amax: A Tensor of type float32. Inputs will be quantized within range [-amax, amax]
609-
amax will be broadcasted to inputs tensor.
610-
num_bits: A integer used to calculate scaling factor, scale = (2^(num_bits-1) - 1) / max
611-
Effectively, it indicates how many integer bits is used to represent the value. Default 8.
612-
output_dtype: A type of Tensor. torch.int32 or torch.float32.
613-
unsigned: A boolean. Use unsigned integer range. E.g. [0, 255] for num_bits=8. Default False.
614-
narrow_range: A boolean. Use symmetric integer range for signed quantization
615-
E.g. [-127,127] instead of [-128,127] for num_bits=8. Default True.
616-
617-
Returns:
618-
outputs: A Tensor of type output_dtype.
619-
scale: A Tensor of type float32. outputs / scale will dequantize outputs tensor.
620-
621-
Raises:
622-
ValueError:
623-
"""
624-
if bias is not None:
625-
inputs = inputs - bias
626-
627-
ctx.save_for_backward(inputs, amax)
628-
629-
outputs, scale = _tensor_quant(inputs, amax, num_bits, unsigned, narrow_range)
630-
# Check if scale overflows FP16
631-
if outputs.dtype == torch.half and scale.max() > 65504:
632-
raise ValueError(f"scale is too large for FP16 with amax={amax}")
633-
634-
if bias is not None:
635-
outputs = outputs + bias
636-
637-
return outputs, scale.to(inputs.dtype)
638-
639-
@staticmethod
640-
def backward(ctx, grad_outputs, grad_scale):
641-
"""Implements straight through estimation with clipping.
642-
643-
For -amax <= input <= amax the gradient passes straight through, otherwise the gradient is zero.
644-
645-
Args:
646-
ctx: A Context object with saved tensors from forward.
647-
grad_outputs: A tensor of gradient of outputs.
648-
grad_scale: A tensor of gradient of scale.
649-
650-
Returns:
651-
grad_inputs: A tensor of gradient.
652-
"""
653-
inputs, amax = ctx.saved_tensors
654-
zero = grad_outputs.new_zeros(1) # create a zero tensor with the same type and device
655-
grad_inputs = torch.where(inputs.abs() <= amax, grad_outputs, zero)
656-
return grad_inputs, None, None, None, None, None, None
657-
658-
659-
class LegacyFakeTensorQuantFunction(Function):
660-
"""Fake version of TensorQuantFunction.
661-
662-
See comments of TensorQuantFunction, arguments are the same.
663-
"""
664-
665-
@staticmethod
666-
def forward(ctx, inputs, amax, bias, num_bits=8, unsigned=False, narrow_range=True):
667-
"""Forward method."""
668-
if bias is not None:
669-
inputs = inputs - bias
670-
671-
ctx.save_for_backward(inputs, amax)
672-
673-
outputs, scale = _tensor_quant(inputs, amax, num_bits, unsigned, narrow_range)
674-
675-
if bias is not None:
676-
outputs = outputs + bias
677-
678-
return outputs / scale.to(inputs.dtype)
679-
680-
@staticmethod
681-
def backward(ctx, grad_outputs):
682-
"""Implements straight through estimation."""
683-
inputs, amax = ctx.saved_tensors
684-
zero = grad_outputs.new_zeros(1)
685-
grad_inputs = torch.where(inputs.abs() <= amax, grad_outputs, zero)
686-
return grad_inputs, None, None, None, None, None
687-
688-
689556
def _tensor_quant(inputs, amax, num_bits=8, unsigned=False, narrow_range=True):
690557
"""Shared function body between TensorQuantFunction and FakeTensorQuantFunction."""
691558
# Fine scale, per channel scale will be handled by broadcasting, which could be tricky. Pop a warning.
@@ -694,10 +561,8 @@ def _tensor_quant(inputs, amax, num_bits=8, unsigned=False, narrow_range=True):
694561

695562
# Computation can be done in FP32 to prevent potential over flow.
696563
input_dtype = inputs.dtype
697-
if inputs.dtype == torch.half:
698-
inputs = inputs.float()
699-
if amax.dtype == torch.half:
700-
amax = amax.float()
564+
inputs = inputs.float()
565+
amax = amax.float()
701566

702567
min_amax = amax.min()
703568
if min_amax < 0:
@@ -724,72 +589,10 @@ def _tensor_quant(inputs, amax, num_bits=8, unsigned=False, narrow_range=True):
724589
1.0 # Return 1 makes more sense for values quantized to 0 with amax=0
725590
)
726591

727-
if input_dtype == torch.half:
728-
outputs = outputs.half()
729-
592+
outputs = outputs.to(input_dtype)
730593
return outputs, scale
731594

732595

733-
class FakeAffineTensorQuantFunction(Function):
734-
"""Fake version of affine quantization.
735-
736-
gemmlowp style scale+shift quantization. See more details in
737-
https://github.com/google/gemmlowp/blob/master/doc/quantization.md.
738-
739-
We DO NOT recommend affine quantization on weights for performance reason. There might be value to affine quantize
740-
activation as it can be cancelled by bias and comes with no performance penalty. This functionality is only added
741-
for experimental purpose.
742-
"""
743-
744-
@staticmethod
745-
def forward(ctx, inputs, min_range, max_range, num_bits=8):
746-
"""As it will be only applied on activation with per tensor granularity, broadcast is not needed.
747-
748-
Args:
749-
ctx: Pytorch convention.
750-
inputs: A Tensor of type float32.
751-
min_range: A float.
752-
max_range: A float.
753-
num_bits: An integer
754-
755-
Returns:
756-
outputs: A Tensor of type output_dtype
757-
"""
758-
ctx.save_for_backward(inputs, min_range, max_range)
759-
760-
step_size = (max_range - min_range) / (2.0**num_bits - 1)
761-
762-
min_bound = -(2.0 ** (num_bits - 1))
763-
max_bound = 2.0 ** (num_bits - 1) - 1
764-
765-
quant_zero = torch.round(min_range / step_size) - min_bound
766-
quantized = torch.round(inputs / step_size) - quant_zero
767-
quantized = torch.clamp(quantized, min_bound, max_bound)
768-
769-
outputs = (quantized + quant_zero) * step_size
770-
771-
return outputs
772-
773-
@staticmethod
774-
def backward(ctx, grad_outputs):
775-
"""Implements straight through estimation with clipping.
776-
777-
Args:
778-
ctx: Pytorch convention.
779-
grad_output: A tensor of gradient of outputs.
780-
781-
Returns:
782-
grad_inputs: A tensor of gradient
783-
"""
784-
inputs, min_range, max_range = ctx.saved_tensors
785-
zero = grad_outputs.new_zeros(1)
786-
grad_inputs = torch.where((inputs <= max_range) * (inputs >= min_range), grad_outputs, zero)
787-
return grad_inputs, None, None, None
788-
789-
790-
tensor_quant = TensorQuantFunction.apply
791-
legacy_fake_tensor_quant = LegacyFakeTensorQuantFunction.apply
792596
fake_tensor_quant = FakeTensorQuantFunction.apply
793-
fake_affine_tensor_quant = FakeAffineTensorQuantFunction.apply
794597
scaled_e4m3 = ScaledE4M3Function.apply
795598
dynamic_block_quant = DynamicBlockQuantizationFunction.apply

0 commit comments

Comments
 (0)