Skip to content

Commit 59a2675

Browse files
authored
[5455919] Insert cast nodes for 'FP32 required' ops (#363)
Signed-off-by: gcunhase <[email protected]>
1 parent a041bbe commit 59a2675

File tree

6 files changed

+58
-16
lines changed

6 files changed

+58
-16
lines changed

CHANGELOG.rst

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,15 @@
11
Model Optimizer Changelog (Linux)
22
=================================
33

4+
0.39 (2025-10-xx)
5+
^^^^^^^^^^^^^^^^^
6+
7+
**Deprecations**
8+
9+
**New Features**
10+
11+
- Add flag ``op_types_to_exclude_fp16`` in ONNX quantization to exclude ops from being converted to FP16/BF16. Alternatively, for custom TensorRT ops, this can also be done by indicating ``'fp32'`` precision in ``trt_plugins_precision``.
12+
413
0.37 (2025-09-xx)
514
^^^^^^^^^^^^^^^^^
615

modelopt/onnx/quantization/__main__.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -90,28 +90,33 @@ def get_parser() -> argparse.ArgumentParser:
9090
argparser.add_argument(
9191
"--op_types_to_quantize",
9292
type=str,
93-
default=[],
9493
nargs="+",
9594
help="A space-separated list of node types to quantize.",
9695
)
9796
argparser.add_argument(
9897
"--op_types_to_exclude",
9998
type=str,
100-
default=[],
10199
nargs="+",
102100
help="A space-separated list of node types to exclude from quantization.",
103101
)
102+
argparser.add_argument(
103+
"--op_types_to_exclude_fp16",
104+
type=str,
105+
nargs="+",
106+
help=(
107+
"A space-separated list of node types to exclude from FP16/BF16 conversion. "
108+
"Relevant when --high_precision_dtype is 'fp16' or 'bf16'."
109+
),
110+
)
104111
argparser.add_argument(
105112
"--nodes_to_quantize",
106113
type=str,
107-
default=[],
108114
nargs="+",
109115
help="A space-separated list of node names to quantize. Regular expressions are supported.",
110116
)
111117
argparser.add_argument(
112118
"--nodes_to_exclude",
113119
type=str,
114-
default=[],
115120
nargs="+",
116121
help="A space-separated list of node names to exclude from quantization. Regular expressions are supported.",
117122
)
@@ -274,6 +279,7 @@ def main():
274279
override_shapes=args.override_shapes,
275280
op_types_to_quantize=args.op_types_to_quantize,
276281
op_types_to_exclude=args.op_types_to_exclude,
282+
op_types_to_exclude_fp16=args.op_types_to_exclude_fp16,
277283
nodes_to_quantize=args.nodes_to_quantize,
278284
nodes_to_exclude=args.nodes_to_exclude,
279285
use_external_data_format=args.use_external_data_format,

modelopt/onnx/quantization/fp8.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,7 @@ def quantize(
168168
calibration_eps: list[str] = ["cpu", "cuda:0", "trt"],
169169
op_types_to_quantize: list[str] | None = None,
170170
op_types_to_exclude: list[str] | None = None,
171+
op_types_to_exclude_fp16: list[str] | None = None,
171172
nodes_to_quantize: list[str] | None = None,
172173
nodes_to_exclude: list[str] | None = None,
173174
use_external_data_format: bool = False,
@@ -318,7 +319,7 @@ def quantize(
318319
onnx_model = convert_to_f16(
319320
onnx_model,
320321
keep_io_types=not direct_io_types,
321-
op_block_list=["Resize"],
322+
op_block_list=op_types_to_exclude_fp16 or [],
322323
low_precision_type=high_precision_dtype,
323324
trt_plugins=trt_extra_plugin_lib_paths,
324325
)

modelopt/onnx/quantization/int8.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ def quantize(
119119
calibration_eps: list[str] = ["cpu", "cuda:0", "trt"],
120120
op_types_to_quantize: list[str] | None = None,
121121
op_types_to_exclude: list[str] | None = None,
122+
op_types_to_exclude_fp16: list[str] | None = None,
122123
nodes_to_quantize: list[str] | None = None,
123124
nodes_to_exclude: list[str] | None = None,
124125
use_external_data_format: bool = False,
@@ -279,6 +280,7 @@ def quantize(
279280
onnx_model = convert_to_f16(
280281
onnx_model,
281282
keep_io_types=not direct_io_types,
283+
op_block_list=op_types_to_exclude_fp16 or [],
282284
low_precision_type=high_precision_dtype,
283285
trt_plugins=trt_extra_plugin_lib_paths,
284286
)

modelopt/onnx/quantization/quantize.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def _preprocess_onnx(
8181
override_shapes: str,
8282
simplify: bool = False,
8383
quantize_mode: str = "int8",
84-
) -> tuple[str, onnx.ModelProto, list[str], bool, bool, bool, dict]:
84+
) -> tuple[str, onnx.ModelProto, list[str], bool, bool, bool, dict, dict]:
8585
logger.info(f"Preprocessing the model {onnx_path}")
8686
intermediate_generated_files = []
8787
output_dir = os.path.dirname(output_path)
@@ -180,13 +180,14 @@ def _preprocess_onnx(
180180
intermediate_generated_files.append(onnx_path)
181181

182182
# If custom op precisions are given, add Cast or Q/DQ where appropriate.
183+
custom_ops_to_cast = {}
183184
custom_ops_to_quantize = {}
184185
if trt_plugins_precision:
185186
custom_ops_to_cast, custom_ops_to_quantize = interpret_trt_plugins_precision_flag(
186187
onnx_model, trt_plugins_precision, quantize_mode
187188
)
188-
if custom_ops_to_cast:
189-
onnx_model = cast_custom_ops(onnx_model, custom_ops_to_cast)
189+
if custom_ops_to_cast.get("fp16", {}):
190+
onnx_model = cast_custom_ops(onnx_model, custom_ops_to_cast["fp16"])
190191
onnx_path = os.path.join(output_dir, f"{model_name}_castFP16.onnx")
191192
save_onnx(onnx_model, onnx_path, use_external_data_format)
192193
logger.info(f"Model is cloned to {onnx_path} after casting tensors to FP16")
@@ -199,6 +200,7 @@ def _preprocess_onnx(
199200
has_custom_op,
200201
has_dds_op,
201202
use_external_data_format,
203+
custom_ops_to_cast.get("fp32", {}),
202204
custom_ops_to_quantize,
203205
)
204206

@@ -214,6 +216,7 @@ def quantize(
214216
override_shapes: str | None = None,
215217
op_types_to_quantize: list[str] | None = None,
216218
op_types_to_exclude: list[str] | None = None,
219+
op_types_to_exclude_fp16: list[str] | None = None,
217220
nodes_to_quantize: list[str] | None = None,
218221
nodes_to_exclude: list[str] | None = None,
219222
use_external_data_format: bool = False,
@@ -265,6 +268,9 @@ def quantize(
265268
This flag does not support regular expression.
266269
op_types_to_exclude:
267270
List of op types to exclude from quantization. This flag does not support regular expression.
271+
op_types_to_exclude_fp16:
272+
List of op types to exclude from FP16 conversion.
273+
This is only relevant if '--high_precision_dtype != fp32'.
268274
nodes_to_quantize:
269275
List of node names to quantize. If None (default), all supported nodes are quantized.
270276
This flag supports regular expression.
@@ -406,6 +412,7 @@ def quantize(
406412
has_custom_op,
407413
has_dds_op,
408414
use_external_data_format,
415+
custom_ops_to_cast_fp32,
409416
custom_ops_to_quantize,
410417
) = _preprocess_onnx(
411418
onnx_path,
@@ -420,6 +427,16 @@ def quantize(
420427
)
421428
trt_plugins = update_trt_ep_support(calibration_eps, has_dds_op, has_custom_op, trt_plugins) # type: ignore[arg-type]
422429

430+
# Update list with op types to exclude from FP16/BF16 conversion
431+
op_types_to_exclude_fp16 = list(
432+
dict.fromkeys((op_types_to_exclude_fp16 or []) + list(custom_ops_to_cast_fp32.keys()))
433+
)
434+
if high_precision_dtype == "fp32" and op_types_to_exclude_fp16:
435+
logger.warning(
436+
"Nodes were detected for exclusion from FP16/BF16 conversion, but 'high_precision_dtype' is set to FP32. "
437+
"Since the model won't be converted to a lower precision, this flag is void."
438+
)
439+
423440
# Use random scales if calibration data is not supplied
424441
if calibration_data is None:
425442
calibration_data_reader = RandomDataProvider(onnx_path, calibration_shapes)
@@ -461,6 +478,7 @@ def quantize(
461478
calibration_eps=calibration_eps,
462479
op_types_to_quantize=op_types_to_quantize,
463480
op_types_to_exclude=op_types_to_exclude,
481+
op_types_to_exclude_fp16=op_types_to_exclude_fp16,
464482
nodes_to_quantize=nodes_to_quantize,
465483
nodes_to_exclude=nodes_to_exclude,
466484
use_external_data_format=use_external_data_format,

modelopt/onnx/trt_utils.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -367,10 +367,12 @@ def interpret_trt_plugins_precision_flag(
367367
if trt_plugin_precision.count(":") == 1:
368368
if precision not in supported_precisions:
369369
logger.warning(f"Precision {precision} is not supported. Skipping.")
370-
if precision == "fp16":
371-
custom_ops_to_cast[op_type] = {
372-
"inp": list(range(num_inps)),
373-
"out": list(range(num_outs)),
370+
if precision in ["fp16", "fp32"]:
371+
custom_ops_to_cast[precision] = {
372+
op_type: {
373+
"inp": list(range(num_inps)),
374+
"out": list(range(num_outs)),
375+
}
374376
}
375377
if precision in ["int8", "fp8"]:
376378
if precision != quantize_mode:
@@ -408,10 +410,14 @@ def interpret_trt_plugins_precision_flag(
408410
f"Setting the custom op precision to be the same as quantize mode."
409411
)
410412

411-
# Will cast the inputs to FP16 and the outputs back to FP32
412-
inp_precision_cast = [i for i, p in enumerate(inp_precision) if p == "fp16"]
413-
out_precision_cast = [i for i, p in enumerate(out_precision) if p in ["fp16", "fp32"]]
414-
custom_ops_to_cast[op_type] = {"inp": inp_precision_cast, "out": out_precision_cast}
413+
# Will cast the inputs to FP16/FP32 and the outputs back to FP32
414+
for precision in ["fp16", "fp32"]:
415+
inp_precision_cast = [i for i, p in enumerate(inp_precision) if p == precision]
416+
out_precision_cast = [i for i, p in enumerate(out_precision) if p == precision]
417+
if inp_precision_cast:
418+
custom_ops_to_cast[precision] = {
419+
op_type: {"inp": inp_precision_cast, "out": out_precision_cast}
420+
}
415421

416422
# Will add Q/DQ nodes in the requested I/O indices
417423
inp_precision_quant = [i for i, p in enumerate(inp_precision) if p in ["int8", "fp8"]]

0 commit comments

Comments
 (0)