Skip to content

Commit 2ce32be

Browse files
committed
Set default to none
1 parent ff81cb0 commit 2ce32be

File tree

1 file changed

+23
-17
lines changed

1 file changed

+23
-17
lines changed

scripts/quantize.py

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from enum import Enum
22

33
from tqdm import tqdm
4-
from typing import Set, List
4+
from typing import Set, List, Optional
55
import onnx
66
import os
77

@@ -68,16 +68,6 @@ class QuantizationArguments:
6868
},
6969
)
7070

71-
op_block_list: List[str] = field(
72-
default_factory=list,
73-
metadata={
74-
"help": """List of operators to exclude from quantization.
75-
Can be any standard ONNX operator (see https://onnx.ai/onnx/operators/)
76-
or your custom implemented operators.""",
77-
"nargs": "+",
78-
},
79-
)
80-
8171
# 8-bit quantization
8272
per_channel: bool = field(
8373
default=None, metadata={"help": "Whether to quantize weights per channel"}
@@ -120,6 +110,16 @@ class QuantizationArguments:
120110
},
121111
)
122112

113+
op_block_list: List[str] = field(
114+
default=None,
115+
metadata={
116+
"help": """List of operators to exclude from quantization.
117+
Can be any standard ONNX operator (see https://onnx.ai/onnx/operators/)
118+
or your custom implemented operators.""",
119+
"nargs": "+",
120+
},
121+
)
122+
123123

124124
def get_operators(model: onnx.ModelProto) -> Set[str]:
125125
operators = set()
@@ -141,7 +141,7 @@ def quantize_q8(
141141
per_channel: bool,
142142
reduce_range: bool,
143143
weight_type: QuantType,
144-
op_block_list: List[str] = [],
144+
op_block_list: Optional[List[str]]
145145
):
146146
"""
147147
Quantize the weights of the model from float32 to int8/uint8
@@ -163,7 +163,7 @@ def quantize_q8(
163163
nodes_to_quantize=[],
164164
nodes_to_exclude=[],
165165
op_types_to_quantize=[
166-
op for op in IntegerOpsRegistry.keys() if op not in op_block_list
166+
op for op in IntegerOpsRegistry.keys() if op_block_list is None or op not in op_block_list
167167
],
168168
extra_options=dict(
169169
EnableSubgraph=True,
@@ -178,7 +178,7 @@ def quantize_q8(
178178
def quantize_fp16(
179179
model: onnx.ModelProto,
180180
save_path: str,
181-
op_block_list: List[str] = [],
181+
op_block_list: Optional[List[str]]
182182
):
183183
"""
184184
Quantize the weights of the model from float32 to float16
@@ -188,12 +188,19 @@ def quantize_fp16(
188188
# ValueError: Message onnx.ModelProto exceeds maximum protobuf size of 2GB: 2338583841
189189
disable_shape_infer = model.ByteSize() >= onnx.checker.MAXIMUM_PROTOBUF
190190

191+
convert_kwargs = {}
192+
193+
# Only include the 'op_block_list' keyword argument if a list is provided.
194+
# This allows the library to apply its default behavior (see https://github.com/huggingface/transformers.js/pull/1036).
195+
# Note: To set 'op_block_list' to an empty list (thereby overriding float16 defaults), a custom script is required.
196+
if op_block_list is not None:
197+
convert_kwargs["op_block_list"] = []
198+
191199
model_fp16 = float16.convert_float_to_float16(
192200
model,
193201
keep_io_types=True,
194202
disable_shape_infer=disable_shape_infer,
195-
op_block_list=op_block_list,
196-
203+
**convert_kwargs
197204
)
198205
graph = gs.import_onnx(model_fp16)
199206
graph.toposort()
@@ -229,7 +236,6 @@ def quantize_bnb4(
229236
save_path: str,
230237
block_size: int,
231238
quant_type: int,
232-
op_block_list: List[str] = [],
233239
):
234240
"""
235241
Quantize the weights of the model from float32 to 4-bit int using MatMulBnb4Quantizer

0 commit comments

Comments
 (0)