Skip to content

Commit ea2334b

Browse files
committed
Add static quantization runner
- Add a general command-line tool for static quantization - Support loading TensorQuantOverride from json file - Add the corresponding README
1 parent 9dcb99c commit ea2334b

File tree

2 files changed

+342
-1
lines changed

2 files changed

+342
-1
lines changed
Lines changed: 83 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,84 @@
11
# Quantization Tool
2-
This tool can be used to quantize select ONNX models. Support is based on operators in the model. Please refer to https://onnxruntime.ai/docs/performance/quantization.html for usage details and https://github.com/microsoft/onnxruntime-inference-examples/tree/main/quantization for examples.
2+
This tool can be used to quantize select ONNX models. Support is based on operators in the model. Please refer to https://onnxruntime.ai/docs/performance/quantization.html for usage details and https://github.com/microsoft/onnxruntime-inference-examples/tree/main/quantization for examples.
3+
4+
## Static Quantization Tool
5+
6+
### Build
7+
Please add `--enable_pybind` and `--build_wheel` to the build command to acquire the python tools.
8+
9+
```bash
10+
cd onnxruntime
11+
.\build.bat --config RelWithDebInfo --build_shared_lib --parallel --cmake_generator "Visual Studio 17 2022" --enable_pybind --build_wheel
12+
```
13+
14+
### Model and Data
15+
The static quantization tool expects the directory structure of model and data.
16+
17+
```ps1
18+
work_dir\resnet18-v1-7
19+
├───model.onnx
20+
├───test_data_set_0
21+
├───test_data_set_1
22+
├───test_data_set_2
23+
├───test_data_set_3
24+
├───test_data_set_4
25+
├───test_data_set_5
26+
├───test_data_set_6
27+
├───test_data_set_7
28+
├───test_data_set_8
29+
└───test_data_set_9
30+
```
31+
32+
### Usage
33+
Install the python tools built in onnxruntime
34+
```ps1
35+
cd work_dir
36+
python -m venv ort_env
37+
ort_env\Scripts\activate
38+
python -m pip install <path-to-built-folder>\RelWithDebInfo\RelWithDebInfo\dist\<name-of-the-wheel>.whl
39+
40+
# The following command yields model_quant.onnx under the same directory "resnet18-v1-7"
41+
python -m onnxruntime.quantization.static_quantize_runner -i resnet18-v1-7\model.onnx -o resnet18-v1-7\model_quant.onnx
42+
43+
work_dir\resnet18-v1-7
44+
├───model.onnx
45+
├───model_quant.onnx
46+
├───test_data_set_0
47+
│ ...
48+
└───test_data_set_9
49+
```
50+
51+
### Quantization Arguments
52+
Please refer to `static_quantize_runner.py` for more detailed arguments.
53+
54+
```ps1
55+
python -m onnxruntime.quantization.static_quantize_runner -i resnet18-v1-7\model.onnx -o resnet18-v1-7\model_quant.onnx --activation_type qint8 --weight_type qint16
56+
python -m onnxruntime.quantization.static_quantize_runner -i resnet18-v1-7\model.onnx -o resnet18-v1-7\model_quant.onnx --activation_type qint16 --weight_type qint16 --quantize_bias
57+
python -m onnxruntime.quantization.static_quantize_runner -i resnet18-v1-7\model.onnx -o resnet18-v1-7\model_quant.onnx --activation_type qint16 --weight_type qint8 --per_channel
58+
```
59+
60+
### Tensor Quant Overrides Json Format
61+
With `--tensor_quant_overrides`, the tool can consume the json file with quantization override information.
62+
```ps1
63+
python -m onnxruntime.quantization.static_quantize_runner -i resnet18-v1-7\model.onnx -o resnet18-v1-7\model_quant.onnx --tensor_quant_overrides <path-to-json>\encoding.json
64+
```
65+
66+
The tool expects the encoding.json with the format:
67+
```json
68+
{
69+
"conv1_1": [
70+
{
71+
"scale": 0.005,
72+
"zero_point": 12
73+
}
74+
]
75+
}
76+
```
77+
- Each key is the name of a tensor in the onnx model.
78+
- e.g. "conv1_1"
79+
- For each tensor, a list of dictionary should be provided
80+
- For per-tensor quantization, the list contains a single dictionary.
81+
- For per-channel quantization, the list contains a dictionary for each channel in the tensor.
82+
- Each dictionary contain the information required for quantization including:
83+
- scale (float)
84+
- zero_point (int)
Lines changed: 259 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,259 @@
1+
import argparse
2+
import json
3+
import numpy as np
4+
import os
5+
import onnx
6+
7+
import onnxruntime
8+
from onnxruntime.quantization import (
9+
StaticQuantConfig,
10+
QuantType,
11+
QuantFormat,
12+
quantize
13+
)
14+
from onnxruntime.quantization.calibrate import (
15+
CalibrationDataReader,
16+
CalibrationMethod
17+
)
18+
19+
class OnnxModelCalibrationDataReader(CalibrationDataReader):
20+
def __init__(self, model_path):
21+
self.model_dir = os.path.dirname(model_path)
22+
data_dirs = [
23+
os.path.join(self.model_dir, a) for a in os.listdir(self.model_dir)
24+
if a.startswith('test_data_set_')
25+
]
26+
model_inputs = onnxruntime.InferenceSession(model_path).get_inputs()
27+
name2tensors = []
28+
for data_dir in data_dirs:
29+
name2tensor = {}
30+
data_paths = [os.path.join(data_dir, a) for a in os.listdir(data_dir)]
31+
data_ndarrays = [self.read_onnx_pb_data(data_path) for data_path in data_paths]
32+
for model_input, data_ndarray in zip(model_inputs, data_ndarrays):
33+
name2tensor[model_input.name] = data_ndarray
34+
name2tensors.append(name2tensor)
35+
assert len(name2tensors) == len(data_dirs)
36+
assert len(name2tensors[0]) == len(model_inputs)
37+
38+
self.calibration_data = iter(name2tensors)
39+
40+
def get_next(self) -> dict:
41+
"""generate the input data dict for ONNXinferenceSession run"""
42+
return next(self.calibration_data, None)
43+
44+
def read_onnx_pb_data(self, file_pb):
45+
tensor = onnx.TensorProto()
46+
with open(file_pb, 'rb') as f:
47+
tensor.ParseFromString(f.read())
48+
ret = onnx.numpy_helper.to_array(tensor)
49+
return ret
50+
51+
def parse_arguments():
52+
parser = argparse.ArgumentParser(
53+
description="The arguments for static quantization"
54+
)
55+
parser.add_argument("-i", "--input_model_path", required=True, help="Path to the input onnx model")
56+
parser.add_argument("-o", "--output_quantized_model_path", required=True, help="Path to the output quantized onnx model")
57+
parser.add_argument(
58+
"--activation_type",
59+
choices=["qint8", "quint8", "qint16", "quint16", "qint4", "quint4", "qfloat8e4m3fn"],
60+
default="quint8",
61+
help="Activation quantization type used"
62+
)
63+
parser.add_argument(
64+
"--weight_type",
65+
choices=["qint8", "quint8", "qint16", "quint16", "qint4", "quint4", "qfloat8e4m3fn"],
66+
default="qint8",
67+
help="Weight quantization type used"
68+
)
69+
parser.add_argument(
70+
"--enable_subgraph", action="store_true", help="If set, subgraph will be quantized."
71+
)
72+
parser.add_argument(
73+
"--force_quantize_no_input_check",
74+
action="store_true",
75+
help="By default, some latent operators like maxpool, transpose, do not quantize if their input is not"
76+
" quantized already. Setting to True to force such operator always quantize input and so generate"
77+
" quantized output. Also the True behavior could be disabled per node using the nodes_to_exclude."
78+
)
79+
parser.add_argument(
80+
"--matmul_const_b_only",
81+
action="store_true",
82+
help="If set, only MatMul with const B will be quantized."
83+
)
84+
parser.add_argument(
85+
"--add_qdq_pair_to_weight",
86+
action="store_true",
87+
help="If set, it remains floating-point weight and inserts both QuantizeLinear/DeQuantizeLinear"
88+
" nodes to weight."
89+
)
90+
parser.add_argument(
91+
"--dedicated_qdq_pair",
92+
action="store_true",
93+
help="If set, it will create identical and dedicated QDQ pair for each node."
94+
)
95+
parser.add_argument(
96+
"--op_types_to_exclude_output_quantization",
97+
nargs='+',
98+
default=[],
99+
help="If any op type is specified, it won't quantize the output of ops with this specific op types."
100+
)
101+
parser.add_argument(
102+
"--calibration_method",
103+
default="minmax",
104+
choices=["minmax", "entropy", "percentile", "distribution"],
105+
help="Calibration method used"
106+
)
107+
parser.add_argument(
108+
"--quant_format", default="qdq", choices=["qdq", "qoperator"], help="Quantization format used"
109+
)
110+
parser.add_argument(
111+
"--calib_tensor_range_symmetric",
112+
action="store_true",
113+
help="If enabled, the final range of tensor during calibration will be explicitly"
114+
" set to symmetric to central point 0"
115+
)
116+
# TODO: --calib_strided_minmax"
117+
# TODO: --calib_moving_average_constant"
118+
# TODO: --calib_max_intermediate_outputs"
119+
parser.add_argument(
120+
"--calib_moving_average",
121+
action="store_true",
122+
help="If enabled, the moving average of"
123+
" the minimum and maximum values will be computed when the calibration method selected is MinMax."
124+
)
125+
parser.add_argument(
126+
"--quantize_bias",
127+
action="store_true",
128+
help="Whether to quantize floating-point biases by solely inserting a DeQuantizeLinear node"
129+
" If not set, it remains floating-point bias and does not insert any quantization nodes"
130+
" associated with biases."
131+
)
132+
133+
# TODO: Add arguments related to Smooth Quant
134+
135+
parser.add_argument(
136+
"--use_qdq_contrib_ops",
137+
action="store_true",
138+
help="If set, the inserted QuantizeLinear and DequantizeLinear ops will have the `com.microsoft` domain,"
139+
" which forces use of ONNX Runtime's QuantizeLinear and DequantizeLinear contrib op implementations."
140+
)
141+
parser.add_argument(
142+
"--minimum_real_range",
143+
type=float,
144+
default=0.0001,
145+
help="If set to a floating-point value, the calculation of the quantization parameters"
146+
" (i.e., scale and zero point) will enforce a minimum range between rmin and rmax. If (rmax-rmin)"
147+
" is less than the specified minimum range, rmax will be set to rmin + MinimumRealRange. This is"
148+
" necessary for EPs like QNN that require a minimum floating-point range when determining "
149+
" quantization parameters."
150+
)
151+
parser.add_argument(
152+
"--qdq_keep_removable_activations",
153+
action="store_true",
154+
help="If set, removable activations (e.g., Clip or Relu) will not be removed,"
155+
" and will be explicitly represented in the QDQ model."
156+
)
157+
parser.add_argument(
158+
"--qdq_disable_weight_adjust_for_int32_bias",
159+
action="store_true",
160+
help="If set, QDQ quantizer will not adjust the weight's scale when the bias"
161+
" has a scale (input_scale * weight_scale) that is too small."
162+
)
163+
parser.add_argument("--per_channel", action="store_true", help="Whether using per-channel quantization")
164+
parser.add_argument(
165+
"--op_per_channel_axis",
166+
nargs=2,
167+
action="append",
168+
metavar=('OP_TYPE', 'PER_CHANNEL_AXIS'),
169+
default=[],
170+
help="Set channel axis for specific op type, for example: --op_per_channel_axis MatMul 1, and it's"
171+
" effective only when per channel quantization is supported and per_channel is True. If specific"
172+
" op type supports per channel quantization but not explicitly specified with channel axis,"
173+
" default channel axis will be used."
174+
)
175+
parser.add_argument(
176+
"--tensor_quant_overrides",
177+
help="Set the json file for tensor quantization overrides."
178+
)
179+
return parser.parse_args()
180+
181+
def get_tensor_quant_overrides(file):
182+
# TODO: Enhance the function to handle more real cases of json file
183+
if not file:
184+
return {}
185+
with open(file, "r") as f:
186+
quant_override_dict = json.load(f)
187+
for tensor in quant_override_dict:
188+
for enc_dict in quant_override_dict[tensor]:
189+
enc_dict["scale"] = np.array(enc_dict["scale"], dtype=np.float32)
190+
enc_dict["zero_point"] = np.array(enc_dict["zero_point"])
191+
return quant_override_dict
192+
193+
def main():
194+
args = parse_arguments()
195+
data_reader = OnnxModelCalibrationDataReader(
196+
model_path=args.input_model_path
197+
)
198+
arg2quant_type = {
199+
"qint8": QuantType.QInt8,
200+
"quint8": QuantType.QUInt8,
201+
"qint16": QuantType.QInt16,
202+
"quint16": QuantType.QUInt16,
203+
"qint4": QuantType.QInt4,
204+
"quint4": QuantType.QUInt4,
205+
"qfloat8e4m3fn": QuantType.QFLOAT8E4M3FN
206+
}
207+
activation_type = arg2quant_type[args.activation_type]
208+
weight_type = arg2quant_type[args.weight_type]
209+
qdq_op_type_per_channel_support_to_axis = dict(args.op_per_channel_axis)
210+
extra_options = {
211+
"EnableSubgraph": args.enable_subgraph,
212+
"ForceQuantizeNoInputCheck": args.force_quantize_no_input_check,
213+
"MatMulConstBOnly": args.matmul_const_b_only,
214+
"AddQDQPairToWeight": args.add_qdq_pair_to_weight,
215+
"OpTypesToExcludeOutputQuantization": args.op_types_to_exclude_output_quantization,
216+
"DedicatedQDQPair": args.dedicated_qdq_pair,
217+
"QDQOpTypePerChannelSupportToAxis": qdq_op_type_per_channel_support_to_axis,
218+
"CalibTensorRangeSymmetric" : args.calib_tensor_range_symmetric,
219+
"CalibMovingAverage": args.calib_moving_average,
220+
"UseQDQContribOps": args.use_qdq_contrib_ops,
221+
"MinimumRealRange": args.minimum_real_range,
222+
"QDQKeepRemovableActivations": args.qdq_keep_removable_activations,
223+
"QDQDisableWeightAdjustForInt32Bias": args.qdq_disable_weight_adjust_for_int32_bias,
224+
# Load json file for encoding override
225+
"TensorQuantOverrides": get_tensor_quant_overrides(args.tensor_quant_overrides)
226+
}
227+
arg2calib_method = {
228+
"minmax": CalibrationMethod.MinMax,
229+
"entropy": CalibrationMethod.Entropy,
230+
"percentile": CalibrationMethod.Percentile,
231+
"distribution": CalibrationMethod.Distribution
232+
}
233+
arg2quant_format = {
234+
"qdq": QuantFormat.QDQ,
235+
"qoperator": QuantFormat.QOperator,
236+
}
237+
sqc = StaticQuantConfig(
238+
calibration_data_reader=data_reader,
239+
calibrate_method=arg2calib_method[args.calibration_method],
240+
quant_format=arg2quant_format[args.quant_format],
241+
activation_type=activation_type,
242+
weight_type=weight_type,
243+
op_types_to_quantize=None,
244+
nodes_to_quantize=None,
245+
nodes_to_exclude=None,
246+
per_channel=args.per_channel,
247+
reduce_range=False,
248+
use_external_data_format=False,
249+
calibration_providers=None, # Use CPUExecutionProvider
250+
extra_options=extra_options
251+
)
252+
quantize(
253+
model_input=args.input_model_path,
254+
model_output=args.output_quantized_model_path,
255+
quant_config=sqc
256+
)
257+
258+
if __name__ == '__main__':
259+
main()

0 commit comments

Comments
 (0)