Skip to content

Commit aec1a03

Browse files
committed
Added flag '--op_types_to_exclude_fp16' to allow any op to be excluded from FP16 conversion
Signed-off-by: gcunhase <[email protected]>
1 parent 0853ce1 commit aec1a03

File tree

4 files changed

+22
-5
lines changed

4 files changed

+22
-5
lines changed

modelopt/onnx/quantization/__main__.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,16 @@ def get_parser() -> argparse.ArgumentParser:
101101
nargs="+",
102102
help="A space-separated list of node types to exclude from quantization.",
103103
)
104+
argparser.add_argument(
105+
"--op_types_to_exclude_fp16",
106+
type=str,
107+
default=[],
108+
nargs="+",
109+
help=(
110+
"A space-separated list of node types to exclude from FP16 conversion. "
111+
"This is only relevant if '--high_precision_dtype != fp32'."
112+
),
113+
)
104114
argparser.add_argument(
105115
"--nodes_to_quantize",
106116
type=str,
@@ -274,6 +284,7 @@ def main():
274284
override_shapes=args.override_shapes,
275285
op_types_to_quantize=args.op_types_to_quantize,
276286
op_types_to_exclude=args.op_types_to_exclude,
287+
op_types_to_exclude_fp16=args.op_types_to_exclude_fp16,
277288
nodes_to_quantize=args.nodes_to_quantize,
278289
nodes_to_exclude=args.nodes_to_exclude,
279290
use_external_data_format=args.use_external_data_format,

modelopt/onnx/quantization/fp8.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,7 @@ def quantize(
168168
calibration_eps: list[str] = ["cpu", "cuda:0", "trt"],
169169
op_types_to_quantize: list[str] | None = None,
170170
op_types_to_exclude: list[str] | None = None,
171+
op_types_to_exclude_fp16: list[str] = [],
171172
nodes_to_quantize: list[str] | None = None,
172173
nodes_to_exclude: list[str] | None = None,
173174
use_external_data_format: bool = False,
@@ -178,7 +179,6 @@ def quantize(
178179
passes: list[str] = ["concat_elimination"],
179180
log_level: str = "INFO",
180181
calibrate_per_node: bool = False,
181-
custom_ops_to_cast_fp32: list[str] = [],
182182
custom_ops_to_quantize: list[str] = [],
183183
direct_io_types: bool = False,
184184
**kwargs,
@@ -319,7 +319,7 @@ def quantize(
319319
onnx_model = convert_to_f16(
320320
onnx_model,
321321
keep_io_types=not direct_io_types,
322-
op_block_list=custom_ops_to_cast_fp32,
322+
op_block_list=op_types_to_exclude_fp16,
323323
low_precision_type=high_precision_dtype,
324324
trt_plugins=trt_extra_plugin_lib_paths,
325325
)

modelopt/onnx/quantization/int8.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ def quantize(
119119
calibration_eps: list[str] = ["cpu", "cuda:0", "trt"],
120120
op_types_to_quantize: list[str] | None = None,
121121
op_types_to_exclude: list[str] | None = None,
122+
op_types_to_exclude_fp16: list[str] = [],
122123
nodes_to_quantize: list[str] | None = None,
123124
nodes_to_exclude: list[str] | None = None,
124125
use_external_data_format: bool = False,
@@ -128,7 +129,6 @@ def quantize(
128129
passes: list[str] = ["concat_elimination"],
129130
log_level: str = "INFO",
130131
calibrate_per_node: bool = False,
131-
custom_ops_to_cast_fp32: list[str] = [],
132132
custom_ops_to_quantize: list[str] = [],
133133
direct_io_types: bool = False,
134134
**kwargs,
@@ -280,7 +280,7 @@ def quantize(
280280
onnx_model = convert_to_f16(
281281
onnx_model,
282282
keep_io_types=not direct_io_types,
283-
op_block_list=custom_ops_to_cast_fp32,
283+
op_block_list=op_types_to_exclude_fp16,
284284
low_precision_type=high_precision_dtype,
285285
trt_plugins=trt_extra_plugin_lib_paths,
286286
)

modelopt/onnx/quantization/quantize.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,7 @@ def quantize(
216216
override_shapes: str | None = None,
217217
op_types_to_quantize: list[str] | None = None,
218218
op_types_to_exclude: list[str] | None = None,
219+
op_types_to_exclude_fp16: list[str] | None = None,
219220
nodes_to_quantize: list[str] | None = None,
220221
nodes_to_exclude: list[str] | None = None,
221222
use_external_data_format: bool = False,
@@ -267,6 +268,9 @@ def quantize(
267268
This flag does not support regular expression.
268269
op_types_to_exclude:
269270
List of op types to exclude from quantization. This flag does not support regular expression.
271+
op_types_to_exclude_fp16:
272+
List of op types to exclude from FP16 conversion.
273+
This is only relevant if '--high_precision_dtype != fp32'.
270274
nodes_to_quantize:
271275
List of node names to quantize. If None (default), all supported nodes are quantized.
272276
This flag supports regular expression.
@@ -422,6 +426,8 @@ def quantize(
422426
quantize_mode,
423427
)
424428
trt_plugins = update_trt_ep_support(calibration_eps, has_dds_op, has_custom_op, trt_plugins) # type: ignore[arg-type]
429+
op_types_to_exclude_fp16 = op_types_to_exclude_fp16 or []
430+
op_types_to_exclude_fp16.extend(list(custom_ops_to_cast_fp32.keys()))
425431

426432
# Use random scales if calibration data is not supplied
427433
if calibration_data is None:
@@ -464,6 +470,7 @@ def quantize(
464470
calibration_eps=calibration_eps,
465471
op_types_to_quantize=op_types_to_quantize,
466472
op_types_to_exclude=op_types_to_exclude,
473+
op_types_to_exclude_fp16=op_types_to_exclude_fp16,
467474
nodes_to_quantize=nodes_to_quantize,
468475
nodes_to_exclude=nodes_to_exclude,
469476
use_external_data_format=use_external_data_format,
@@ -474,7 +481,6 @@ def quantize(
474481
passes=passes,
475482
log_level=log_level,
476483
calibrate_per_node=calibrate_per_node,
477-
custom_ops_to_cast_fp32=list(custom_ops_to_cast_fp32.keys()),
478484
custom_ops_to_quantize=list(custom_ops_to_quantize.keys()),
479485
direct_io_types=direct_io_types,
480486
**kwargs,

0 commit comments

Comments
 (0)