Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions src/brevitas/export/common/handler/qcdq.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ def clip_fn(self, x, min_val, max_val):

class FloatQMixin(ABC):

export_fake_quantized = False

@abstractmethod
def quantize_fn(self, x, scale, zero_point, dtype, axis):
pass
Expand All @@ -98,6 +100,8 @@ def signed_dtype(cls, exponent_bit_width, mantissa_bit_width, is_ocp, is_fnuz):

class QMixin(BitWidthHandlerMixin, ABC):

export_fake_quantized = False

@classmethod
@abstractmethod
def uint8_dtype(cls):
Expand Down Expand Up @@ -228,6 +232,7 @@ def prepare_quantize_from_minifloat(self, module):

def prepare_for_export(self, module):
if module.is_quant_enabled:
assert not self.export_fake_quantized, "Activation quantization does not support fake quantization export"
self.validate(module)
if self._export_q_node:
self.prepare_quantize_from_floating_point(module)
Expand Down Expand Up @@ -262,12 +267,16 @@ def prepare_for_export(self, module):
quant_weight.mantissa_bit_width,
module.is_ocp,
module.is_fnuz)
if self.export_fake_quantized:
self.symbolic_kwargs['fake_quant_weights'] = quant_weight.value
else:
self.symbolic_kwargs = None

def quantize_from_floating_point(self, x: Tensor):
# Workaround for equal_cpu RuntimeError
quantize_symbolic_kwargs = self.symbolic_kwargs['quantize_symbolic_kwargs']
if self.export_fake_quantized:
x = self.symbolic_kwargs['fake_quant_weights']
# Before quantization, cast input to float32
if self.scale_dtype == torch.float16 or self.scale_dtype == torch.bfloat16:
x = self.cast_fn(x, torch.float32)
Expand Down Expand Up @@ -337,6 +346,8 @@ def prepare_quantize_from_floating_point(self, module):
scale = self.cast_fn(scale, torch.float32)
self.symbolic_kwargs['quantize_symbolic_kwargs'] = self.quantize_symbolic_kwargs(
scale, quant_weight.zero_point, quant_weight.bit_width, module.is_signed)
if self.export_fake_quantized:
self.symbolic_kwargs['fake_quant_weights'] = quant_weight.value

def prepare_quantize_from_integer(self, module):
int_weights = {
Expand Down Expand Up @@ -373,6 +384,8 @@ def prepare_for_export(self, module):
def quantize_from_floating_point(self, x: Tensor):
quantize_symbolic_kwargs = self.symbolic_kwargs['quantize_symbolic_kwargs']
# Before quantization, cast input to float32
if self.export_fake_quantized:
x = self.symbolic_kwargs['fake_quant_weights']
if self.scale_dtype == torch.float16 or self.scale_dtype == torch.bfloat16:
x = self.cast_fn(x, torch.float32)
x = self.quantize_fn(x, *quantize_symbolic_kwargs.values())
Expand Down Expand Up @@ -450,6 +463,7 @@ def quantize_symbolic_kwargs(

def prepare_for_export(self, module):
if module.is_quant_enabled:
assert not self.export_fake_quantized, "Activation quantization does not support fake quantization export"
self.validate(module)
self.symbolic_kwargs['exponent_bit_width'] = module.exponent_bit_width()
self.symbolic_kwargs['mantissa_bit_width'] = module.mantissa_bit_width()
Expand Down Expand Up @@ -529,6 +543,7 @@ class QCDQCastActQuantProxyHandlerMixin(QMixin, CDQCastProxyHandlerMixin, ABC):
handled_layer = ActQuantProxyFromInjector

def quantize_symbolic_kwargs(cls, scale, zero_point, bit_width, is_signed):

# compute axis before redefining scale
axis = cls.quant_axis(scale)
scale = to_0dim_if_scalar(scale.flatten())
Expand All @@ -544,6 +559,7 @@ def quantize_symbolic_kwargs(cls, scale, zero_point, bit_width, is_signed):

def prepare_for_export(self, module):
if module.is_quant_enabled:
assert not self.export_fake_quantized, "Activation quantization does not support fake quantization export"
self.validate(module)
self.symbolic_kwargs['bit_width'] = module.bit_width()

Expand Down Expand Up @@ -779,6 +795,7 @@ def validate(self, module):

def prepare_for_export(self, module: TruncQuantProxyFromInjector):
if module.is_quant_enabled:
assert not self.export_fake_quantized, "Activation quantization does not support fake quantization export"
self.validate(module)
self.symbolic_kwargs = {
'narrow_range': module.is_narrow_range,
Expand Down
9 changes: 7 additions & 2 deletions src/brevitas/export/onnx/standard/qcdq/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-License-Identifier: BSD-3-Clause

from abc import ABC
import warnings
from warnings import warn

import torch
Expand Down Expand Up @@ -74,7 +75,10 @@ class StdFloatQCDQCastONNXMixin(FloatQMixin, StdFloatCDQCastONNXMixin, ABC):

def validate(self, module):
if getattr(self, '_export_q_node', True):
assert module.rounding_mode.upper() == 'ROUND', 'Only round to nearest even supported'
if module.rounding_mode.upper() != 'ROUND':
warnings.warn("Exporting different rounding function than ROUND requires exporting" \
" and storing fake quantized weights. This could cause OOM issues.")
self.export_fake_quantized = True
super().validate(module)

def quantize_fn(self, x, scale, zero_point, dtype, axis):
Expand All @@ -100,7 +104,8 @@ def validate(self, module):
# ONNX QuantizeLinear supports only 8b output with round to nearest even.
# Below 8b quantization is supported through clipping.
if getattr(self, '_export_q_node', True):
assert module.rounding_mode.upper() == 'ROUND', 'Only round to nearest even supported'
if module.rounding_mode.upper() != 'ROUND':
self.export_fake_quantized = True
assert not module.is_groupwise, "Export with Per Group quantization not supported"

self.validate_8b_bit_width(module.bit_width(), le_then=True)
Expand Down
16 changes: 13 additions & 3 deletions tests/brevitas_ort/quant_module_cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,32 @@

class QuantWBIOLCases:

@parametrize(
'rounding_type', ['round', 'floor'], ids=[f'rtype_{r}' for r in ['round', 'floor']])
@parametrize('impl', QUANT_WBIOL_IMPL, ids=[f'{c.__name__}' for c in QUANT_WBIOL_IMPL])
@parametrize('input_bit_width', BIT_WIDTHS, ids=[f'i{b}' for b in BIT_WIDTHS])
@parametrize('weight_bit_width', BIT_WIDTHS, ids=[f'w{b}' for b in BIT_WIDTHS])
@parametrize('output_bit_width', BIT_WIDTHS, ids=[f'o{b}' for b in BIT_WIDTHS])
@parametrize('quantizers', WBIOL_QUANTIZERS.values(), ids=list(WBIOL_QUANTIZERS.keys()))
def case_quant_wbiol(
self, impl, input_bit_width, weight_bit_width, output_bit_width, quantizers, request):
self,
rounding_type,
impl,
input_bit_width,
weight_bit_width,
output_bit_width,
quantizers,
request):

# Change the case_id based on current value of Parameters
set_case_id(request.node.callspec.id, QuantWBIOLCases.case_quant_wbiol)

weight_quant, io_quant = quantizers
is_fp8 = weight_quant == Fp8e4m3OCPWeightPerTensorFloat
is_dynamic = io_quant == ShiftedUint8DynamicActPerTensorFloat
if is_fp8:
if is_fp8 or rounding_type == 'floor':
if weight_bit_width < 8 or input_bit_width < 8 or output_bit_width < 8:
pytest.skip('FP8 export requires total bitwidth equal to 8')
pytest.skip('FP8 export and FLOOR rounding require all bitwidths equal to 8')
torch.use_deterministic_algorithms(False)
else:
torch.use_deterministic_algorithms(True)
Expand Down Expand Up @@ -60,6 +69,7 @@ def __init__(self):
input_bit_width=input_bit_width,
output_bit_width=output_bit_width,
bias_quant=bias_quantizer,
weight_float_to_int_impl_type=rounding_type,
return_quant_tensor=return_quant_tensor)
self.conv.weight.data.uniform_(-0.01, 0.01)

Expand Down
21 changes: 16 additions & 5 deletions tests/brevitas_ort/test_quant_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from functools import reduce
from operator import mul
import os

from packaging.version import parse
import pytest
Expand All @@ -13,6 +14,7 @@
from brevitas import torch_version
from tests.marker import requires_pt_ge

from ..export_fixture import rm_onnx
from .common import *
from .quant_module_cases import QuantAvgPoolCases
from .quant_module_cases import QuantRecurrentCases
Expand All @@ -25,13 +27,18 @@
def test_ort_wbiol(model, export_type, current_cases):
cases_generator_func = current_cases['model'][1]
case_id = get_case_id(cases_generator_func)
rounding = case_id.split('-')[0]
impl = case_id.split('-')[
-2] # Inverse list of definition, 'export_type' is -1, 'impl' is -2, etc.
quantizer = case_id.split('-')[-6]
o_bit_width = case_id.split('-')[-5]
i_bit_width = case_id.split('-')[-3]
-3] # Inverse list of definition, 'export_type' is -1, 'impl' is -2, etc.
quantizer = case_id.split('-')[-7]
o_bit_width = case_id.split('-')[-6]
i_bit_width = case_id.split('-')[-4]
onnx_opset = 14
export_q_weight = False

if rounding == 'round':
export_q_weight = True

if 'per_channel' in quantizer and 'asymmetric' in quantizer:
pytest.skip('Per-channel zero-point is not well supported in ORT.')
if 'QuantLinear' in impl and 'asymmetric' in quantizer:
Expand All @@ -56,7 +63,7 @@ def test_ort_wbiol(model, export_type, current_cases):
elif impl in ('QuantConv3d', 'QuantConvTranspose3d'):
in_size = (1, IN_CH, FEATURES, FEATURES, FEATURES)
else:
raise RuntimeError("Unsupported operation")
raise RuntimeError(f"Unsupported operation {impl}")

inp = gen_linspaced_data(reduce(mul, in_size), -1, 1).reshape(in_size)

Expand All @@ -73,6 +80,8 @@ def test_ort_wbiol(model, export_type, current_cases):
onnx_opset=onnx_opset,
export_q_weight=export_q_weight)

rm_onnx(export_name)


@parametrize_with_cases('model', cases=QuantAvgPoolCases)
@requires_pt_ge('1.10')
Expand All @@ -84,6 +93,7 @@ def test_ort_avgpool(model, current_cases):
export_name = 'qcdq_quant_avgpool.onnx'
assert is_brevitas_ort_close(
model, inp, export_name, 'qcdq', tolerance=INT_TOLERANCE, first_output_only=True)
rm_onnx(export_name)


@parametrize_with_cases('model', cases=QuantRecurrentCases)
Expand All @@ -105,3 +115,4 @@ def test_ort_lstm(model, export_type, current_cases):
model.eval()
export_name = f'lstm_export_{case_id}.onnx'
assert is_brevitas_ort_close(model, inp, export_name, export_type, tolerance=FLOAT_TOLERANCE)
rm_onnx(export_name)