Skip to content

Commit 0e13148

Browse files
gcunhasekevalmorabia97
authored andcommitted
[5455919] Insert cast nodes for 'FP32 required' ops (#363)
Signed-off-by: gcunhase <[email protected]>
1 parent 5a66696 commit 0e13148

File tree

6 files changed

+58
-16
lines changed

6 files changed

+58
-16
lines changed

CHANGELOG.rst

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,15 @@
11
Model Optimizer Changelog (Linux)
22
=================================
33

4+
0.39 (2025-10-xx)
5+
^^^^^^^^^^^^^^^^^
6+
7+
**Deprecations**
8+
9+
**New Features**
10+
11+
- Add flag ``op_types_to_exclude_fp16`` in ONNX quantization to exclude ops from being converted to FP16/BF16. Alternatively, for custom TensorRT ops, this can also be done by indicating ``'fp32'`` precision in ``trt_plugins_precision``.
12+
413
0.37 (2025-09-xx)
514
^^^^^^^^^^^^^^^^^
615

modelopt/onnx/quantization/__main__.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -90,28 +90,33 @@ def get_parser() -> argparse.ArgumentParser:
9090
argparser.add_argument(
9191
"--op_types_to_quantize",
9292
type=str,
93-
default=[],
9493
nargs="+",
9594
help="A space-separated list of node types to quantize.",
9695
)
9796
argparser.add_argument(
9897
"--op_types_to_exclude",
9998
type=str,
100-
default=[],
10199
nargs="+",
102100
help="A space-separated list of node types to exclude from quantization.",
103101
)
102+
argparser.add_argument(
103+
"--op_types_to_exclude_fp16",
104+
type=str,
105+
nargs="+",
106+
help=(
107+
"A space-separated list of node types to exclude from FP16/BF16 conversion. "
108+
"Relevant when --high_precision_dtype is 'fp16' or 'bf16'."
109+
),
110+
)
104111
argparser.add_argument(
105112
"--nodes_to_quantize",
106113
type=str,
107-
default=[],
108114
nargs="+",
109115
help="A space-separated list of node names to quantize. Regular expressions are supported.",
110116
)
111117
argparser.add_argument(
112118
"--nodes_to_exclude",
113119
type=str,
114-
default=[],
115120
nargs="+",
116121
help="A space-separated list of node names to exclude from quantization. Regular expressions are supported.",
117122
)
@@ -273,6 +278,7 @@ def main():
273278
override_shapes=args.override_shapes,
274279
op_types_to_quantize=args.op_types_to_quantize,
275280
op_types_to_exclude=args.op_types_to_exclude,
281+
op_types_to_exclude_fp16=args.op_types_to_exclude_fp16,
276282
nodes_to_quantize=args.nodes_to_quantize,
277283
nodes_to_exclude=args.nodes_to_exclude,
278284
use_external_data_format=args.use_external_data_format,

modelopt/onnx/quantization/fp8.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,7 @@ def quantize(
168168
calibration_eps: list[str] = ["cpu", "cuda:0", "trt"],
169169
op_types_to_quantize: list[str] | None = None,
170170
op_types_to_exclude: list[str] | None = None,
171+
op_types_to_exclude_fp16: list[str] | None = None,
171172
nodes_to_quantize: list[str] | None = None,
172173
nodes_to_exclude: list[str] | None = None,
173174
use_external_data_format: bool = False,
@@ -318,7 +319,7 @@ def quantize(
318319
onnx_model = convert_to_f16(
319320
onnx_model,
320321
keep_io_types=not direct_io_types,
321-
op_block_list=["Resize"],
322+
op_block_list=op_types_to_exclude_fp16 or [],
322323
low_precision_type=high_precision_dtype,
323324
trt_plugins=trt_extra_plugin_lib_paths,
324325
)

modelopt/onnx/quantization/int8.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ def quantize(
119119
calibration_eps: list[str] = ["cpu", "cuda:0", "trt"],
120120
op_types_to_quantize: list[str] | None = None,
121121
op_types_to_exclude: list[str] | None = None,
122+
op_types_to_exclude_fp16: list[str] | None = None,
122123
nodes_to_quantize: list[str] | None = None,
123124
nodes_to_exclude: list[str] | None = None,
124125
use_external_data_format: bool = False,
@@ -279,6 +280,7 @@ def quantize(
279280
onnx_model = convert_to_f16(
280281
onnx_model,
281282
keep_io_types=not direct_io_types,
283+
op_block_list=op_types_to_exclude_fp16 or [],
282284
low_precision_type=high_precision_dtype,
283285
trt_plugins=trt_extra_plugin_lib_paths,
284286
)

modelopt/onnx/quantization/quantize.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def _preprocess_onnx(
7777
override_shapes: str,
7878
simplify: bool = False,
7979
quantize_mode: str = "int8",
80-
) -> tuple[str, onnx.ModelProto, list[str], bool, bool, bool, dict]:
80+
) -> tuple[str, onnx.ModelProto, list[str], bool, bool, bool, dict, dict]:
8181
logger.info(f"Preprocessing the model {onnx_path}")
8282
intermediate_generated_files = []
8383
output_dir = os.path.dirname(output_path)
@@ -176,13 +176,14 @@ def _preprocess_onnx(
176176
intermediate_generated_files.append(onnx_path)
177177

178178
# If custom op precisions are given, add Cast or Q/DQ where appropriate.
179+
custom_ops_to_cast = {}
179180
custom_ops_to_quantize = {}
180181
if trt_plugins_precision:
181182
custom_ops_to_cast, custom_ops_to_quantize = interpret_trt_plugins_precision_flag(
182183
onnx_model, trt_plugins_precision, quantize_mode
183184
)
184-
if custom_ops_to_cast:
185-
onnx_model = cast_custom_ops(onnx_model, custom_ops_to_cast)
185+
if custom_ops_to_cast.get("fp16", {}):
186+
onnx_model = cast_custom_ops(onnx_model, custom_ops_to_cast["fp16"])
186187
onnx_path = os.path.join(output_dir, f"{model_name}_castFP16.onnx")
187188
save_onnx(onnx_model, onnx_path, use_external_data_format)
188189
logger.info(f"Model is cloned to {onnx_path} after casting tensors to FP16")
@@ -195,6 +196,7 @@ def _preprocess_onnx(
195196
has_custom_op,
196197
has_dds_op,
197198
use_external_data_format,
199+
custom_ops_to_cast.get("fp32", {}),
198200
custom_ops_to_quantize,
199201
)
200202

@@ -210,6 +212,7 @@ def quantize(
210212
override_shapes: str | None = None,
211213
op_types_to_quantize: list[str] | None = None,
212214
op_types_to_exclude: list[str] | None = None,
215+
op_types_to_exclude_fp16: list[str] | None = None,
213216
nodes_to_quantize: list[str] | None = None,
214217
nodes_to_exclude: list[str] | None = None,
215218
use_external_data_format: bool = False,
@@ -261,6 +264,9 @@ def quantize(
261264
This flag does not support regular expression.
262265
op_types_to_exclude:
263266
List of op types to exclude from quantization. This flag does not support regular expression.
267+
op_types_to_exclude_fp16:
268+
List of op types to exclude from FP16 conversion.
269+
This is only relevant if '--high_precision_dtype != fp32'.
264270
nodes_to_quantize:
265271
List of node names to quantize. If None (default), all supported nodes are quantized.
266272
This flag supports regular expression.
@@ -402,6 +408,7 @@ def quantize(
402408
has_custom_op,
403409
has_dds_op,
404410
use_external_data_format,
411+
custom_ops_to_cast_fp32,
405412
custom_ops_to_quantize,
406413
) = _preprocess_onnx(
407414
onnx_path,
@@ -416,6 +423,16 @@ def quantize(
416423
)
417424
trt_plugins = update_trt_ep_support(calibration_eps, has_dds_op, has_custom_op, trt_plugins) # type: ignore[arg-type]
418425

426+
# Update list with op types to exclude from FP16/BF16 conversion
427+
op_types_to_exclude_fp16 = list(
428+
dict.fromkeys((op_types_to_exclude_fp16 or []) + list(custom_ops_to_cast_fp32.keys()))
429+
)
430+
if high_precision_dtype == "fp32" and op_types_to_exclude_fp16:
431+
logger.warning(
432+
"Nodes were detected for exclusion from FP16/BF16 conversion, but 'high_precision_dtype' is set to FP32. "
433+
"Since the model won't be converted to a lower precision, this flag is void."
434+
)
435+
419436
# Use random scales if calibration data is not supplied
420437
if calibration_data is None:
421438
calibration_data_reader = RandomDataProvider(onnx_path, calibration_shapes)
@@ -457,6 +474,7 @@ def quantize(
457474
calibration_eps=calibration_eps,
458475
op_types_to_quantize=op_types_to_quantize,
459476
op_types_to_exclude=op_types_to_exclude,
477+
op_types_to_exclude_fp16=op_types_to_exclude_fp16,
460478
nodes_to_quantize=nodes_to_quantize,
461479
nodes_to_exclude=nodes_to_exclude,
462480
use_external_data_format=use_external_data_format,

modelopt/onnx/trt_utils.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -367,10 +367,12 @@ def interpret_trt_plugins_precision_flag(
367367
if trt_plugin_precision.count(":") == 1:
368368
if precision not in supported_precisions:
369369
logger.warning(f"Precision {precision} is not supported. Skipping.")
370-
if precision == "fp16":
371-
custom_ops_to_cast[op_type] = {
372-
"inp": list(range(num_inps)),
373-
"out": list(range(num_outs)),
370+
if precision in ["fp16", "fp32"]:
371+
custom_ops_to_cast[precision] = {
372+
op_type: {
373+
"inp": list(range(num_inps)),
374+
"out": list(range(num_outs)),
375+
}
374376
}
375377
if precision in ["int8", "fp8"]:
376378
if precision != quantize_mode:
@@ -408,10 +410,14 @@ def interpret_trt_plugins_precision_flag(
408410
f"Setting the custom op precision to be the same as quantize mode."
409411
)
410412

411-
# Will cast the inputs to FP16 and the outputs back to FP32
412-
inp_precision_cast = [i for i, p in enumerate(inp_precision) if p == "fp16"]
413-
out_precision_cast = [i for i, p in enumerate(out_precision) if p in ["fp16", "fp32"]]
414-
custom_ops_to_cast[op_type] = {"inp": inp_precision_cast, "out": out_precision_cast}
413+
# Will cast the inputs to FP16/FP32 and the outputs back to FP32
414+
for precision in ["fp16", "fp32"]:
415+
inp_precision_cast = [i for i, p in enumerate(inp_precision) if p == precision]
416+
out_precision_cast = [i for i, p in enumerate(out_precision) if p == precision]
417+
if inp_precision_cast:
418+
custom_ops_to_cast[precision] = {
419+
op_type: {"inp": inp_precision_cast, "out": out_precision_cast}
420+
}
415421

416422
# Will add Q/DQ nodes in the requested I/O indices
417423
inp_precision_quant = [i for i, p in enumerate(inp_precision) if p in ["int8", "fp8"]]

0 commit comments

Comments
 (0)