Skip to content

Commit bd789d1

Browse files
committed
update 2.5.1
1 parent 09ff5cb commit bd789d1

File tree

12 files changed

+366
-169
lines changed

12 files changed

+366
-169
lines changed

mqbench/convert_deploy.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def convert_onnx(model: GraphModule, input_shape_dict, dummy_input, onnx_model_p
6868
input_names = list(dummy_input.keys())
6969
dummy_input = tuple(dummy_input.values())
7070
# Per-channel QuantizeLinear and DequantizeLinear is supported since opset 13
71-
opset_version = 13 if kwargs.get('deploy_to_qlinear', False) else 11
71+
opset_version = 13 if kwargs.get('deploy_to_qlinear', False) else 13
7272
with torch.no_grad():
7373
try:
7474
from torch.onnx.utils import ONNXCheckerError
@@ -159,7 +159,7 @@ def deploy_qparams_stpu(model: GraphModule, onnx_model_path, model_name, **kwarg
159159
remove_fakequantize_and_collect_params_stpu(onnx_model_path, model_name)
160160

161161

162-
def convert_deploy(model: GraphModule, backend_type: BackendType,
162+
def convert_deploy(model: GraphModule, backend_type: BackendType,
163163
input_shape_dict=None, dummy_input=None, output_path='./',
164164
model_name='mqbench_qmodel', deploy_to_qlinear=False, **extra_kwargs):
165165
r"""Convert model to onnx model and quantization params depends on backend.

mqbench/custom_quantizer/academic_quantizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import torch
77
from torch.fx import GraphModule
88
from torch.quantization import propagate_qconfig_
9-
from torch.quantization.fx.qconfig_utils import get_flattened_qconfig_dict
9+
from mqbench.quantization.qconfig_mapping_utils import get_flattened_qconfig_dict
1010

1111
from mqbench.utils import is_symmetric_quant, getitem2node
1212
from mqbench.utils.logger import logger

mqbench/custom_quantizer/model_quantizer.py

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,7 @@
2323
from torch.quantization.utils import (
2424
get_combined_dict
2525
)
26-
from torch.quantization.fx.qconfig_utils import (
27-
get_flattened_qconfig_dict
28-
)
26+
from mqbench.quantization.qconfig_mapping_utils import get_flattened_qconfig_dict
2927
from torch.quantization.quantize_fx import (
3028
_fuse_fx
3129
)
@@ -34,8 +32,12 @@
3432
from mqbench.utils.logger import logger
3533
from mqbench.utils.registry import register_model_quantizer
3634
from mqbench.prepare_by_platform import BackendType
37-
38-
35+
from torch.ao.quantization.backend_config import (
36+
BackendConfig,
37+
)
38+
from torch.ao.quantization.backend_config.utils import (
39+
get_module_to_qat_module,
40+
)
3941
@register_model_quantizer(BackendType.Tensorrt)
4042
@register_model_quantizer(BackendType.NNIE)
4143
class ModelQuantizer(object):
@@ -60,9 +62,9 @@ def __init__(self, extra_quantizer_dict, extra_fuse_dict):
6062
self.exclude_node_name = extra_quantizer_dict.get('exclude_node_name', [])
6163
self.extra_fuse_dict = extra_fuse_dict
6264

63-
def prepare(self, model: GraphModule, qconfig):
64-
model = _fuse_fx(model, self.extra_fuse_dict)
65-
model = self._weight_quant(model, qconfig)
65+
def prepare(self, model: GraphModule, qconfig, is_qat, backend_config):
66+
model = _fuse_fx(model, is_qat, self.extra_fuse_dict, backend_config)
67+
model = self._weight_quant(model, qconfig, backend_config)
6668
model = self._insert_fake_quantize_for_act_quant(model, qconfig)
6769
return model
6870

@@ -119,11 +121,11 @@ def _fix_succ_recursivly(self, args, target_node, inserted_node):
119121
else:
120122
raise NotImplementedError('{} can not be handled now.'.format(type(args)))
121123

122-
def _weight_quant(self, model: GraphModule, qconfig):
124+
def _weight_quant(self, model: GraphModule, qconfig, backend_config):
123125
logger.info("Replace module to qat module.")
124126
flattened_qconfig_dict = get_flattened_qconfig_dict({'': qconfig})
125127
propagate_qconfig_(model, flattened_qconfig_dict)
126-
self._qat_swap_modules(model, self.additional_qat_module_mapping)
128+
self._qat_swap_modules(model, self.additional_qat_module_mapping, backend_config)
127129
return model
128130

129131
@property
@@ -245,15 +247,18 @@ def _find_act_quants(self, model: GraphModule) -> List:
245247
node_need_to_quantize_output.append(_node)
246248
return node_need_to_quantize_output
247249

248-
def _qat_swap_modules(self, root: GraphModule, additional_qat_module_mapping: Dict[Callable, Callable]):
250+
def _qat_swap_modules(self, root: GraphModule, additional_qat_module_mapping: Dict[Callable, Callable], backend_config: BackendConfig):
251+
# all_mappings = get_combined_dict(
252+
# get_default_qat_module_mappings(), additional_qat_module_mapping)
249253
all_mappings = get_combined_dict(
250-
get_default_qat_module_mappings(), additional_qat_module_mapping)
251-
root = self._convert(root, all_mappings, inplace=True)
254+
get_module_to_qat_module(backend_config), additional_qat_module_mapping)
255+
root = self._convert(root, all_mappings, inplace=True, backend_config = backend_config)
252256
return root
253257

254-
def _convert(self, module, mapping=None, inplace=False, scope=''):
258+
def _convert(self, module, mapping=None, inplace=False, backend_config=None, scope=''):
255259
if mapping is None:
256-
mapping = get_default_static_quant_module_mappings()
260+
# mapping = get_default_static_quant_module_mappings()
261+
mapping = get_module_to_qat_module(backend_config)
257262

258263
if not inplace:
259264
module = copy.deepcopy(module)
@@ -266,7 +271,7 @@ def _convert(self, module, mapping=None, inplace=False, scope=''):
266271
continue
267272
if not isinstance(mod, _FusedModule):
268273
self._convert(mod, mapping, True, new_scope)
269-
reassign[name] = swap_module(mod, mapping, {})
274+
reassign[name] = swap_module(mod, mapping, {}, False)
270275
for key, value in reassign.items():
271276
module._modules[key] = value
272277

mqbench/custom_quantizer/openvino_quantizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import torch
77
from torch.fx import GraphModule
88
from torch.quantization import propagate_qconfig_
9-
from torch.quantization.fx.qconfig_utils import get_flattened_qconfig_dict
9+
from mqbench.quantization.qconfig_mapping_utils import get_flattened_qconfig_dict
1010
from torch.quantization.quantize_fx import _fuse_fx
1111

1212
from mqbench.utils import is_symmetric_quant

mqbench/custom_symbolic_opset.py

Lines changed: 135 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,143 @@
11
from torch.onnx import register_custom_op_symbolic
22

33
# Register symbolic op for torch.quantize_function op.
4+
import functools
5+
from torch.onnx._internal import jit_utils, registration
6+
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=13)
7+
from torch.onnx import (
8+
_type_utils,
9+
symbolic_helper,
10+
symbolic_opset9 as opset9,
11+
)
12+
import torch._C._onnx as _C_onnx
13+
import torch
14+
@_onnx_symbolic("aten::fake_quantize_per_tensor_affine")
15+
@symbolic_helper.parse_args("v", "v", "v", "i", "i")
16+
def fake_quantize_per_tensor_affine(
17+
g: jit_utils.GraphContext,
18+
inputs,
19+
scale,
20+
zero_point,
21+
quant_min=-128,
22+
quant_max=127,
23+
):
24+
# NOTE: (0, 127) is allowed as special case. PyTorch restricts activations to be in the range (0, 127).
25+
# https://github.com/pytorch/pytorch/blob/b34b192d6b97325c9f78e5995c48c8498ede34bd/torch/ao/quantization/observer.py#L1422
26+
# if (quant_min, quant_max) not in [(0, 255), (-128, 127), (0, 127)]:
27+
# raise errors.SymbolicValueError(
28+
# "For (quant_min, quant_max), ONNX allows only (0, 127), (0, 255) and (-128, 127). "
29+
# f"Got ({quant_min}, {quant_max})",
30+
# inputs,
31+
# )
32+
if quant_min == 0:
33+
zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.UINT8)
34+
else:
35+
zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.INT8)
36+
if (
37+
_type_utils.JitScalarType.from_value(scale, _type_utils.JitScalarType.UNDEFINED)
38+
!= _type_utils.JitScalarType.FLOAT
39+
):
40+
scale = g.op("Cast", scale, to_i=_C_onnx.TensorProtoDataType.FLOAT)
41+
quantized = g.op("QuantizeLinear", inputs, scale, zero_point)
42+
if (quant_min, quant_max) == (0, 127):
43+
quantized = g.op(
44+
"Clip",
45+
quantized,
46+
opset9.unused(g),
47+
g.op("Constant", value_t=torch.tensor(127, dtype=torch.uint8)),
48+
)
49+
return g.op("DequantizeLinear", quantized, scale, zero_point)
450

5-
def _fake_quantize_learnable_per_tensor_affine(g, x, scale, zero_point, quant_min, quant_max, grad_factor):
6-
return g.op("::LearnablePerTensorAffine", x, scale, zero_point, quant_min, quant_max)
51+
@_onnx_symbolic("aten::fake_quantize_per_channel_affine")
52+
@symbolic_helper.parse_args("v", "v", "v", "i", "i", "i")
53+
def fake_quantize_per_channel_affine(
54+
g: jit_utils.GraphContext,
55+
inputs,
56+
scale,
57+
zero_point,
58+
axis,
59+
quant_min=-128,
60+
quant_max=127,
61+
):
62+
# NOTE: (0, 127) is allowed as special case. PyTorch restricts activations to be in the range (0, 127).
63+
# https://github.com/pytorch/pytorch/blob/b34b192d6b97325c9f78e5995c48c8498ede34bd/torch/ao/quantization/observer.py#L1422
64+
# if (quant_min, quant_max) not in [(0, 255), (-128, 127), (0, 127)]:
65+
# raise errors.SymbolicValueError(
66+
# "For (quant_min, quant_max), ONNX allows only (0, 127), (0, 255) and (-128, 127). "
67+
# f"Got ({quant_min}, {quant_max})",
68+
# inputs,
69+
# )
70+
# ONNX defines zero_point to be int8 or uint8
71+
if quant_min == 0:
72+
zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.UINT8)
73+
else:
74+
zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.INT8)
75+
quantized = g.op("QuantizeLinear", inputs, scale, zero_point, axis_i=axis)
76+
if (quant_min, quant_max) == (0, 127):
77+
quantized = g.op(
78+
"Clip",
79+
quantized,
80+
opset9.unused(g),
81+
g.op("Constant", value_t=torch.tensor(127, dtype=torch.uint8)),
82+
)
83+
return g.op("DequantizeLinear", quantized, scale, zero_point, axis_i=axis)
784

85+
@_onnx_symbolic("aten::_fake_quantize_learnable_per_tensor_affine")
86+
@symbolic_helper.parse_args("v", "v", "v", "i", "i", "i")
87+
def _fake_quantize_learnable_per_tensor_affine(
88+
g: jit_utils.GraphContext,
89+
inputs,
90+
scale,
91+
zero_point,
92+
quant_min=-128,
93+
quant_max=127,
94+
grad_factor=0,
95+
):
96+
# NOTE: (0, 127) is allowed as special case. PyTorch restricts activations to be in the range (0, 127).
97+
# https://github.com/pytorch/pytorch/blob/b34b192d6b97325c9f78e5995c48c8498ede34bd/torch/ao/quantization/observer.py#L1422
98+
# if (quant_min, quant_max) not in [(0, 255), (-128, 127), (0, 127)]:
99+
# raise errors.SymbolicValueError(
100+
# "For (quant_min, quant_max), ONNX allows only (0, 127), (0, 255) and (-128, 127). "
101+
# f"Got ({quant_min}, {quant_max})",
102+
# inputs,
103+
# )
104+
if quant_min == 0:
105+
zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.UINT8)
106+
else:
107+
zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.INT8)
108+
if (
109+
_type_utils.JitScalarType.from_value(scale, _type_utils.JitScalarType.UNDEFINED)
110+
!= _type_utils.JitScalarType.FLOAT
111+
):
112+
scale = g.op("Cast", scale, to_i=_C_onnx.TensorProtoDataType.FLOAT)
113+
quantized = g.op("QuantizeLinear", inputs, scale, zero_point)
114+
if (quant_min, quant_max) == (0, 127):
115+
quantized = g.op(
116+
"Clip",
117+
quantized,
118+
opset9.unused(g),
119+
g.op("Constant", value_t=torch.tensor(127, dtype=torch.uint8)),
120+
)
121+
return g.op("DequantizeLinear", quantized, scale, zero_point)
8122

9-
register_custom_op_symbolic('::_fake_quantize_learnable_per_tensor_affine', _fake_quantize_learnable_per_tensor_affine, 11)
10123

124+
# def _fake_quantize_learnable_per_tensor_affine(g, x, scale, zero_point, quant_min, quant_max, grad_factor):
125+
# return g.op(x, scale, zero_point, quant_min, quant_max)
126+
#
127+
#
128+
# register_custom_op_symbolic('::_fake_quantize_learnable_per_tensor_affine', _fake_quantize_learnable_per_tensor_affine, 11)
129+
#
130+
#
131+
# def fake_quantize_per_channel_affine(g, x, scale, zero_point, ch_axis, quant_min, quant_max):
132+
# return g.op("::FixedPerChannelAffine", x, scale, zero_point, ch_axis, quant_min, quant_max)
133+
#
134+
#
135+
# register_custom_op_symbolic('::fake_quantize_per_channel_affine', fake_quantize_per_channel_affine, 11)
136+
#
137+
#
138+
# def fake_quantize_per_tensor_affine(g, x, scale, zero_point, quant_min, quant_max):
139+
# return g.op("::FixedPerTensorAffine", x, scale, zero_point, quant_min, quant_max)
140+
#
141+
#
142+
# register_custom_op_symbolic('::fake_quantize_per_tensor_affine', fake_quantize_per_tensor_affine, 11)
11143

12-
def fake_quantize_per_channel_affine(g, x, scale, zero_point, ch_axis, quant_min, quant_max):
13-
return g.op("::FixedPerChannelAffine", x, scale, zero_point, ch_axis, quant_min, quant_max)
14-
15-
16-
register_custom_op_symbolic('::fake_quantize_per_channel_affine', fake_quantize_per_channel_affine, 11)
17-
18-
19-
def fake_quantize_per_tensor_affine(g, x, scale, zero_point, quant_min, quant_max):
20-
return g.op("::FixedPerTensorAffine", x, scale, zero_point, quant_min, quant_max)
21-
22-
23-
register_custom_op_symbolic('::fake_quantize_per_tensor_affine', fake_quantize_per_tensor_affine, 11)

mqbench/fake_quantize/fixed.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from mqbench.utils.hook import PerChannelLoadHook
55

66

7-
_version_under_1100 = int(torch.__version__.split('.')[1]) < 10
7+
_version_under_1100 = int(torch.__version__.split('.')[0]) == 1 and int(torch.__version__.split('.')[1]) < 10
88

99
class FixedFakeQuantize(QuantizeBase):
1010
"""This is actually torch.quantization.FakeQuantize.

mqbench/fake_quantize/lsq.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,12 @@
44
from mqbench.fake_quantize.quantize_base import QuantizeBase
55
from mqbench.utils import is_symmetric_quant, is_tracing_state
66
from mqbench.utils.hook import PerChannelLoadHook
7-
7+
from torch.onnx import (
8+
_type_utils,
9+
symbolic_helper,
10+
symbolic_opset9 as opset9,
11+
)
12+
import torch._C._onnx as _C_onnx
813

914
class LearnableFakeQuantize(QuantizeBase):
1015
r""" This is an extension of the FakeQuantize module in fake_quantize.py, which
@@ -106,5 +111,17 @@ def forward(ctx, x, scale, zero_point, ch_axis, quant_min, quant_max, grad_facto
106111
quant_min, quant_max, grad_factor)
107112

108113
@staticmethod
109-
def symbolic(g, x, scale, zero_point, ch_axis, quant_min, quant_max, grad_factor):
110-
return g.op("::FakeQuantizeLearnablePerchannelAffine", x, scale, zero_point, quant_min_i=quant_min, quant_max_i=quant_max)
114+
def symbolic(g, inputs, scale, zero_point, ch_axis, quant_min, quant_max, grad_factor):
115+
if quant_min == 0:
116+
zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.UINT8)
117+
else:
118+
zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.INT8)
119+
quantized = g.op("QuantizeLinear", inputs, scale, zero_point, axis_i=ch_axis)
120+
if (quant_min, quant_max) == (0, 127):
121+
quantized = g.op(
122+
"Clip",
123+
quantized,
124+
opset9.unused(g),
125+
g.op("Constant", value_t=torch.tensor(127, dtype=torch.uint8)),
126+
)
127+
return g.op("DequantizeLinear", quantized, scale, zero_point, axis_i=ch_axis)

0 commit comments

Comments
 (0)