11from enum import Enum
22
33from tqdm import tqdm
4- from typing import Set
4+ from typing import Set , List
55import onnx
66import os
77
@@ -68,6 +68,16 @@ class QuantizationArguments:
6868 },
6969 )
7070
71+ op_block_list : List [str ] = field (
72+ default_factory = list ,
73+ metadata = {
74+ "help" : """List of operators to exclude from quantization.
75+ Can be any standard ONNX operator (see https://onnx.ai/onnx/operators/)
76+ or your custom implemented operators.""" ,
77+ "nargs" : "+" ,
78+ },
79+ )
80+
7181 # 8-bit quantization
7282 per_channel : bool = field (
7383 default = None , metadata = {"help" : "Whether to quantize weights per channel" }
@@ -131,6 +141,7 @@ def quantize_q8(
131141 per_channel : bool ,
132142 reduce_range : bool ,
133143 weight_type : QuantType ,
144+ op_block_list : List [str ] = [],
134145):
135146 """
136147 Quantize the weights of the model from float32 to int8/uint8
@@ -151,7 +162,9 @@ def quantize_q8(
151162 tensors_range = None ,
152163 nodes_to_quantize = [],
153164 nodes_to_exclude = [],
154- op_types_to_quantize = list (IntegerOpsRegistry .keys ()),
165+ op_types_to_quantize = [
166+ op for op in IntegerOpsRegistry .keys () if op not in op_block_list
167+ ],
155168 extra_options = dict (
156169 EnableSubgraph = True ,
157170 MatMulConstBOnly = True ,
@@ -165,6 +178,7 @@ def quantize_q8(
165178def quantize_fp16 (
166179 model : onnx .ModelProto ,
167180 save_path : str ,
181+ op_block_list : List [str ] = [],
168182):
169183 """
170184 Quantize the weights of the model from float32 to float16
@@ -178,6 +192,8 @@ def quantize_fp16(
178192 model ,
179193 keep_io_types = True ,
180194 disable_shape_infer = disable_shape_infer ,
195+ op_block_list = op_block_list ,
196+
181197 )
182198 graph = gs .import_onnx (model_fp16 )
183199 graph .toposort ()
@@ -191,6 +207,7 @@ def quantize_q4(
191207 block_size : int ,
192208 is_symmetric : bool ,
193209 accuracy_level : int ,
210+ op_block_list : List [str ] = [],
194211):
195212 """
196213 Quantize the weights of the model from float32 to 4-bit int
@@ -213,6 +230,7 @@ def quantize_bnb4(
213230 save_path : str ,
214231 block_size : int ,
215232 quant_type : int ,
233+ op_block_list : List [str ] = [],
216234):
217235 """
218236 Quantize the weights of the model from float32 to 4-bit int using MatMulBnb4Quantizer
@@ -282,6 +300,7 @@ def quantize(input_folder, output_folder, quantization_args: QuantizationArgumen
282300 block_size = block_size ,
283301 is_symmetric = quantization_args .is_symmetric ,
284302 accuracy_level = quantization_args .accuracy_level ,
303+ op_block_list = quantization_args .op_block_list ,
285304 )
286305 if mode == QuantMode .Q4F16 :
287306 quantize_fp16 (
@@ -299,6 +318,7 @@ def quantize(input_folder, output_folder, quantization_args: QuantizationArgumen
299318 if quantization_args .quant_type is not None
300319 else MatMulBnb4Quantizer .NF4
301320 ),
321+ op_block_list = quantization_args .op_block_list ,
302322 )
303323
304324 elif mode in (QuantMode .Q8 , QuantMode .QI8 , QuantMode .QU8 ):
@@ -331,6 +351,7 @@ def quantize(input_folder, output_folder, quantization_args: QuantizationArgumen
331351 per_channel = quantization_args .per_channel ,
332352 reduce_range = quantization_args .reduce_range ,
333353 weight_type = weight_type ,
354+ op_block_list = quantization_args .op_block_list ,
334355 )
335356
336357
0 commit comments