|
| 1 | +import argparse |
| 2 | +import json |
| 3 | +import os |
| 4 | + |
| 5 | +import numpy as np |
| 6 | +import onnx |
| 7 | + |
| 8 | +import onnxruntime |
| 9 | +from onnxruntime.quantization import QuantFormat, QuantType, StaticQuantConfig, quantize |
| 10 | +from onnxruntime.quantization.calibrate import CalibrationDataReader, CalibrationMethod |
| 11 | + |
| 12 | + |
| 13 | +class OnnxModelCalibrationDataReader(CalibrationDataReader): |
| 14 | + def __init__(self, model_path): |
| 15 | + self.model_dir = os.path.dirname(model_path) |
| 16 | + data_dirs = [ |
| 17 | + os.path.join(self.model_dir, a) for a in os.listdir(self.model_dir) if a.startswith("test_data_set_") |
| 18 | + ] |
| 19 | + model_inputs = onnxruntime.InferenceSession(model_path).get_inputs() |
| 20 | + name2tensors = [] |
| 21 | + for data_dir in data_dirs: |
| 22 | + name2tensor = {} |
| 23 | + data_paths = [os.path.join(data_dir, a) for a in sorted(os.listdir(data_dir))] |
| 24 | + data_ndarrays = [self.read_onnx_pb_data(data_path) for data_path in data_paths] |
| 25 | + for model_input, data_ndarray in zip(model_inputs, data_ndarrays, strict=False): |
| 26 | + name2tensor[model_input.name] = data_ndarray |
| 27 | + name2tensors.append(name2tensor) |
| 28 | + assert len(name2tensors) == len(data_dirs) |
| 29 | + assert len(name2tensors[0]) == len(model_inputs) |
| 30 | + |
| 31 | + self.calibration_data = iter(name2tensors) |
| 32 | + |
| 33 | + def get_next(self) -> dict: |
| 34 | + """generate the input data dict for ONNXinferenceSession run""" |
| 35 | + return next(self.calibration_data, None) |
| 36 | + |
| 37 | + def read_onnx_pb_data(self, file_pb): |
| 38 | + tensor = onnx.TensorProto() |
| 39 | + with open(file_pb, "rb") as f: |
| 40 | + tensor.ParseFromString(f.read()) |
| 41 | + ret = onnx.numpy_helper.to_array(tensor) |
| 42 | + return ret |
| 43 | + |
| 44 | + |
| 45 | +def parse_arguments(): |
| 46 | + parser = argparse.ArgumentParser(description="The arguments for static quantization") |
| 47 | + parser.add_argument("-i", "--input_model_path", required=True, help="Path to the input onnx model") |
| 48 | + parser.add_argument( |
| 49 | + "-o", "--output_quantized_model_path", required=True, help="Path to the output quantized onnx model" |
| 50 | + ) |
| 51 | + parser.add_argument( |
| 52 | + "--activation_type", |
| 53 | + choices=["qint8", "quint8", "qint16", "quint16", "qint4", "quint4", "qfloat8e4m3fn"], |
| 54 | + default="quint8", |
| 55 | + help="Activation quantization type used", |
| 56 | + ) |
| 57 | + parser.add_argument( |
| 58 | + "--weight_type", |
| 59 | + choices=["qint8", "quint8", "qint16", "quint16", "qint4", "quint4", "qfloat8e4m3fn"], |
| 60 | + default="qint8", |
| 61 | + help="Weight quantization type used", |
| 62 | + ) |
| 63 | + parser.add_argument("--enable_subgraph", action="store_true", help="If set, subgraph will be quantized.") |
| 64 | + parser.add_argument( |
| 65 | + "--force_quantize_no_input_check", |
| 66 | + action="store_true", |
| 67 | + help="By default, some latent operators like maxpool, transpose, do not quantize if their input is not" |
| 68 | + " quantized already. Setting to True to force such operator always quantize input and so generate" |
| 69 | + " quantized output. Also the True behavior could be disabled per node using the nodes_to_exclude.", |
| 70 | + ) |
| 71 | + parser.add_argument( |
| 72 | + "--matmul_const_b_only", |
| 73 | + action="store_true", |
| 74 | + help="If set, only MatMul with const B will be quantized.", |
| 75 | + ) |
| 76 | + parser.add_argument( |
| 77 | + "--add_qdq_pair_to_weight", |
| 78 | + action="store_true", |
| 79 | + help="If set, it remains floating-point weight and inserts both QuantizeLinear/DeQuantizeLinear" |
| 80 | + " nodes to weight.", |
| 81 | + ) |
| 82 | + parser.add_argument( |
| 83 | + "--dedicated_qdq_pair", |
| 84 | + action="store_true", |
| 85 | + help="If set, it will create identical and dedicated QDQ pair for each node.", |
| 86 | + ) |
| 87 | + parser.add_argument( |
| 88 | + "--op_types_to_exclude_output_quantization", |
| 89 | + nargs="+", |
| 90 | + default=[], |
| 91 | + help="If any op type is specified, it won't quantize the output of ops with this specific op types.", |
| 92 | + ) |
| 93 | + parser.add_argument( |
| 94 | + "--calibration_method", |
| 95 | + default="minmax", |
| 96 | + choices=["minmax", "entropy", "percentile", "distribution"], |
| 97 | + help="Calibration method used", |
| 98 | + ) |
| 99 | + parser.add_argument("--quant_format", default="qdq", choices=["qdq", "qoperator"], help="Quantization format used") |
| 100 | + parser.add_argument( |
| 101 | + "--calib_tensor_range_symmetric", |
| 102 | + action="store_true", |
| 103 | + help="If enabled, the final range of tensor during calibration will be explicitly" |
| 104 | + " set to symmetric to central point 0", |
| 105 | + ) |
| 106 | + # TODO: --calib_strided_minmax" |
| 107 | + # TODO: --calib_moving_average_constant" |
| 108 | + # TODO: --calib_max_intermediate_outputs" |
| 109 | + parser.add_argument( |
| 110 | + "--calib_moving_average", |
| 111 | + action="store_true", |
| 112 | + help="If enabled, the moving average of" |
| 113 | + " the minimum and maximum values will be computed when the calibration method selected is MinMax.", |
| 114 | + ) |
| 115 | + parser.add_argument( |
| 116 | + "--disable_quantize_bias", |
| 117 | + action="store_true", |
| 118 | + help="Whether to quantize floating-point biases by solely inserting a DeQuantizeLinear node" |
| 119 | + " If not set, it remains floating-point bias and does not insert any quantization nodes" |
| 120 | + " associated with biases.", |
| 121 | + ) |
| 122 | + |
| 123 | + # TODO: Add arguments related to Smooth Quant |
| 124 | + |
| 125 | + parser.add_argument( |
| 126 | + "--use_qdq_contrib_ops", |
| 127 | + action="store_true", |
| 128 | + help="If set, the inserted QuantizeLinear and DequantizeLinear ops will have the com.microsoft domain," |
| 129 | + " which forces use of ONNX Runtime's QuantizeLinear and DequantizeLinear contrib op implementations.", |
| 130 | + ) |
| 131 | + parser.add_argument( |
| 132 | + "--minimum_real_range", |
| 133 | + type=float, |
| 134 | + default=0.0001, |
| 135 | + help="If set to a floating-point value, the calculation of the quantization parameters" |
| 136 | + " (i.e., scale and zero point) will enforce a minimum range between rmin and rmax. If (rmax-rmin)" |
| 137 | + " is less than the specified minimum range, rmax will be set to rmin + MinimumRealRange. This is" |
| 138 | + " necessary for EPs like QNN that require a minimum floating-point range when determining " |
| 139 | + " quantization parameters.", |
| 140 | + ) |
| 141 | + parser.add_argument( |
| 142 | + "--qdq_keep_removable_activations", |
| 143 | + action="store_true", |
| 144 | + help="If set, removable activations (e.g., Clip or Relu) will not be removed," |
| 145 | + " and will be explicitly represented in the QDQ model.", |
| 146 | + ) |
| 147 | + parser.add_argument( |
| 148 | + "--qdq_disable_weight_adjust_for_int32_bias", |
| 149 | + action="store_true", |
| 150 | + help="If set, QDQ quantizer will not adjust the weight's scale when the bias" |
| 151 | + " has a scale (input_scale * weight_scale) that is too small.", |
| 152 | + ) |
| 153 | + parser.add_argument("--per_channel", action="store_true", help="Whether using per-channel quantization") |
| 154 | + parser.add_argument( |
| 155 | + "--nodes_to_quantize", |
| 156 | + nargs="+", |
| 157 | + default=None, |
| 158 | + help="List of nodes names to quantize. When this list is not None only the nodes in this list are quantized.", |
| 159 | + ) |
| 160 | + parser.add_argument( |
| 161 | + "--nodes_to_exclude", |
| 162 | + nargs="+", |
| 163 | + default=None, |
| 164 | + help="List of nodes names to exclude. The nodes in this list will be excluded from quantization when it is not None.", |
| 165 | + ) |
| 166 | + parser.add_argument( |
| 167 | + "--op_per_channel_axis", |
| 168 | + nargs=2, |
| 169 | + action="append", |
| 170 | + metavar=("OP_TYPE", "PER_CHANNEL_AXIS"), |
| 171 | + default=[], |
| 172 | + help="Set channel axis for specific op type, for example: --op_per_channel_axis MatMul 1, and it's" |
| 173 | + " effective only when per channel quantization is supported and per_channel is True. If specific" |
| 174 | + " op type supports per channel quantization but not explicitly specified with channel axis," |
| 175 | + " default channel axis will be used.", |
| 176 | + ) |
| 177 | + parser.add_argument("--tensor_quant_overrides", help="Set the json file for tensor quantization overrides.") |
| 178 | + return parser.parse_args() |
| 179 | + |
| 180 | + |
| 181 | +def get_tensor_quant_overrides(file): |
| 182 | + # TODO: Enhance the function to handle more real cases of json file |
| 183 | + if not file: |
| 184 | + return {} |
| 185 | + with open(file) as f: |
| 186 | + quant_override_dict = json.load(f) |
| 187 | + for tensor in quant_override_dict: |
| 188 | + for enc_dict in quant_override_dict[tensor]: |
| 189 | + enc_dict["scale"] = np.array(enc_dict["scale"], dtype=np.float32) |
| 190 | + enc_dict["zero_point"] = np.array(enc_dict["zero_point"]) |
| 191 | + return quant_override_dict |
| 192 | + |
| 193 | + |
| 194 | +def main(): |
| 195 | + args = parse_arguments() |
| 196 | + data_reader = OnnxModelCalibrationDataReader(model_path=args.input_model_path) |
| 197 | + arg2quant_type = { |
| 198 | + "qint8": QuantType.QInt8, |
| 199 | + "quint8": QuantType.QUInt8, |
| 200 | + "qint16": QuantType.QInt16, |
| 201 | + "quint16": QuantType.QUInt16, |
| 202 | + "qint4": QuantType.QInt4, |
| 203 | + "quint4": QuantType.QUInt4, |
| 204 | + "qfloat8e4m3fn": QuantType.QFLOAT8E4M3FN, |
| 205 | + } |
| 206 | + activation_type = arg2quant_type[args.activation_type] |
| 207 | + weight_type = arg2quant_type[args.weight_type] |
| 208 | + qdq_op_type_per_channel_support_to_axis = dict(args.op_per_channel_axis) |
| 209 | + extra_options = { |
| 210 | + "EnableSubgraph": args.enable_subgraph, |
| 211 | + "ForceQuantizeNoInputCheck": args.force_quantize_no_input_check, |
| 212 | + "MatMulConstBOnly": args.matmul_const_b_only, |
| 213 | + "AddQDQPairToWeight": args.add_qdq_pair_to_weight, |
| 214 | + "OpTypesToExcludeOutputQuantization": args.op_types_to_exclude_output_quantization, |
| 215 | + "DedicatedQDQPair": args.dedicated_qdq_pair, |
| 216 | + "QDQOpTypePerChannelSupportToAxis": qdq_op_type_per_channel_support_to_axis, |
| 217 | + "CalibTensorRangeSymmetric": args.calib_tensor_range_symmetric, |
| 218 | + "CalibMovingAverage": args.calib_moving_average, |
| 219 | + "QuantizeBias": not args.disable_quantize_bias, |
| 220 | + "UseQDQContribOps": args.use_qdq_contrib_ops, |
| 221 | + "MinimumRealRange": args.minimum_real_range, |
| 222 | + "QDQKeepRemovableActivations": args.qdq_keep_removable_activations, |
| 223 | + "QDQDisableWeightAdjustForInt32Bias": args.qdq_disable_weight_adjust_for_int32_bias, |
| 224 | + # Load json file for encoding override |
| 225 | + "TensorQuantOverrides": get_tensor_quant_overrides(args.tensor_quant_overrides), |
| 226 | + } |
| 227 | + arg2calib_method = { |
| 228 | + "minmax": CalibrationMethod.MinMax, |
| 229 | + "entropy": CalibrationMethod.Entropy, |
| 230 | + "percentile": CalibrationMethod.Percentile, |
| 231 | + "distribution": CalibrationMethod.Distribution, |
| 232 | + } |
| 233 | + arg2quant_format = { |
| 234 | + "qdq": QuantFormat.QDQ, |
| 235 | + "qoperator": QuantFormat.QOperator, |
| 236 | + } |
| 237 | + sqc = StaticQuantConfig( |
| 238 | + calibration_data_reader=data_reader, |
| 239 | + calibrate_method=arg2calib_method[args.calibration_method], |
| 240 | + quant_format=arg2quant_format[args.quant_format], |
| 241 | + activation_type=activation_type, |
| 242 | + weight_type=weight_type, |
| 243 | + op_types_to_quantize=None, |
| 244 | + nodes_to_quantize=args.nodes_to_quantize, |
| 245 | + nodes_to_exclude=args.nodes_to_exclude, |
| 246 | + per_channel=args.per_channel, |
| 247 | + reduce_range=False, |
| 248 | + use_external_data_format=False, |
| 249 | + calibration_providers=None, # Use CPUExecutionProvider |
| 250 | + extra_options=extra_options, |
| 251 | + ) |
| 252 | + quantize(model_input=args.input_model_path, model_output=args.output_quantized_model_path, quant_config=sqc) |
| 253 | + |
| 254 | + |
| 255 | +if __name__ == "__main__": |
| 256 | + main() |
0 commit comments