Skip to content

Commit 3e6b01a

Browse files
authored
Minor code suggestions
1 parent 2ce32be commit 3e6b01a

File tree

1 file changed

+2
-28
lines changed

1 file changed

+2
-28
lines changed

scripts/quantize.py

Lines changed: 2 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from enum import Enum
22

33
from tqdm import tqdm
4-
from typing import Set, List, Optional
4+
from typing import Set
55
import onnx
66
import os
77

@@ -110,16 +110,6 @@ class QuantizationArguments:
110110
},
111111
)
112112

113-
op_block_list: List[str] = field(
114-
default=None,
115-
metadata={
116-
"help": """List of operators to exclude from quantization.
117-
Can be any standard ONNX operator (see https://onnx.ai/onnx/operators/)
118-
or your custom implemented operators.""",
119-
"nargs": "+",
120-
},
121-
)
122-
123113

124114
def get_operators(model: onnx.ModelProto) -> Set[str]:
125115
operators = set()
@@ -141,7 +131,6 @@ def quantize_q8(
141131
per_channel: bool,
142132
reduce_range: bool,
143133
weight_type: QuantType,
144-
op_block_list: Optional[List[str]]
145134
):
146135
"""
147136
Quantize the weights of the model from float32 to int8/uint8
@@ -162,9 +151,7 @@ def quantize_q8(
162151
tensors_range=None,
163152
nodes_to_quantize=[],
164153
nodes_to_exclude=[],
165-
op_types_to_quantize=[
166-
op for op in IntegerOpsRegistry.keys() if op_block_list is None or op not in op_block_list
167-
],
154+
op_types_to_quantize=list(IntegerOpsRegistry.keys()),
168155
extra_options=dict(
169156
EnableSubgraph=True,
170157
MatMulConstBOnly=True,
@@ -178,7 +165,6 @@ def quantize_q8(
178165
def quantize_fp16(
179166
model: onnx.ModelProto,
180167
save_path: str,
181-
op_block_list: Optional[List[str]]
182168
):
183169
"""
184170
Quantize the weights of the model from float32 to float16
@@ -188,19 +174,10 @@ def quantize_fp16(
188174
# ValueError: Message onnx.ModelProto exceeds maximum protobuf size of 2GB: 2338583841
189175
disable_shape_infer = model.ByteSize() >= onnx.checker.MAXIMUM_PROTOBUF
190176

191-
convert_kwargs = {}
192-
193-
# Only include the 'op_block_list' keyword argument if a list is provided.
194-
# This allows the library to apply its default behavior (see https://github.com/huggingface/transformers.js/pull/1036).
195-
# Note: To set 'op_block_list' to an empty list (thereby overriding float16 defaults), a custom script is required.
196-
if op_block_list is not None:
197-
convert_kwargs["op_block_list"] = []
198-
199177
model_fp16 = float16.convert_float_to_float16(
200178
model,
201179
keep_io_types=True,
202180
disable_shape_infer=disable_shape_infer,
203-
**convert_kwargs
204181
)
205182
graph = gs.import_onnx(model_fp16)
206183
graph.toposort()
@@ -294,7 +271,6 @@ def quantize(input_folder, output_folder, quantization_args: QuantizationArgumen
294271
quantize_fp16(
295272
model,
296273
save_path,
297-
quantization_args.op_block_list
298274
)
299275

300276
elif mode in (QuantMode.Q4, QuantMode.Q4F16):
@@ -311,7 +287,6 @@ def quantize(input_folder, output_folder, quantization_args: QuantizationArgumen
311287
quantize_fp16(
312288
q4_model,
313289
save_path,
314-
quantization_args.op_block_list,
315290
)
316291

317292
elif mode == QuantMode.BNB4:
@@ -356,7 +331,6 @@ def quantize(input_folder, output_folder, quantization_args: QuantizationArgumen
356331
per_channel=quantization_args.per_channel,
357332
reduce_range=quantization_args.reduce_range,
358333
weight_type=weight_type,
359-
op_block_list=quantization_args.op_block_list,
360334
)
361335

362336

0 commit comments

Comments
 (0)