1+ import argparse
2+ import json
3+ import numpy as np
4+ import os
5+ import onnx
6+
7+ import onnxruntime
8+ from onnxruntime .quantization import (
9+ StaticQuantConfig ,
10+ QuantType ,
11+ QuantFormat ,
12+ quantize
13+ )
14+ from onnxruntime .quantization .calibrate import (
15+ CalibrationDataReader ,
16+ CalibrationMethod
17+ )
18+
19+ class OnnxModelCalibrationDataReader (CalibrationDataReader ):
20+ def __init__ (self , model_path ):
21+ self .model_dir = os .path .dirname (model_path )
22+ data_dirs = [
23+ os .path .join (self .model_dir , a ) for a in os .listdir (self .model_dir )
24+ if a .startswith ('test_data_set_' )
25+ ]
26+ model_inputs = onnxruntime .InferenceSession (model_path ).get_inputs ()
27+ name2tensors = []
28+ for data_dir in data_dirs :
29+ name2tensor = {}
30+ data_paths = [os .path .join (data_dir , a ) for a in os .listdir (data_dir )]
31+ data_ndarrays = [self .read_onnx_pb_data (data_path ) for data_path in data_paths ]
32+ for model_input , data_ndarray in zip (model_inputs , data_ndarrays ):
33+ name2tensor [model_input .name ] = data_ndarray
34+ name2tensors .append (name2tensor )
35+ assert len (name2tensors ) == len (data_dirs )
36+ assert len (name2tensors [0 ]) == len (model_inputs )
37+
38+ self .calibration_data = iter (name2tensors )
39+
40+ def get_next (self ) -> dict :
41+ """generate the input data dict for ONNXinferenceSession run"""
42+ return next (self .calibration_data , None )
43+
44+ def read_onnx_pb_data (self , file_pb ):
45+ tensor = onnx .TensorProto ()
46+ with open (file_pb , 'rb' ) as f :
47+ tensor .ParseFromString (f .read ())
48+ ret = onnx .numpy_helper .to_array (tensor )
49+ return ret
50+
51+ def parse_arguments ():
52+ parser = argparse .ArgumentParser (
53+ description = "The arguments for static quantization"
54+ )
55+ parser .add_argument ("-i" , "--input_model_path" , required = True , help = "Path to the input onnx model" )
56+ parser .add_argument ("-o" , "--output_quantized_model_path" , required = True , help = "Path to the output quantized onnx model" )
57+ parser .add_argument (
58+ "--activation_type" ,
59+ choices = ["qint8" , "quint8" , "qint16" , "quint16" , "qint4" , "quint4" , "qfloat8e4m3fn" ],
60+ default = "quint8" ,
61+ help = "Activation quantization type used"
62+ )
63+ parser .add_argument (
64+ "--weight_type" ,
65+ choices = ["qint8" , "quint8" , "qint16" , "quint16" , "qint4" , "quint4" , "qfloat8e4m3fn" ],
66+ default = "qint8" ,
67+ help = "Weight quantization type used"
68+ )
69+ parser .add_argument (
70+ "--enable_subgraph" , action = "store_true" , help = "If set, subgraph will be quantized."
71+ )
72+ parser .add_argument (
73+ "--force_quantize_no_input_check" ,
74+ action = "store_true" ,
75+ help = "By default, some latent operators like maxpool, transpose, do not quantize if their input is not"
76+ " quantized already. Setting to True to force such operator always quantize input and so generate"
77+ " quantized output. Also the True behavior could be disabled per node using the nodes_to_exclude."
78+ )
79+ parser .add_argument (
80+ "--matmul_const_b_only" ,
81+ action = "store_true" ,
82+ help = "If set, only MatMul with const B will be quantized."
83+ )
84+ parser .add_argument (
85+ "--add_qdq_pair_to_weight" ,
86+ action = "store_true" ,
87+ help = "If set, it remains floating-point weight and inserts both QuantizeLinear/DeQuantizeLinear"
88+ " nodes to weight."
89+ )
90+ parser .add_argument (
91+ "--dedicated_qdq_pair" ,
92+ action = "store_true" ,
93+ help = "If set, it will create identical and dedicated QDQ pair for each node."
94+ )
95+ parser .add_argument (
96+ "--op_types_to_exclude_output_quantization" ,
97+ nargs = '+' ,
98+ default = [],
99+ help = "If any op type is specified, it won't quantize the output of ops with this specific op types."
100+ )
101+ parser .add_argument (
102+ "--calibration_method" ,
103+ default = "minmax" ,
104+ choices = ["minmax" , "entropy" , "percentile" , "distribution" ],
105+ help = "Calibration method used"
106+ )
107+ parser .add_argument (
108+ "--quant_format" , default = "qdq" , choices = ["qdq" , "qoperator" ], help = "Quantization format used"
109+ )
110+ parser .add_argument (
111+ "--calib_tensor_range_symmetric" ,
112+ action = "store_true" ,
113+ help = "If enabled, the final range of tensor during calibration will be explicitly"
114+ " set to symmetric to central point 0"
115+ )
116+ # TODO: --calib_strided_minmax"
117+ # TODO: --calib_moving_average_constant"
118+ # TODO: --calib_max_intermediate_outputs"
119+ parser .add_argument (
120+ "--calib_moving_average" ,
121+ action = "store_true" ,
122+ help = "If enabled, the moving average of"
123+ " the minimum and maximum values will be computed when the calibration method selected is MinMax."
124+ )
125+ parser .add_argument (
126+ "--quantize_bias" ,
127+ action = "store_true" ,
128+ help = "Whether to quantize floating-point biases by solely inserting a DeQuantizeLinear node"
129+ " If not set, it remains floating-point bias and does not insert any quantization nodes"
130+ " associated with biases."
131+ )
132+
133+ # TODO: Add arguments related to Smooth Quant
134+
135+ parser .add_argument (
136+ "--use_qdq_contrib_ops" ,
137+ action = "store_true" ,
138+ help = "If set, the inserted QuantizeLinear and DequantizeLinear ops will have the `com.microsoft` domain,"
139+ " which forces use of ONNX Runtime's QuantizeLinear and DequantizeLinear contrib op implementations."
140+ )
141+ parser .add_argument (
142+ "--minimum_real_range" ,
143+ type = float ,
144+ default = 0.0001 ,
145+ help = "If set to a floating-point value, the calculation of the quantization parameters"
146+ " (i.e., scale and zero point) will enforce a minimum range between rmin and rmax. If (rmax-rmin)"
147+ " is less than the specified minimum range, rmax will be set to rmin + MinimumRealRange. This is"
148+ " necessary for EPs like QNN that require a minimum floating-point range when determining "
149+ " quantization parameters."
150+ )
151+ parser .add_argument (
152+ "--qdq_keep_removable_activations" ,
153+ action = "store_true" ,
154+ help = "If set, removable activations (e.g., Clip or Relu) will not be removed,"
155+ " and will be explicitly represented in the QDQ model."
156+ )
157+ parser .add_argument (
158+ "--qdq_disable_weight_adjust_for_int32_bias" ,
159+ action = "store_true" ,
160+ help = "If set, QDQ quantizer will not adjust the weight's scale when the bias"
161+ " has a scale (input_scale * weight_scale) that is too small."
162+ )
163+ parser .add_argument ("--per_channel" , action = "store_true" , help = "Whether using per-channel quantization" )
164+ parser .add_argument (
165+ "--op_per_channel_axis" ,
166+ nargs = 2 ,
167+ action = "append" ,
168+ metavar = ('OP_TYPE' , 'PER_CHANNEL_AXIS' ),
169+ default = [],
170+ help = "Set channel axis for specific op type, for example: --op_per_channel_axis MatMul 1, and it's"
171+ " effective only when per channel quantization is supported and per_channel is True. If specific"
172+ " op type supports per channel quantization but not explicitly specified with channel axis,"
173+ " default channel axis will be used."
174+ )
175+ parser .add_argument (
176+ "--tensor_quant_overrides" ,
177+ help = "Set the json file for tensor quantization overrides."
178+ )
179+ return parser .parse_args ()
180+
181+ def get_tensor_quant_overrides (file ):
182+ # TODO: Enhance the function to handle more real cases of json file
183+ if not file :
184+ return {}
185+ with open (file , "r" ) as f :
186+ quant_override_dict = json .load (f )
187+ for tensor in quant_override_dict :
188+ for enc_dict in quant_override_dict [tensor ]:
189+ enc_dict ["scale" ] = np .array (enc_dict ["scale" ], dtype = np .float32 )
190+ enc_dict ["zero_point" ] = np .array (enc_dict ["zero_point" ])
191+ return quant_override_dict
192+
193+ def main ():
194+ args = parse_arguments ()
195+ data_reader = OnnxModelCalibrationDataReader (
196+ model_path = args .input_model_path
197+ )
198+ arg2quant_type = {
199+ "qint8" : QuantType .QInt8 ,
200+ "quint8" : QuantType .QUInt8 ,
201+ "qint16" : QuantType .QInt16 ,
202+ "quint16" : QuantType .QUInt16 ,
203+ "qint4" : QuantType .QInt4 ,
204+ "quint4" : QuantType .QUInt4 ,
205+ "qfloat8e4m3fn" : QuantType .QFLOAT8E4M3FN
206+ }
207+ activation_type = arg2quant_type [args .activation_type ]
208+ weight_type = arg2quant_type [args .weight_type ]
209+ qdq_op_type_per_channel_support_to_axis = dict (args .op_per_channel_axis )
210+ extra_options = {
211+ "EnableSubgraph" : args .enable_subgraph ,
212+ "ForceQuantizeNoInputCheck" : args .force_quantize_no_input_check ,
213+ "MatMulConstBOnly" : args .matmul_const_b_only ,
214+ "AddQDQPairToWeight" : args .add_qdq_pair_to_weight ,
215+ "OpTypesToExcludeOutputQuantization" : args .op_types_to_exclude_output_quantization ,
216+ "DedicatedQDQPair" : args .dedicated_qdq_pair ,
217+ "QDQOpTypePerChannelSupportToAxis" : qdq_op_type_per_channel_support_to_axis ,
218+ "CalibTensorRangeSymmetric" : args .calib_tensor_range_symmetric ,
219+ "CalibMovingAverage" : args .calib_moving_average ,
220+ "UseQDQContribOps" : args .use_qdq_contrib_ops ,
221+ "MinimumRealRange" : args .minimum_real_range ,
222+ "QDQKeepRemovableActivations" : args .qdq_keep_removable_activations ,
223+ "QDQDisableWeightAdjustForInt32Bias" : args .qdq_disable_weight_adjust_for_int32_bias ,
224+ # Load json file for encoding override
225+ "TensorQuantOverrides" : get_tensor_quant_overrides (args .tensor_quant_overrides )
226+ }
227+ arg2calib_method = {
228+ "minmax" : CalibrationMethod .MinMax ,
229+ "entropy" : CalibrationMethod .Entropy ,
230+ "percentile" : CalibrationMethod .Percentile ,
231+ "distribution" : CalibrationMethod .Distribution
232+ }
233+ arg2quant_format = {
234+ "qdq" : QuantFormat .QDQ ,
235+ "qoperator" : QuantFormat .QOperator ,
236+ }
237+ sqc = StaticQuantConfig (
238+ calibration_data_reader = data_reader ,
239+ calibrate_method = arg2calib_method [args .calibration_method ],
240+ quant_format = arg2quant_format [args .quant_format ],
241+ activation_type = activation_type ,
242+ weight_type = weight_type ,
243+ op_types_to_quantize = None ,
244+ nodes_to_quantize = None ,
245+ nodes_to_exclude = None ,
246+ per_channel = args .per_channel ,
247+ reduce_range = False ,
248+ use_external_data_format = False ,
249+ calibration_providers = None , # Use CPUExecutionProvider
250+ extra_options = extra_options
251+ )
252+ quantize (
253+ model_input = args .input_model_path ,
254+ model_output = args .output_quantized_model_path ,
255+ quant_config = sqc
256+ )
257+
258+ if __name__ == '__main__' :
259+ main ()
0 commit comments