diff --git a/modelopt/torch/quantization/extensions.py b/modelopt/torch/quantization/extensions.py index f108e4192..003703567 100644 --- a/modelopt/torch/quantization/extensions.py +++ b/modelopt/torch/quantization/extensions.py @@ -41,7 +41,7 @@ def get_cuda_ext_fp8(raise_if_failed: bool = False): if not hasattr(get_cuda_ext_fp8, "extension"): get_cuda_ext_fp8.extension = load_cpp_extension( # type:ignore[attr-defined] name="modelopt_cuda_ext_fp8", - sources=[path / "src/tensor_quant_fp8.cpp", path / "src/tensor_quant_gpu_fp8.cu"], + sources=[path / "src/tensor_quant_gpu_fp8.cu"], cuda_version_specifiers=">=11.8", fail_msg=( "CUDA extension for FP8 quantization could not be built and loaded, FP8 simulated" diff --git a/modelopt/torch/quantization/src/tensor_quant.h b/modelopt/torch/quantization/src/tensor_quant.h index de0c9a9cd..625566032 100644 --- a/modelopt/torch/quantization/src/tensor_quant.h +++ b/modelopt/torch/quantization/src/tensor_quant.h @@ -22,7 +22,6 @@ void fake_tensor_quant_cuda_inplace(at::Tensor, at::Tensor, int, bool, bool); at::Tensor fake_tensor_quant_cuda(at::Tensor, at::Tensor, int, bool, bool); at::Tensor fake_tensor_quant_with_axis_cuda(at::Tensor, at::Tensor, int, int, bool, bool); float bits_to_bound(int, int); -at::Tensor fake_e4m3fy_cuda(at::Tensor inputs); // Dequantizes data using NF4 quantization scheme and per-block scaling factors. // diff --git a/modelopt/torch/quantization/src/tensor_quant_fp8.cpp b/modelopt/torch/quantization/src/tensor_quant_fp8.cpp deleted file mode 100644 index 2bf9669ac..000000000 --- a/modelopt/torch/quantization/src/tensor_quant_fp8.cpp +++ /dev/null @@ -1,48 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include - -at::Tensor fake_e4m3fy_cuda(at::Tensor inputs); -at::Tensor fused_fake_e4m3fy_cuda(at::Tensor inputs, at::Tensor amax, const float zero_threshold); - -at::Tensor fake_e4m3fy(at::Tensor inputs) { - if (inputs.is_cuda()) { - return fake_e4m3fy_cuda(inputs.contiguous()); - } else { - TORCH_CHECK(inputs.dtype() == at::ScalarType::Float); - TORCH_CHECK(inputs.is_contiguous()); - auto out = at::zeros_like(inputs); - for (int i = 0; i < inputs.numel(); ++i) { - out.data_ptr()[i] = - static_cast(static_cast<__nv_fp8_e4m3>(inputs.data_ptr()[i])); - } - return out; - } -} - -at::Tensor fused_fake_e4m3fy(at::Tensor inputs, at::Tensor amax, const float zero_threshold) { - return fused_fake_e4m3fy_cuda(inputs.contiguous(), amax, zero_threshold); -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("fake_e4m3fy", &fake_e4m3fy, "Reduce precision to E4M3", py::arg("inputs")); - m.def("fused_fake_e4m3fy", &fused_fake_e4m3fy, "Reduce precision to E4M3 (fused)", - py::arg("inputs"), py::arg("amax"), py::arg("zero_threshold")); -} diff --git a/modelopt/torch/quantization/src/tensor_quant_gpu.cu b/modelopt/torch/quantization/src/tensor_quant_gpu.cu index 0d646e081..f26931c00 100644 --- a/modelopt/torch/quantization/src/tensor_quant_gpu.cu +++ b/modelopt/torch/quantization/src/tensor_quant_gpu.cu @@ -16,6 +16,7 @@ */ #include +#include #include #include #include @@ -74,8 +75,9 @@ __global__ void fake_tensor_quant_kernel(const T *inputs, size_t n, T *outputs, void fake_tensor_quant_cuda_inplace(at::Tensor inputs, at::Tensor amax, int num_bits = 8, bool is_unsigned = false, bool narrow_range = true) { size_t numel = inputs.numel(); + auto stream = c10::cuda::getCurrentCUDAStream(); AT_DISPATCH_FLOATING_TYPES(inputs.type().scalarType(), "fake_tensor_quant_cuda_inplace", [&] { - fake_tensor_quant_kernel<<>>( + fake_tensor_quant_kernel<<>>( inputs.data_ptr(), numel, inputs.data_ptr(), amax.to(at::ScalarType::Float).data_ptr(), num_bits, is_unsigned, narrow_range); }); @@ -85,8 +87,9 @@ at::Tensor fake_tensor_quant_cuda(at::Tensor inputs, at::Tensor amax, int num_bi bool is_unsigned = false, bool narrow_range = true) { size_t numel = inputs.numel(); auto outputs = torch::empty_like(inputs); + auto stream = c10::cuda::getCurrentCUDAStream(); AT_DISPATCH_FLOATING_TYPES(inputs.type().scalarType(), "fake_tensor_quant_cuda", [&] { - fake_tensor_quant_kernel<<>>( + fake_tensor_quant_kernel<<>>( inputs.data_ptr(), numel, outputs.data_ptr(), amax.to(at::ScalarType::Float).data_ptr(), num_bits, is_unsigned, narrow_range); }); @@ -125,8 +128,10 @@ at::Tensor fake_tensor_quant_with_axis_cuda(at::Tensor inputs, at::Tensor amax, int outer_size = inputs.stride(axis); + auto stream = c10::cuda::getCurrentCUDAStream(); AT_DISPATCH_FLOATING_TYPES(inputs.type().scalarType(), "fake_tensor_quant_cuda_with_axis", [&] { - fake_tensor_quant_with_axis_cuda_kernel<<>>( + fake_tensor_quant_with_axis_cuda_kernel<<>>( inputs.data_ptr(), numel, outputs.data_ptr(), amax.to(at::ScalarType::Float).data_ptr(), axis_size, outer_size, num_bits, is_unsigned, narrow_range); diff --git a/modelopt/torch/quantization/src/tensor_quant_gpu_fp8.cu b/modelopt/torch/quantization/src/tensor_quant_gpu_fp8.cu index 9081d0e14..38a66562f 100644 --- a/modelopt/torch/quantization/src/tensor_quant_gpu_fp8.cu +++ b/modelopt/torch/quantization/src/tensor_quant_gpu_fp8.cu @@ -18,6 +18,7 @@ #include #include #include +#include #include #define BLOCK_SIZE 128 @@ -31,92 +32,77 @@ #define AT_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) -template __global__ void fake_e4m3fy_kernel(const T *inputs, size_t n, T *outputs) { +template +__global__ void fake_e4m3fy_kernel(const T *inputs, size_t n, const float *scale, + const float *inv_scale, T *outputs) { int tid = blockIdx.x * blockDim.x + threadIdx.x; for (int idx = 4 * tid; idx < 4 * (tid + 1) && idx < n; ++idx) { outputs[idx] = static_cast( - static_cast(static_cast<__nv_fp8_e4m3>(static_cast(inputs[idx])))); + static_cast(static_cast<__nv_fp8_e4m3>(static_cast(inputs[idx]) * scale[0])) * + inv_scale[0]); } } template -__global__ void fused_fake_e4m3fy_kernel(const T *inputs, size_t n, float *amax, - bool per_block_scaling_factor, size_t blocksize, - float zero_threshold, T *outputs) { +__global__ void fake_e4m3fy_with_axis_cuda_kernel(const T *inputs, size_t n, const float *scale, + const float *inv_scale, int axis_size, + int outer_size, T *outputs) { int tid = blockIdx.x * blockDim.x + threadIdx.x; for (int idx = 4 * tid; idx < 4 * (tid + 1) && idx < n; ++idx) { float x = static_cast(inputs[idx]); - // generate mask for zeroing tiny values - float x_abs = fabsf(x); - bool zero_mask = x_abs < zero_threshold; - - // grab the global scaling factor - size_t amax_idx = (per_block_scaling_factor) ? (idx / blocksize) : 0; - - // compute scale and inverse-scales - float scale = 448.f / (amax[amax_idx]); - float inv_scale = 1.f / scale; + int axis_id = (idx / outer_size) % axis_size; // compute the output - float output = static_cast(static_cast<__nv_fp8_e4m3>(scale * x)) * inv_scale; - - // zero out small values - if (zero_mask) { - output = 0.f; - } + float output = + static_cast(static_cast<__nv_fp8_e4m3>(scale[axis_id] * x)) * inv_scale[axis_id]; outputs[idx] = output; } } -at::Tensor fused_fake_e4m3fy_cuda(at::Tensor inputs, at::Tensor amax, const float zero_threshold) { - size_t numel = inputs.numel(); +at::Tensor fake_e4m3fy_with_axis(at::Tensor inputs, at::Tensor amax, int axis) { + inputs = inputs.contiguous(); + amax = amax.contiguous().to(at::kFloat); auto outputs = torch::empty_like(inputs); + size_t numel = inputs.numel(); + int axis_size = inputs.size(axis); + int outer_size = inputs.stride(axis); - bool per_block_scaling_factor = false; - size_t blocksize = numel; - - int amax_ndim = amax.dim(); - int input_ndim = inputs.dim(); - - // 3 options: - // 1. - // inputs[numel], amax[1] -> per-tensor scaling - // 2. - // inputs[numel], amax[numel/num_cols] -> per-row / per-channel scaling - // 3. - // inputs[numel/bs, bs], amax[numel/bs, 1] -> blockwise scaling - if (amax.numel() == 1) { - // case 1. - per_block_scaling_factor = false; - } else if (amax.numel() > 1 && (amax_ndim > 1 && (amax.size(-1) == amax.numel()))) { - // case 2. - per_block_scaling_factor = true; - blocksize = numel / amax.numel(); - } else { - throw std::runtime_error("invalid combination of inputs and amax shapes/sizes"); - } + auto scale = 448.f / amax; + auto inv_scale = 1.f / scale; auto stream = c10::cuda::getCurrentCUDAStream(); - - AT_DISPATCH_FLOATING_TYPES(inputs.type().scalarType(), "fused_fake_e4m3fy_cuda", [&] { - fused_fake_e4m3fy_kernel<<>>( - inputs.data_ptr(), numel, amax.data_ptr(), per_block_scaling_factor, - blocksize, zero_threshold, outputs.data_ptr()); + AT_DISPATCH_FLOATING_TYPES(inputs.type().scalarType(), "fake_e4m3fy_with_axis", [&] { + fake_e4m3fy_with_axis_cuda_kernel<<>>( + inputs.data_ptr(), numel, scale.data_ptr(), inv_scale.data_ptr(), + axis_size, outer_size, outputs.data_ptr()); }); + return outputs; } -at::Tensor fake_e4m3fy_cuda(at::Tensor inputs) { +at::Tensor fake_e4m3fy(at::Tensor inputs, at::Tensor amax) { + inputs = inputs.contiguous(); + amax = amax.view(-1).to(at::kFloat); size_t numel = inputs.numel(); + at::Tensor scale = 448.f / amax; + auto inv_scale = 1.f / scale; auto outputs = torch::empty_like(inputs); auto stream = c10::cuda::getCurrentCUDAStream(); - AT_DISPATCH_FLOATING_TYPES(inputs.type().scalarType(), "fake_e4m3fy_cuda", [&] { + AT_DISPATCH_FLOATING_TYPES(inputs.type().scalarType(), "fake_e4m3fy", [&] { fake_e4m3fy_kernel<<>>( - inputs.data_ptr(), numel, outputs.data_ptr()); + inputs.data_ptr(), numel, scale.data_ptr(), inv_scale.data_ptr(), + outputs.data_ptr()); }); return outputs; } + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("fake_e4m3fy", &fake_e4m3fy, "Reduce precision to E4M3", py::arg("inputs"), + py::arg("amax")); + m.def("fake_e4m3fy_with_axis", &fake_e4m3fy_with_axis, "Reduce precision to E4M3 (fused)", + py::arg("inputs"), py::arg("amax"), py::arg("axis")); +} diff --git a/modelopt/torch/quantization/tensor_quant.py b/modelopt/torch/quantization/tensor_quant.py index 9bd73d909..5f69e3999 100644 --- a/modelopt/torch/quantization/tensor_quant.py +++ b/modelopt/torch/quantization/tensor_quant.py @@ -42,10 +42,26 @@ DISABLE_TRITON_KERNEL = False +def _fp8_eager(x, amax=None): + dtype = x.dtype + if amax is not None: + scale = 448.0 / (amax.to(torch.float32)) + scale_inv = 1 / scale + x = x.to(torch.float32) * scale + x = x.to(torch.float8_e4m3fn) + if amax is not None: + x = x.to(torch.float32) * scale_inv + return x.to(dtype) + + +def fp8_eager(x, amax): + """Eager mode implementation of FP8 quantization.""" + return _fp8_eager(x, amax) + + def scaled_e4m3_impl( - inputs: torch.Tensor, # TODO: check support for multiple inputs - amax: torch.Tensor, - disable_fused_kernel=True, + inputs: torch.Tensor, + amax: torch.Tensor | None = None, ) -> torch.Tensor: """Implementation of fake quantizing input to FP8. @@ -56,44 +72,22 @@ def scaled_e4m3_impl( Returns: Input tensors faked quantized to FP8. """ - cuda_ext_fp8 = get_cuda_ext_fp8(raise_if_failed=True) - - def is_fusable(): - # ignore no scaling and shape([]) cases - if amax is None or len(amax.shape) == 0: - return False - else: - # can't have amax.shape = [1, 1, 4, 1] and the like - amax_last_dim_only = amax.numel() == amax.shape[-1] - # must be cuda - all_cuda = inputs.is_cuda and amax.is_cuda + if (not inputs.is_cuda) or amax is None or amax.squeeze().ndim > 1: + return fp8_eager(inputs, amax) - # also check explicit disable. - return amax_last_dim_only and all_cuda and (not disable_fused_kernel) + cuda_ext_fp8 = get_cuda_ext_fp8(raise_if_failed=False) + if cuda_ext_fp8 is None: + return fp8_eager(inputs, amax) with torch.cuda.device( None if inputs.device.index == torch.cuda.current_device() else inputs.device.index ): - # differentiate between fused & unfused cases - if is_fusable(): - zero_threshold = 1.0 / (1 << 24) - outputs = cuda_ext_fp8.fused_fake_e4m3fy(inputs, amax.float(), zero_threshold) - else: - zero_mask = inputs.abs() < 1.0 / (1 << 24) - - if amax is None: - outputs = cuda_ext_fp8.fake_e4m3fy(inputs) - else: - scale = 448.0 / amax - outputs = cuda_ext_fp8.fake_e4m3fy(inputs * scale) / scale - - # Zero out values that are tiny. - # Tiny values could lead to tiny amax and then large scale which cause overflow/saturation - # and won't go back to normal value after dividing by scale. The right behavior is to mark them - # as zero which also get rid of inf/nan - outputs[zero_mask] = 0.0 - - return outputs + if amax.numel() == 1: + outputs = cuda_ext_fp8.fake_e4m3fy(inputs, amax) + elif amax.squeeze().ndim == 1: + axis = amax.shape.index(amax.numel()) + outputs = cuda_ext_fp8.fake_e4m3fy_with_axis(inputs, amax.squeeze(), axis) + return outputs def fake_quant_impl( @@ -377,8 +371,8 @@ def forward( def legacy_quant_func(): # The LegacyFakeTensorQuantFunction support cpu and amax with any shape that can be broadcasted to inputs. - outputs, scale = _tensor_quant(inputs, amax, num_bits, unsigned, narrow_range) - return outputs / scale.to(inputs.dtype) + outputs = _tensor_quant(inputs, amax, num_bits, unsigned, narrow_range) + return outputs if not inputs.is_cuda: outputs = legacy_quant_func() @@ -577,136 +571,6 @@ def backward(ctx, grad_outputs): return _fake_quant_backward_function(ctx, grad_outputs, num_args=9) -class TensorQuantFunction(Function): - """A universal tensor quantization function. - - Take an input tensor, output an quantized tensor. The granularity of scale can be interpreted from the - shape of amax. - output_dtype indicates whether the quantized value will be stored in integer or float. The reason we want to store - it in float is the pytorch function takes the quantized value may not accept integer input, e.g. Conv2D. - - It uses 2^num_bits -1 values instead of 2^num_bits. e.g., for num_bits=8, it uses [-127, 127] instead of [-128, 127] - """ - - @staticmethod - @symbolic_helper.parse_args("v", "t", "t", "i", "b", "b", "s") - def symbolic( - g, - inputs, - amax, - bias=None, - num_bits=8, - unsigned=False, - narrow_range=True, - trt_high_precision_dtype=None, - ): - """ONNX symbolic function.""" - from .export_onnx import export_int8 - - return export_int8( - g, inputs, amax, num_bits, unsigned, narrow_range, trt_high_precision_dtype - ) - - @staticmethod - def forward( - ctx, - inputs, - amax, - bias=None, - num_bits=8, - unsigned=False, - narrow_range=True, - trt_high_precision_dtype=None, - ): - """Forward method. - - Follow tensorflow convention, max value is passed in and used to decide scale, instead of inputting scale - directly. Though inputting scale directly may be more natural to use. - - Args: - ctx: A Context object to store tensors for backward. - inputs: A Tensor of type float32. - amax: A Tensor of type float32. Inputs will be quantized within range [-amax, amax] - amax will be broadcasted to inputs tensor. - num_bits: A integer used to calculate scaling factor, scale = (2^(num_bits-1) - 1) / max - Effectively, it indicates how many integer bits is used to represent the value. Default 8. - output_dtype: A type of Tensor. torch.int32 or torch.float32. - unsigned: A boolean. Use unsigned integer range. E.g. [0, 255] for num_bits=8. Default False. - narrow_range: A boolean. Use symmetric integer range for signed quantization - E.g. [-127,127] instead of [-128,127] for num_bits=8. Default True. - - Returns: - outputs: A Tensor of type output_dtype. - scale: A Tensor of type float32. outputs / scale will dequantize outputs tensor. - - Raises: - ValueError: - """ - if bias is not None: - inputs = inputs - bias - - ctx.save_for_backward(inputs, amax) - - outputs, scale = _tensor_quant(inputs, amax, num_bits, unsigned, narrow_range) - # Check if scale overflows FP16 - if outputs.dtype == torch.half and scale.max() > 65504: - raise ValueError(f"scale is too large for FP16 with amax={amax}") - - if bias is not None: - outputs = outputs + bias - - return outputs, scale.to(inputs.dtype) - - @staticmethod - def backward(ctx, grad_outputs, grad_scale): - """Implements straight through estimation with clipping. - - For -amax <= input <= amax the gradient passes straight through, otherwise the gradient is zero. - - Args: - ctx: A Context object with saved tensors from forward. - grad_outputs: A tensor of gradient of outputs. - grad_scale: A tensor of gradient of scale. - - Returns: - grad_inputs: A tensor of gradient. - """ - inputs, amax = ctx.saved_tensors - zero = grad_outputs.new_zeros(1) # create a zero tensor with the same type and device - grad_inputs = torch.where(inputs.abs() <= amax, grad_outputs, zero) - return grad_inputs, None, None, None, None, None, None - - -class LegacyFakeTensorQuantFunction(Function): - """Fake version of TensorQuantFunction. - - See comments of TensorQuantFunction, arguments are the same. - """ - - @staticmethod - def forward(ctx, inputs, amax, bias, num_bits=8, unsigned=False, narrow_range=True): - """Forward method.""" - if bias is not None: - inputs = inputs - bias - - ctx.save_for_backward(inputs, amax) - - outputs, scale = _tensor_quant(inputs, amax, num_bits, unsigned, narrow_range) - - if bias is not None: - outputs = outputs + bias - - return outputs / scale.to(inputs.dtype) - - @staticmethod - def backward(ctx, grad_outputs): - """Implements straight through estimation.""" - inputs, amax = ctx.saved_tensors - zero = grad_outputs.new_zeros(1) - grad_inputs = torch.where(inputs.abs() <= amax, grad_outputs, zero) - return grad_inputs, None, None, None, None, None - - def _tensor_quant(inputs, amax, num_bits=8, unsigned=False, narrow_range=True): """Shared function body between TensorQuantFunction and FakeTensorQuantFunction.""" # Fine scale, per channel scale will be handled by broadcasting, which could be tricky. Pop a warning. @@ -715,10 +579,8 @@ def _tensor_quant(inputs, amax, num_bits=8, unsigned=False, narrow_range=True): # Computation can be done in FP32 to prevent potential over flow. input_dtype = inputs.dtype - if inputs.dtype == torch.half: - inputs = inputs.float() - if amax.dtype == torch.half: - amax = amax.float() + inputs = inputs.float() + amax = amax.float() min_amax = amax.min() if min_amax < 0: @@ -744,73 +606,12 @@ def _tensor_quant(inputs, amax, num_bits=8, unsigned=False, narrow_range=True): scale[zero_amax_mask] = ( 1.0 # Return 1 makes more sense for values quantized to 0 with amax=0 ) + outputs = outputs / scale - if input_dtype == torch.half: - outputs = outputs.half() - - return outputs, scale - - -class FakeAffineTensorQuantFunction(Function): - """Fake version of affine quantization. - - gemmlowp style scale+shift quantization. See more details in - https://github.com/google/gemmlowp/blob/master/doc/quantization.md. - - We DO NOT recommend affine quantization on weights for performance reason. There might be value to affine quantize - activation as it can be cancelled by bias and comes with no performance penalty. This functionality is only added - for experimental purpose. - """ - - @staticmethod - def forward(ctx, inputs, min_range, max_range, num_bits=8): - """As it will be only applied on activation with per tensor granularity, broadcast is not needed. - - Args: - ctx: Pytorch convention. - inputs: A Tensor of type float32. - min_range: A float. - max_range: A float. - num_bits: An integer - - Returns: - outputs: A Tensor of type output_dtype - """ - ctx.save_for_backward(inputs, min_range, max_range) - - step_size = (max_range - min_range) / (2.0**num_bits - 1) - - min_bound = -(2.0 ** (num_bits - 1)) - max_bound = 2.0 ** (num_bits - 1) - 1 - - quant_zero = torch.round(min_range / step_size) - min_bound - quantized = torch.round(inputs / step_size) - quant_zero - quantized = torch.clamp(quantized, min_bound, max_bound) - - outputs = (quantized + quant_zero) * step_size - - return outputs - - @staticmethod - def backward(ctx, grad_outputs): - """Implements straight through estimation with clipping. - - Args: - ctx: Pytorch convention. - grad_output: A tensor of gradient of outputs. - - Returns: - grad_inputs: A tensor of gradient - """ - inputs, min_range, max_range = ctx.saved_tensors - zero = grad_outputs.new_zeros(1) - grad_inputs = torch.where((inputs <= max_range) * (inputs >= min_range), grad_outputs, zero) - return grad_inputs, None, None, None + outputs = outputs.to(input_dtype) + return outputs -tensor_quant = TensorQuantFunction.apply -legacy_fake_tensor_quant = LegacyFakeTensorQuantFunction.apply fake_tensor_quant = FakeTensorQuantFunction.apply -fake_affine_tensor_quant = FakeAffineTensorQuantFunction.apply scaled_e4m3 = ScaledE4M3Function.apply dynamic_block_quant = DynamicBlockQuantizationFunction.apply diff --git a/tests/_test_utils/torch_quantization/tensor_quant_common.py b/tests/_test_utils/torch_quantization/tensor_quant_common.py index 413cba284..266f382d9 100644 --- a/tests/_test_utils/torch_quantization/tensor_quant_common.py +++ b/tests/_test_utils/torch_quantization/tensor_quant_common.py @@ -123,17 +123,6 @@ def test_full_range(self): assert torch.allclose(quant_x_test, quant_x_ref) -class TensorQuantTester(TensorQuantCommon): - func = tensor_quant.tensor_quant - is_fake = False - return_tuple = True - - def test_overflow_fp16(self): - x = torch.randn(31).to(self.device).half() - with pytest.raises(ValueError, match="scale is too large for FP16"): - _ = self.func(x, torch.tensor(1e-4).to(self.device).half(), None, 8, False) - - class FakeTensorQuantTester(TensorQuantCommon): func = tensor_quant.fake_tensor_quant is_fake = True @@ -145,72 +134,3 @@ def test_overflow_fp16(self): x, torch.tensor(1e-4).to(self.device).half(), 8, False ) assert not (torch.isinf(quant_x_test).any() or torch.isnan(quant_x_test).any()) - - @pytest.mark.parametrize("dtype", [torch.float32, torch.float16]) - @pytest.mark.parametrize("num_bits", [3, 4, 5, 7, 8, 11]) - @pytest.mark.parametrize("unsigned", [True, False]) - def test_against_legacy(self, dtype, num_bits, unsigned): - torch.manual_seed(123456) - x = torch.randn(3, 4, 5, 6).to(dtype).to(self.device) - - amax_torch = torch.tensor(0.7).to(self.device) - - if unsigned: - x = x.abs() - legacy_out = tensor_quant.legacy_fake_tensor_quant(x, amax_torch, None, num_bits, unsigned) - test_out = tensor_quant.fake_tensor_quant(x, amax_torch, None, num_bits, unsigned) - if dtype == torch.float16: - assert torch.allclose(legacy_out, test_out, rtol=1e-3, atol=1e-4) - else: - assert torch.allclose(legacy_out, test_out) - - def test_against_legacy_noncontiguous(self): - x = torch.randn(3, 4, 5, 6).to(self.device) - - amax_torch = torch.tensor(0.7).to(self.device) - - x_torch_noncontiguous = x[:, 2, :, 3] - assert not x_torch_noncontiguous.is_contiguous() - - legacy_out = tensor_quant.legacy_fake_tensor_quant(x_torch_noncontiguous, amax_torch, None) - test_out = tensor_quant.fake_tensor_quant(x_torch_noncontiguous, amax_torch, None) - assert torch.allclose(legacy_out, test_out, rtol=0, atol=0) - - @pytest.mark.parametrize("dtype", [torch.float32, torch.float16]) - @pytest.mark.parametrize("num_bits", [3, 4, 5, 7, 8, 11]) - @pytest.mark.parametrize("unsigned", [True, False]) - def test_against_legacy_with_axis(self, dtype, num_bits, unsigned): - x = torch.randn(3, 4, 5, 6).to(self.device).to(dtype) - - # amax along axis 1 - amax_torch = torch.tensor([0.8, 0.9, 0.7, 0.6]).to(self.device).view(1, -1, 1, 1) - - if unsigned: - x = x.abs() - legacy_out = tensor_quant.legacy_fake_tensor_quant(x, amax_torch, None, num_bits, unsigned) - test_out = tensor_quant.fake_tensor_quant(x, amax_torch, None, num_bits, unsigned) - assert torch.allclose( - legacy_out, test_out, atol=1e-3 if dtype == torch.float16 else 0, rtol=0 - ) - - -class FakeAffineTensorQuantTester: - device = None - - def test_simple_run(self): - x = torch.tensor([-1.0, -13.0, -101.0, -128.0, 0.0, 2.0, 5.0, 13.0, 93.0, 111.0, 127.0]).to( - self.device - ) - quant_x = tensor_quant.fake_affine_tensor_quant(x, torch.min(x), torch.max(x)) - assert torch.allclose(quant_x, x) - - def test_clip_gradient(self): - x = torch.randn(3, 7, requires_grad=True).to(self.device) - x.retain_grad() - xmin = x.min() / 2 - xmax = x.max() / 2 - x_in_range = (xmin <= x) * (x <= xmax) - quant_x = tensor_quant.fake_affine_tensor_quant(x, xmin, xmax, 8) - loss = torch.sum((quant_x - 0.5) ** 2) - loss.backward() - assert torch.equal(x.grad != 0, x_in_range) diff --git a/tests/gpu/torch/quantization/test_tensor_quant_cuda.py b/tests/gpu/torch/quantization/test_tensor_quant_cuda.py index 430b96783..9519824bb 100644 --- a/tests/gpu/torch/quantization/test_tensor_quant_cuda.py +++ b/tests/gpu/torch/quantization/test_tensor_quant_cuda.py @@ -18,23 +18,15 @@ import pytest import torch from _test_utils.torch_quantization.quant_utils import quant -from _test_utils.torch_quantization.tensor_quant_common import ( - FakeAffineTensorQuantTester, - FakeTensorQuantTester, - TensorQuantTester, -) +from _test_utils.torch_quantization.tensor_quant_common import FakeTensorQuantTester import modelopt.torch.quantization.triton as triton_kernel import modelopt.torch.quantization.utils as quant_utils from modelopt.torch.quantization import tensor_quant -from modelopt.torch.quantization.extensions import get_cuda_ext, get_cuda_ext_fp8, get_cuda_ext_mx +from modelopt.torch.quantization.extensions import get_cuda_ext, get_cuda_ext_mx from modelopt.torch.quantization.tensor_quant import mx_format_map -class TestTensorQuantCuda(TensorQuantTester): - device = "cuda" - - class TestFakeTensorQuantCuda(FakeTensorQuantTester): device = "cuda" @@ -47,10 +39,6 @@ def test_non_current_gpu(self, need_2_gpus): assert torch.allclose(quant_x, quant_x_ref) -class TestFakeAffineTensorQuantCuda(FakeAffineTensorQuantTester): - device = "cuda" - - class TestCudaExt: @pytest.mark.parametrize("num_bits", [3, 4, 5, 7, 8, 11]) @pytest.mark.parametrize("unsigned", [True, False]) @@ -120,48 +108,27 @@ def test_cuda_ext_tiny_amax(self): class TestScaledE4M3: - x = [ - [-2.0000, -1.8000, -1.6000, -1.4000, -1.2000], - [-1.0000, -0.8000, -0.6000, -0.4000, -0.2000], - [-0.0000, 0.2000, 0.4000, 0.6000, 0.8000], - [1.0000, 1.2000, 1.4000, 1.6000, 1.8000], - ] - - xq_unscaled = [ - [-2.0000, -1.7500, -1.6250, -1.3750, -1.2500], - [-1.0000, -0.8125, -0.6250, -0.4062, -0.2031], - [0.0000, 0.2031, 0.4062, 0.6250, 0.8125], - [1.0000, 1.2500, 1.3750, 1.6250, 1.7500], - ] - - xq_scaled = [ - [-2.0000, -1.8571, -1.5714, -1.4286, -1.1429], - [-1.0000, -0.7857, -0.5714, -0.3929, -0.1964], - [0.0000, 0.1964, 0.3929, 0.5714, 0.7857], - [1.0000, 1.1429, 1.4286, 1.5714, 1.8571], - ] - @pytest.mark.parametrize("device", ["cuda", "cpu"]) def test_e4m3_no_scale(self, device): - x = torch.tensor(TestScaledE4M3.x).to(device) - xq_ref = torch.tensor(TestScaledE4M3.xq_unscaled).to(device) + x = torch.randn(4, 4, device=device, dtype=torch.float32) + xq_ref = tensor_quant.fp8_eager(x, torch.tensor(448.0, device=x.device)) e4m3_x = tensor_quant.scaled_e4m3(x, None, None, 4, 3) assert torch.allclose(e4m3_x, xq_ref, atol=1e-4, rtol=1e-4) @pytest.mark.parametrize("device", ["cuda", "cpu"]) - def test_with_amax(self, device): - x = torch.tensor(TestScaledE4M3.x).to(device).unsqueeze(-1) - xq_ref = torch.tensor(TestScaledE4M3.xq_scaled).to(device).unsqueeze(-1) - + @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) + def test_with_amax(self, device, dtype): + if device == "cpu" and dtype != torch.float32: + pytest.skip("CPU does not support non-float32 dtype") + x = torch.randn(4, 4, device=device, dtype=dtype) amax = quant_utils.reduce_amax(x, axis=None, keepdims=True) - - e4m3_x = tensor_quant.scaled_e4m3(x, amax, None, 4, 3) - - assert torch.allclose(e4m3_x, xq_ref, atol=1e-4, rtol=1e-4) + xq_ref = tensor_quant.fp8_eager(x, amax) + xq_test = tensor_quant.scaled_e4m3(x, amax, None, 4, 3) + assert torch.allclose(xq_test, xq_ref) def test_e4m3_incontiguous(self): - x = torch.tensor(TestScaledE4M3.x).cuda().transpose(1, 0) - xq_ref = torch.tensor(TestScaledE4M3.xq_unscaled).cuda().transpose(1, 0) + x = torch.randn(4, 4).cuda().transpose(1, 0) + xq_ref = tensor_quant.fp8_eager(x, torch.tensor(448.0, device=x.device)) assert not x.is_contiguous() e4m3_x = tensor_quant.scaled_e4m3(x, None, None, 4, 3) assert torch.allclose(e4m3_x, xq_ref, atol=1e-4, rtol=1e-4) @@ -179,28 +146,21 @@ def test_backward(self, device): assert torch.allclose(quant_x.grad, x.grad) def test_non_current_gpu(self, need_2_gpus): + torch.cuda.set_device(0) device = torch.cuda.device_count() - 1 - assert torch.cuda.current_device() != device x = torch.randn(3, 4).cuda() - quant_x_ref = tensor_quant.scaled_e4m3(x, x.amax(), None, 4, 3) + quant_x_ref = tensor_quant.fp8_eager(x, torch.tensor(448.0, device=x.device)) x = x.cuda(device) - quant_x = tensor_quant.scaled_e4m3(x, x.amax(), None, 4, 3) - assert torch.allclose(quant_x, quant_x_ref.cuda(device)) - - def test_fused_e4m3_kernel(self): - cuda_ext_fp8 = get_cuda_ext_fp8() - x = torch.tensor(TestScaledE4M3.x).cuda() - xq_ref = torch.tensor(TestScaledE4M3.xq_scaled).cuda() - amax = torch.ones(1, x.shape[-1]).cuda() * x.abs().amax() - e4m3_x = cuda_ext_fp8.fused_fake_e4m3fy(x, amax.float(), 1.0 / (1 << 24)) - assert torch.allclose(e4m3_x, xq_ref, atol=1e-4, rtol=1e-4) - - def test_e4m3_kernel_non_last_axis(self): - x = torch.tensor(TestScaledE4M3.x).cuda() - xq_ref = torch.tensor(TestScaledE4M3.xq_scaled).cuda() - amax = torch.ones(x.shape[0], 1).cuda() * x.abs().amax() - e4m3_x = tensor_quant.scaled_e4m3(x, amax, None, 4, 3) - assert torch.allclose(e4m3_x, xq_ref, atol=1e-4, rtol=1e-4) + quant_x = tensor_quant.scaled_e4m3(x, None, None, 4, 3) + assert torch.allclose(quant_x.cuda(), quant_x_ref) + + @pytest.mark.parametrize("axis", [0, 1, 2]) + def test_e4m3_per_channel(self, axis): + x = torch.randn(4, 4, 4, dtype=torch.float32).cuda() + amax = x.abs().amax(dim=[ax for ax in range(x.ndim) if ax != axis], keepdim=True) + xq_ref = tensor_quant.fp8_eager(x, amax) + xq_test = tensor_quant.scaled_e4m3(x, amax, None, 4, 3) + assert torch.allclose(xq_test, xq_ref) class Testfp4: diff --git a/tests/unit/torch/quantization/test_tensor_quant_cpu.py b/tests/unit/torch/quantization/test_tensor_quant_cpu.py index d64226add..5e457739c 100644 --- a/tests/unit/torch/quantization/test_tensor_quant_cpu.py +++ b/tests/unit/torch/quantization/test_tensor_quant_cpu.py @@ -19,29 +19,17 @@ import pytest import torch from _test_utils.torch_quantization.models import SimpleLinear -from _test_utils.torch_quantization.tensor_quant_common import ( - FakeAffineTensorQuantTester, - FakeTensorQuantTester, - TensorQuantTester, -) +from _test_utils.torch_quantization.tensor_quant_common import FakeTensorQuantTester import modelopt.torch.quantization as mtq from modelopt.torch.quantization.config import QuantizerAttributeConfig from modelopt.torch.quantization.nn import TensorQuantizer -class TestTensorQuantCPU(TensorQuantTester): - device = "cpu" - - class TestFakeTensorQuantCPU(FakeTensorQuantTester): device = "cpu" -class TestFakeAffineTensorQuantCPU(FakeAffineTensorQuantTester): - device = "cpu" - - class TestQuantizerAttributeConfig: def test_scaled_mode(self): num_bits = np.random.randint(1, 16)