11from enum import Enum
22
33from tqdm import tqdm
4- from typing import Set , List
4+ from typing import Set , List , Optional
55import onnx
66import os
77
@@ -68,16 +68,6 @@ class QuantizationArguments:
6868 },
6969 )
7070
71- op_block_list : List [str ] = field (
72- default_factory = list ,
73- metadata = {
74- "help" : """List of operators to exclude from quantization.
75- Can be any standard ONNX operator (see https://onnx.ai/onnx/operators/)
76- or your custom implemented operators.""" ,
77- "nargs" : "+" ,
78- },
79- )
80-
8171 # 8-bit quantization
8272 per_channel : bool = field (
8373 default = None , metadata = {"help" : "Whether to quantize weights per channel" }
@@ -120,6 +110,16 @@ class QuantizationArguments:
120110 },
121111 )
122112
113+ op_block_list : List [str ] = field (
114+ default = None ,
115+ metadata = {
116+ "help" : """List of operators to exclude from quantization.
117+ Can be any standard ONNX operator (see https://onnx.ai/onnx/operators/)
118+ or your custom implemented operators.""" ,
119+ "nargs" : "+" ,
120+ },
121+ )
122+
123123
124124def get_operators (model : onnx .ModelProto ) -> Set [str ]:
125125 operators = set ()
@@ -141,7 +141,7 @@ def quantize_q8(
141141 per_channel : bool ,
142142 reduce_range : bool ,
143143 weight_type : QuantType ,
144- op_block_list : List [str ] = [],
144+ op_block_list : Optional [ List [str ]]
145145):
146146 """
147147 Quantize the weights of the model from float32 to int8/uint8
@@ -163,7 +163,7 @@ def quantize_q8(
163163 nodes_to_quantize = [],
164164 nodes_to_exclude = [],
165165 op_types_to_quantize = [
166- op for op in IntegerOpsRegistry .keys () if op not in op_block_list
166+ op for op in IntegerOpsRegistry .keys () if op_block_list is None or op not in op_block_list
167167 ],
168168 extra_options = dict (
169169 EnableSubgraph = True ,
@@ -178,7 +178,7 @@ def quantize_q8(
178178def quantize_fp16 (
179179 model : onnx .ModelProto ,
180180 save_path : str ,
181- op_block_list : List [str ] = [],
181+ op_block_list : Optional [ List [str ]]
182182):
183183 """
184184 Quantize the weights of the model from float32 to float16
@@ -188,12 +188,19 @@ def quantize_fp16(
188188 # ValueError: Message onnx.ModelProto exceeds maximum protobuf size of 2GB: 2338583841
189189 disable_shape_infer = model .ByteSize () >= onnx .checker .MAXIMUM_PROTOBUF
190190
191+ convert_kwargs = {}
192+
193+ # Only include the 'op_block_list' keyword argument if a list is provided.
194+ # This allows the library to apply its default behavior (see https://github.com/huggingface/transformers.js/pull/1036).
195+ # Note: To set 'op_block_list' to an empty list (thereby overriding float16 defaults), a custom script is required.
196+ if op_block_list is not None :
197+ convert_kwargs ["op_block_list" ] = []
198+
191199 model_fp16 = float16 .convert_float_to_float16 (
192200 model ,
193201 keep_io_types = True ,
194202 disable_shape_infer = disable_shape_infer ,
195- op_block_list = op_block_list ,
196-
203+ ** convert_kwargs
197204 )
198205 graph = gs .import_onnx (model_fp16 )
199206 graph .toposort ()
@@ -229,7 +236,6 @@ def quantize_bnb4(
229236 save_path : str ,
230237 block_size : int ,
231238 quant_type : int ,
232- op_block_list : List [str ] = [],
233239):
234240 """
235241 Quantize the weights of the model from float32 to 4-bit int using MatMulBnb4Quantizer
0 commit comments