diff --git a/src/brevitas/export/common/handler/qcdq.py b/src/brevitas/export/common/handler/qcdq.py index 1a954279f..28bc34b2b 100644 --- a/src/brevitas/export/common/handler/qcdq.py +++ b/src/brevitas/export/common/handler/qcdq.py @@ -73,6 +73,8 @@ def clip_fn(self, x, min_val, max_val): class FloatQMixin(ABC): + export_fake_quantized = False + @abstractmethod def quantize_fn(self, x, scale, zero_point, dtype, axis): pass @@ -98,6 +100,8 @@ def signed_dtype(cls, exponent_bit_width, mantissa_bit_width, is_ocp, is_fnuz): class QMixin(BitWidthHandlerMixin, ABC): + export_fake_quantized = False + @classmethod @abstractmethod def uint8_dtype(cls): @@ -228,6 +232,7 @@ def prepare_quantize_from_minifloat(self, module): def prepare_for_export(self, module): if module.is_quant_enabled: + assert not self.export_fake_quantized, "Activation quantization does not support fake quantization export" self.validate(module) if self._export_q_node: self.prepare_quantize_from_floating_point(module) @@ -262,12 +267,16 @@ def prepare_for_export(self, module): quant_weight.mantissa_bit_width, module.is_ocp, module.is_fnuz) + if self.export_fake_quantized: + self.symbolic_kwargs['fake_quant_weights'] = quant_weight.value else: self.symbolic_kwargs = None def quantize_from_floating_point(self, x: Tensor): # Workaround for equal_cpu RuntimeError quantize_symbolic_kwargs = self.symbolic_kwargs['quantize_symbolic_kwargs'] + if self.export_fake_quantized: + x = self.symbolic_kwargs['fake_quant_weights'] # Before quantization, cast input to float32 if self.scale_dtype == torch.float16 or self.scale_dtype == torch.bfloat16: x = self.cast_fn(x, torch.float32) @@ -337,6 +346,8 @@ def prepare_quantize_from_floating_point(self, module): scale = self.cast_fn(scale, torch.float32) self.symbolic_kwargs['quantize_symbolic_kwargs'] = self.quantize_symbolic_kwargs( scale, quant_weight.zero_point, quant_weight.bit_width, module.is_signed) + if self.export_fake_quantized: + self.symbolic_kwargs['fake_quant_weights'] = quant_weight.value def prepare_quantize_from_integer(self, module): int_weights = { @@ -373,6 +384,8 @@ def prepare_for_export(self, module): def quantize_from_floating_point(self, x: Tensor): quantize_symbolic_kwargs = self.symbolic_kwargs['quantize_symbolic_kwargs'] # Before quantization, cast input to float32 + if self.export_fake_quantized: + x = self.symbolic_kwargs['fake_quant_weights'] if self.scale_dtype == torch.float16 or self.scale_dtype == torch.bfloat16: x = self.cast_fn(x, torch.float32) x = self.quantize_fn(x, *quantize_symbolic_kwargs.values()) @@ -450,6 +463,7 @@ def quantize_symbolic_kwargs( def prepare_for_export(self, module): if module.is_quant_enabled: + assert not self.export_fake_quantized, "Activation quantization does not support fake quantization export" self.validate(module) self.symbolic_kwargs['exponent_bit_width'] = module.exponent_bit_width() self.symbolic_kwargs['mantissa_bit_width'] = module.mantissa_bit_width() @@ -529,6 +543,7 @@ class QCDQCastActQuantProxyHandlerMixin(QMixin, CDQCastProxyHandlerMixin, ABC): handled_layer = ActQuantProxyFromInjector def quantize_symbolic_kwargs(cls, scale, zero_point, bit_width, is_signed): + # compute axis before redefining scale axis = cls.quant_axis(scale) scale = to_0dim_if_scalar(scale.flatten()) @@ -544,6 +559,7 @@ def quantize_symbolic_kwargs(cls, scale, zero_point, bit_width, is_signed): def prepare_for_export(self, module): if module.is_quant_enabled: + assert not self.export_fake_quantized, "Activation quantization does not support fake quantization export" self.validate(module) self.symbolic_kwargs['bit_width'] = module.bit_width() @@ -779,6 +795,7 @@ def validate(self, module): def prepare_for_export(self, module: TruncQuantProxyFromInjector): if module.is_quant_enabled: + assert not self.export_fake_quantized, "Activation quantization does not support fake quantization export" self.validate(module) self.symbolic_kwargs = { 'narrow_range': module.is_narrow_range, diff --git a/src/brevitas/export/onnx/standard/qcdq/handler.py b/src/brevitas/export/onnx/standard/qcdq/handler.py index e13ac73d1..4ec8a439c 100644 --- a/src/brevitas/export/onnx/standard/qcdq/handler.py +++ b/src/brevitas/export/onnx/standard/qcdq/handler.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: BSD-3-Clause from abc import ABC +import warnings from warnings import warn import torch @@ -74,7 +75,10 @@ class StdFloatQCDQCastONNXMixin(FloatQMixin, StdFloatCDQCastONNXMixin, ABC): def validate(self, module): if getattr(self, '_export_q_node', True): - assert module.rounding_mode.upper() == 'ROUND', 'Only round to nearest even supported' + if module.rounding_mode.upper() != 'ROUND': + warnings.warn("Exporting different rounding function than ROUND requires exporting" \ + " and storing fake quantized weights. This could cause OOM issues.") + self.export_fake_quantized = True super().validate(module) def quantize_fn(self, x, scale, zero_point, dtype, axis): @@ -100,7 +104,8 @@ def validate(self, module): # ONNX QuantizeLinear supports only 8b output with round to nearest even. # Below 8b quantization is supported through clipping. if getattr(self, '_export_q_node', True): - assert module.rounding_mode.upper() == 'ROUND', 'Only round to nearest even supported' + if module.rounding_mode.upper() != 'ROUND': + self.export_fake_quantized = True assert not module.is_groupwise, "Export with Per Group quantization not supported" self.validate_8b_bit_width(module.bit_width(), le_then=True) diff --git a/tests/brevitas_ort/quant_module_cases.py b/tests/brevitas_ort/quant_module_cases.py index 4cf6944c0..0cf6e370e 100644 --- a/tests/brevitas_ort/quant_module_cases.py +++ b/tests/brevitas_ort/quant_module_cases.py @@ -15,13 +15,22 @@ class QuantWBIOLCases: + @parametrize( + 'rounding_type', ['round', 'floor'], ids=[f'rtype_{r}' for r in ['round', 'floor']]) @parametrize('impl', QUANT_WBIOL_IMPL, ids=[f'{c.__name__}' for c in QUANT_WBIOL_IMPL]) @parametrize('input_bit_width', BIT_WIDTHS, ids=[f'i{b}' for b in BIT_WIDTHS]) @parametrize('weight_bit_width', BIT_WIDTHS, ids=[f'w{b}' for b in BIT_WIDTHS]) @parametrize('output_bit_width', BIT_WIDTHS, ids=[f'o{b}' for b in BIT_WIDTHS]) @parametrize('quantizers', WBIOL_QUANTIZERS.values(), ids=list(WBIOL_QUANTIZERS.keys())) def case_quant_wbiol( - self, impl, input_bit_width, weight_bit_width, output_bit_width, quantizers, request): + self, + rounding_type, + impl, + input_bit_width, + weight_bit_width, + output_bit_width, + quantizers, + request): # Change the case_id based on current value of Parameters set_case_id(request.node.callspec.id, QuantWBIOLCases.case_quant_wbiol) @@ -29,9 +38,9 @@ def case_quant_wbiol( weight_quant, io_quant = quantizers is_fp8 = weight_quant == Fp8e4m3OCPWeightPerTensorFloat is_dynamic = io_quant == ShiftedUint8DynamicActPerTensorFloat - if is_fp8: + if is_fp8 or rounding_type == 'floor': if weight_bit_width < 8 or input_bit_width < 8 or output_bit_width < 8: - pytest.skip('FP8 export requires total bitwidth equal to 8') + pytest.skip('FP8 export and FLOOR rounding require all bitwidths equal to 8') torch.use_deterministic_algorithms(False) else: torch.use_deterministic_algorithms(True) @@ -60,6 +69,7 @@ def __init__(self): input_bit_width=input_bit_width, output_bit_width=output_bit_width, bias_quant=bias_quantizer, + weight_float_to_int_impl_type=rounding_type, return_quant_tensor=return_quant_tensor) self.conv.weight.data.uniform_(-0.01, 0.01) diff --git a/tests/brevitas_ort/test_quant_module.py b/tests/brevitas_ort/test_quant_module.py index 693b9274d..9202e8a8d 100644 --- a/tests/brevitas_ort/test_quant_module.py +++ b/tests/brevitas_ort/test_quant_module.py @@ -3,6 +3,7 @@ from functools import reduce from operator import mul +import os from packaging.version import parse import pytest @@ -13,6 +14,7 @@ from brevitas import torch_version from tests.marker import requires_pt_ge +from ..export_fixture import rm_onnx from .common import * from .quant_module_cases import QuantAvgPoolCases from .quant_module_cases import QuantRecurrentCases @@ -25,13 +27,18 @@ def test_ort_wbiol(model, export_type, current_cases): cases_generator_func = current_cases['model'][1] case_id = get_case_id(cases_generator_func) + rounding = case_id.split('-')[0] impl = case_id.split('-')[ - -2] # Inverse list of definition, 'export_type' is -1, 'impl' is -2, etc. - quantizer = case_id.split('-')[-6] - o_bit_width = case_id.split('-')[-5] - i_bit_width = case_id.split('-')[-3] + -3] # Inverse list of definition, 'export_type' is -1, 'impl' is -2, etc. + quantizer = case_id.split('-')[-7] + o_bit_width = case_id.split('-')[-6] + i_bit_width = case_id.split('-')[-4] onnx_opset = 14 export_q_weight = False + + if rounding == 'round': + export_q_weight = True + if 'per_channel' in quantizer and 'asymmetric' in quantizer: pytest.skip('Per-channel zero-point is not well supported in ORT.') if 'QuantLinear' in impl and 'asymmetric' in quantizer: @@ -56,7 +63,7 @@ def test_ort_wbiol(model, export_type, current_cases): elif impl in ('QuantConv3d', 'QuantConvTranspose3d'): in_size = (1, IN_CH, FEATURES, FEATURES, FEATURES) else: - raise RuntimeError("Unsupported operation") + raise RuntimeError(f"Unsupported operation {impl}") inp = gen_linspaced_data(reduce(mul, in_size), -1, 1).reshape(in_size) @@ -73,6 +80,8 @@ def test_ort_wbiol(model, export_type, current_cases): onnx_opset=onnx_opset, export_q_weight=export_q_weight) + rm_onnx(export_name) + @parametrize_with_cases('model', cases=QuantAvgPoolCases) @requires_pt_ge('1.10') @@ -84,6 +93,7 @@ def test_ort_avgpool(model, current_cases): export_name = 'qcdq_quant_avgpool.onnx' assert is_brevitas_ort_close( model, inp, export_name, 'qcdq', tolerance=INT_TOLERANCE, first_output_only=True) + rm_onnx(export_name) @parametrize_with_cases('model', cases=QuantRecurrentCases) @@ -105,3 +115,4 @@ def test_ort_lstm(model, export_type, current_cases): model.eval() export_name = f'lstm_export_{case_id}.onnx' assert is_brevitas_ort_close(model, inp, export_name, export_type, tolerance=FLOAT_TOLERANCE) + rm_onnx(export_name)