Skip to content

Commit b31f40f

Browse files
committed
Add support for op_block_list
1 parent 705cfc4 commit b31f40f

File tree

2 files changed

+24
-2
lines changed

2 files changed

+24
-2
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ __pycache__
22
.vscode
33
node_modules
44
.cache
5+
.DS_STORE
56

67
# Do not track build artifacts/generated files
78
/dist

scripts/quantize.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from enum import Enum
22

33
from tqdm import tqdm
4-
from typing import Set
4+
from typing import Set, List
55
import onnx
66
import os
77

@@ -68,6 +68,16 @@ class QuantizationArguments:
6868
},
6969
)
7070

71+
op_block_list: List[str] = field(
72+
default_factory=list,
73+
metadata={
74+
"help": """List of operators to exclude from quantization.
75+
Can be any standard ONNX operator (see https://onnx.ai/onnx/operators/)
76+
or your custom implemented operators.""",
77+
"nargs": "+",
78+
},
79+
)
80+
7181
# 8-bit quantization
7282
per_channel: bool = field(
7383
default=None, metadata={"help": "Whether to quantize weights per channel"}
@@ -131,6 +141,7 @@ def quantize_q8(
131141
per_channel: bool,
132142
reduce_range: bool,
133143
weight_type: QuantType,
144+
op_block_list: List[str] = [],
134145
):
135146
"""
136147
Quantize the weights of the model from float32 to int8/uint8
@@ -151,7 +162,9 @@ def quantize_q8(
151162
tensors_range=None,
152163
nodes_to_quantize=[],
153164
nodes_to_exclude=[],
154-
op_types_to_quantize=list(IntegerOpsRegistry.keys()),
165+
op_types_to_quantize=[
166+
op for op in IntegerOpsRegistry.keys() if op not in op_block_list
167+
],
155168
extra_options=dict(
156169
EnableSubgraph=True,
157170
MatMulConstBOnly=True,
@@ -165,6 +178,7 @@ def quantize_q8(
165178
def quantize_fp16(
166179
model: onnx.ModelProto,
167180
save_path: str,
181+
op_block_list: List[str] = [],
168182
):
169183
"""
170184
Quantize the weights of the model from float32 to float16
@@ -178,6 +192,8 @@ def quantize_fp16(
178192
model,
179193
keep_io_types=True,
180194
disable_shape_infer=disable_shape_infer,
195+
op_block_list=op_block_list,
196+
181197
)
182198
graph = gs.import_onnx(model_fp16)
183199
graph.toposort()
@@ -191,6 +207,7 @@ def quantize_q4(
191207
block_size: int,
192208
is_symmetric: bool,
193209
accuracy_level: int,
210+
op_block_list: List[str] = [],
194211
):
195212
"""
196213
Quantize the weights of the model from float32 to 4-bit int
@@ -213,6 +230,7 @@ def quantize_bnb4(
213230
save_path: str,
214231
block_size: int,
215232
quant_type: int,
233+
op_block_list: List[str] = [],
216234
):
217235
"""
218236
Quantize the weights of the model from float32 to 4-bit int using MatMulBnb4Quantizer
@@ -282,6 +300,7 @@ def quantize(input_folder, output_folder, quantization_args: QuantizationArgumen
282300
block_size=block_size,
283301
is_symmetric=quantization_args.is_symmetric,
284302
accuracy_level=quantization_args.accuracy_level,
303+
op_block_list=quantization_args.op_block_list,
285304
)
286305
if mode == QuantMode.Q4F16:
287306
quantize_fp16(
@@ -299,6 +318,7 @@ def quantize(input_folder, output_folder, quantization_args: QuantizationArgumen
299318
if quantization_args.quant_type is not None
300319
else MatMulBnb4Quantizer.NF4
301320
),
321+
op_block_list=quantization_args.op_block_list,
302322
)
303323

304324
elif mode in (QuantMode.Q8, QuantMode.QI8, QuantMode.QU8):
@@ -331,6 +351,7 @@ def quantize(input_folder, output_folder, quantization_args: QuantizationArgumen
331351
per_channel=quantization_args.per_channel,
332352
reduce_range=quantization_args.reduce_range,
333353
weight_type=weight_type,
354+
op_block_list=quantization_args.op_block_list,
334355
)
335356

336357

0 commit comments

Comments
 (0)