11from enum import Enum
22
33from tqdm import tqdm
4- from typing import Set
4+ from typing import Set , List , Optional
55import onnx
66import os
77
@@ -110,6 +110,16 @@ class QuantizationArguments:
110110 },
111111 )
112112
113+ op_block_list : List [str ] = field (
114+ default = None ,
115+ metadata = {
116+ "help" : "List of operators to exclude from quantization."
117+ "Can be any standard ONNX operator (see https://onnx.ai/onnx/operators/)"
118+ "or your custom implemented operators." ,
119+ "nargs" : "+" ,
120+ },
121+ )
122+
113123
114124def get_operators (model : onnx .ModelProto ) -> Set [str ]:
115125 operators = set ()
@@ -131,6 +141,7 @@ def quantize_q8(
131141 per_channel : bool ,
132142 reduce_range : bool ,
133143 weight_type : QuantType ,
144+ op_block_list : Optional [List [str ]]
134145):
135146 """
136147 Quantize the weights of the model from float32 to int8/uint8
@@ -140,6 +151,10 @@ def quantize_q8(
140151 it is faster on most CPU architectures
141152 """
142153
154+ op_types_to_quantize = set (IntegerOpsRegistry .keys ())
155+ if op_block_list is not None :
156+ op_types_to_quantize .difference_update (op_block_list )
157+
143158 quantizer = ONNXQuantizer (
144159 model ,
145160 per_channel ,
@@ -151,7 +166,7 @@ def quantize_q8(
151166 tensors_range = None ,
152167 nodes_to_quantize = [],
153168 nodes_to_exclude = [],
154- op_types_to_quantize = list ( IntegerOpsRegistry . keys ()) ,
169+ op_types_to_quantize = op_types_to_quantize ,
155170 extra_options = dict (
156171 EnableSubgraph = True ,
157172 MatMulConstBOnly = True ,
@@ -165,6 +180,7 @@ def quantize_q8(
165180def quantize_fp16 (
166181 model : onnx .ModelProto ,
167182 save_path : str ,
183+ op_block_list : Optional [List [str ]]
168184):
169185 """
170186 Quantize the weights of the model from float32 to float16
@@ -174,10 +190,15 @@ def quantize_fp16(
174190 # ValueError: Message onnx.ModelProto exceeds maximum protobuf size of 2GB: 2338583841
175191 disable_shape_infer = model .ByteSize () >= onnx .checker .MAXIMUM_PROTOBUF
176192
193+ blocked_ops = set (float16 .DEFAULT_OP_BLOCK_LIST )
194+ if op_block_list is not None :
195+ blocked_ops .update (op_block_list )
196+
177197 model_fp16 = float16 .convert_float_to_float16 (
178198 model ,
179199 keep_io_types = True ,
180200 disable_shape_infer = disable_shape_infer ,
201+ op_block_list = blocked_ops ,
181202 )
182203 graph = gs .import_onnx (model_fp16 )
183204 graph .toposort ()
@@ -271,6 +292,7 @@ def quantize(input_folder, output_folder, quantization_args: QuantizationArgumen
271292 quantize_fp16 (
272293 model ,
273294 save_path ,
295+ quantization_args .op_block_list
274296 )
275297
276298 elif mode in (QuantMode .Q4 , QuantMode .Q4F16 ):
@@ -287,6 +309,7 @@ def quantize(input_folder, output_folder, quantization_args: QuantizationArgumen
287309 quantize_fp16 (
288310 q4_model ,
289311 save_path ,
312+ quantization_args .op_block_list ,
290313 )
291314
292315 elif mode == QuantMode .BNB4 :
@@ -331,6 +354,7 @@ def quantize(input_folder, output_folder, quantization_args: QuantizationArgumen
331354 per_channel = quantization_args .per_channel ,
332355 reduce_range = quantization_args .reduce_range ,
333356 weight_type = weight_type ,
357+ op_block_list = quantization_args .op_block_list ,
334358 )
335359
336360
0 commit comments