Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,15 @@
Model Optimizer Changelog (Linux)
=================================

0.39 (2025-10-xx)
^^^^^^^^^^^^^^^^^

**Deprecations**

**New Features**

- Add flag ``op_types_to_exclude_fp16`` in ONNX quantization to exclude ops from being converted to FP16/BF16. Alternatively, for custom TensorRT ops, this can also be done by indicating ``'fp32'`` precision in ``trt_plugins_precision``.

0.37 (2025-09-xx)
^^^^^^^^^^^^^^^^^

Expand Down
14 changes: 10 additions & 4 deletions modelopt/onnx/quantization/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,28 +90,33 @@ def get_parser() -> argparse.ArgumentParser:
argparser.add_argument(
"--op_types_to_quantize",
type=str,
default=[],
nargs="+",
help="A space-separated list of node types to quantize.",
)
argparser.add_argument(
"--op_types_to_exclude",
type=str,
default=[],
nargs="+",
help="A space-separated list of node types to exclude from quantization.",
)
argparser.add_argument(
"--op_types_to_exclude_fp16",
type=str,
nargs="+",
help=(
"A space-separated list of node types to exclude from FP16/BF16 conversion. "
"Relevant when --high_precision_dtype is 'fp16' or 'bf16'."
),
)
argparser.add_argument(
"--nodes_to_quantize",
type=str,
default=[],
nargs="+",
help="A space-separated list of node names to quantize. Regular expressions are supported.",
)
argparser.add_argument(
"--nodes_to_exclude",
type=str,
default=[],
nargs="+",
help="A space-separated list of node names to exclude from quantization. Regular expressions are supported.",
)
Expand Down Expand Up @@ -274,6 +279,7 @@ def main():
override_shapes=args.override_shapes,
op_types_to_quantize=args.op_types_to_quantize,
op_types_to_exclude=args.op_types_to_exclude,
op_types_to_exclude_fp16=args.op_types_to_exclude_fp16,
nodes_to_quantize=args.nodes_to_quantize,
nodes_to_exclude=args.nodes_to_exclude,
use_external_data_format=args.use_external_data_format,
Expand Down
3 changes: 2 additions & 1 deletion modelopt/onnx/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ def quantize(
calibration_eps: list[str] = ["cpu", "cuda:0", "trt"],
op_types_to_quantize: list[str] | None = None,
op_types_to_exclude: list[str] | None = None,
op_types_to_exclude_fp16: list[str] | None = None,
nodes_to_quantize: list[str] | None = None,
nodes_to_exclude: list[str] | None = None,
use_external_data_format: bool = False,
Expand Down Expand Up @@ -318,7 +319,7 @@ def quantize(
onnx_model = convert_to_f16(
onnx_model,
keep_io_types=not direct_io_types,
op_block_list=["Resize"],
op_block_list=op_types_to_exclude_fp16 or [],
low_precision_type=high_precision_dtype,
trt_plugins=trt_extra_plugin_lib_paths,
)
Expand Down
2 changes: 2 additions & 0 deletions modelopt/onnx/quantization/int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def quantize(
calibration_eps: list[str] = ["cpu", "cuda:0", "trt"],
op_types_to_quantize: list[str] | None = None,
op_types_to_exclude: list[str] | None = None,
op_types_to_exclude_fp16: list[str] | None = None,
nodes_to_quantize: list[str] | None = None,
nodes_to_exclude: list[str] | None = None,
use_external_data_format: bool = False,
Expand Down Expand Up @@ -279,6 +280,7 @@ def quantize(
onnx_model = convert_to_f16(
onnx_model,
keep_io_types=not direct_io_types,
op_block_list=op_types_to_exclude_fp16 or [],
low_precision_type=high_precision_dtype,
trt_plugins=trt_extra_plugin_lib_paths,
)
Expand Down
24 changes: 21 additions & 3 deletions modelopt/onnx/quantization/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def _preprocess_onnx(
override_shapes: str,
simplify: bool = False,
quantize_mode: str = "int8",
) -> tuple[str, onnx.ModelProto, list[str], bool, bool, bool, dict]:
) -> tuple[str, onnx.ModelProto, list[str], bool, bool, bool, dict, dict]:
logger.info(f"Preprocessing the model {onnx_path}")
intermediate_generated_files = []
output_dir = os.path.dirname(output_path)
Expand Down Expand Up @@ -180,13 +180,14 @@ def _preprocess_onnx(
intermediate_generated_files.append(onnx_path)

# If custom op precisions are given, add Cast or Q/DQ where appropriate.
custom_ops_to_cast = {}
custom_ops_to_quantize = {}
if trt_plugins_precision:
custom_ops_to_cast, custom_ops_to_quantize = interpret_trt_plugins_precision_flag(
onnx_model, trt_plugins_precision, quantize_mode
)
if custom_ops_to_cast:
onnx_model = cast_custom_ops(onnx_model, custom_ops_to_cast)
if custom_ops_to_cast.get("fp16", {}):
onnx_model = cast_custom_ops(onnx_model, custom_ops_to_cast["fp16"])
onnx_path = os.path.join(output_dir, f"{model_name}_castFP16.onnx")
save_onnx(onnx_model, onnx_path, use_external_data_format)
logger.info(f"Model is cloned to {onnx_path} after casting tensors to FP16")
Expand All @@ -199,6 +200,7 @@ def _preprocess_onnx(
has_custom_op,
has_dds_op,
use_external_data_format,
custom_ops_to_cast.get("fp32", {}),
custom_ops_to_quantize,
)

Expand All @@ -214,6 +216,7 @@ def quantize(
override_shapes: str | None = None,
op_types_to_quantize: list[str] | None = None,
op_types_to_exclude: list[str] | None = None,
op_types_to_exclude_fp16: list[str] | None = None,
nodes_to_quantize: list[str] | None = None,
nodes_to_exclude: list[str] | None = None,
use_external_data_format: bool = False,
Expand Down Expand Up @@ -265,6 +268,9 @@ def quantize(
This flag does not support regular expression.
op_types_to_exclude:
List of op types to exclude from quantization. This flag does not support regular expression.
op_types_to_exclude_fp16:
List of op types to exclude from FP16 conversion.
This is only relevant if '--high_precision_dtype != fp32'.
nodes_to_quantize:
List of node names to quantize. If None (default), all supported nodes are quantized.
This flag supports regular expression.
Expand Down Expand Up @@ -406,6 +412,7 @@ def quantize(
has_custom_op,
has_dds_op,
use_external_data_format,
custom_ops_to_cast_fp32,
custom_ops_to_quantize,
) = _preprocess_onnx(
onnx_path,
Expand All @@ -420,6 +427,16 @@ def quantize(
)
trt_plugins = update_trt_ep_support(calibration_eps, has_dds_op, has_custom_op, trt_plugins) # type: ignore[arg-type]

# Update list with op types to exclude from FP16/BF16 conversion
op_types_to_exclude_fp16 = list(
dict.fromkeys((op_types_to_exclude_fp16 or []) + list(custom_ops_to_cast_fp32.keys()))
)
if high_precision_dtype == "fp32" and op_types_to_exclude_fp16:
logger.warning(
"Nodes were detected for exclusion from FP16/BF16 conversion, but 'high_precision_dtype' is set to FP32. "
"Since the model won't be converted to a lower precision, this flag is void."
)

# Use random scales if calibration data is not supplied
if calibration_data is None:
calibration_data_reader = RandomDataProvider(onnx_path, calibration_shapes)
Expand Down Expand Up @@ -461,6 +478,7 @@ def quantize(
calibration_eps=calibration_eps,
op_types_to_quantize=op_types_to_quantize,
op_types_to_exclude=op_types_to_exclude,
op_types_to_exclude_fp16=op_types_to_exclude_fp16,
nodes_to_quantize=nodes_to_quantize,
nodes_to_exclude=nodes_to_exclude,
use_external_data_format=use_external_data_format,
Expand Down
22 changes: 14 additions & 8 deletions modelopt/onnx/trt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,10 +367,12 @@ def interpret_trt_plugins_precision_flag(
if trt_plugin_precision.count(":") == 1:
if precision not in supported_precisions:
logger.warning(f"Precision {precision} is not supported. Skipping.")
if precision == "fp16":
custom_ops_to_cast[op_type] = {
"inp": list(range(num_inps)),
"out": list(range(num_outs)),
if precision in ["fp16", "fp32"]:
custom_ops_to_cast[precision] = {
op_type: {
"inp": list(range(num_inps)),
"out": list(range(num_outs)),
}
}
if precision in ["int8", "fp8"]:
if precision != quantize_mode:
Expand Down Expand Up @@ -408,10 +410,14 @@ def interpret_trt_plugins_precision_flag(
f"Setting the custom op precision to be the same as quantize mode."
)

# Will cast the inputs to FP16 and the outputs back to FP32
inp_precision_cast = [i for i, p in enumerate(inp_precision) if p == "fp16"]
out_precision_cast = [i for i, p in enumerate(out_precision) if p in ["fp16", "fp32"]]
custom_ops_to_cast[op_type] = {"inp": inp_precision_cast, "out": out_precision_cast}
# Will cast the inputs to FP16/FP32 and the outputs back to FP32
for precision in ["fp16", "fp32"]:
inp_precision_cast = [i for i, p in enumerate(inp_precision) if p == precision]
out_precision_cast = [i for i, p in enumerate(out_precision) if p == precision]
if inp_precision_cast:
custom_ops_to_cast[precision] = {
op_type: {"inp": inp_precision_cast, "out": out_precision_cast}
}

Comment on lines +413 to 421
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Bug: Drops output-only casting and overwrites maps.

  • Mapping is added only if inp_precision_cast is non-empty, so output-only casts are ignored.
  • Overwrites the existing precision map for each op instead of merging.
-            # Will cast the inputs to FP16/FP32 and the outputs back to FP32
-            for precision in ["fp16", "fp32"]:
-                inp_precision_cast = [i for i, p in enumerate(inp_precision) if p == precision]
-                out_precision_cast = [i for i, p in enumerate(out_precision) if p == precision]
-                if inp_precision_cast:
-                    custom_ops_to_cast[precision] = {
-                        op_type: {"inp": inp_precision_cast, "out": out_precision_cast}
-                    }
+            # Will cast requested inputs to FP16/FP32 and outputs back to FP32
+            for precision in ["fp16", "fp32"]:
+                inp_precision_cast = [i for i, p in enumerate(inp_precision) if p == precision]
+                out_precision_cast = [i for i, p in enumerate(out_precision) if p == precision]
+                if inp_precision_cast or out_precision_cast:
+                    ops_map = custom_ops_to_cast.setdefault(precision, {})
+                    ops_map[op_type] = {"inp": inp_precision_cast, "out": out_precision_cast}
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# Will cast the inputs to FP16/FP32 and the outputs back to FP32
for precision in ["fp16", "fp32"]:
inp_precision_cast = [i for i, p in enumerate(inp_precision) if p == precision]
out_precision_cast = [i for i, p in enumerate(out_precision) if p == precision]
if inp_precision_cast:
custom_ops_to_cast[precision] = {
op_type: {"inp": inp_precision_cast, "out": out_precision_cast}
}
# Will cast requested inputs to FP16/FP32 and outputs back to FP32
for precision in ["fp16", "fp32"]:
inp_precision_cast = [i for i, p in enumerate(inp_precision) if p == precision]
out_precision_cast = [i for i, p in enumerate(out_precision) if p == precision]
if inp_precision_cast or out_precision_cast:
ops_map = custom_ops_to_cast.setdefault(precision, {})
ops_map[op_type] = {"inp": inp_precision_cast, "out": out_precision_cast}
🤖 Prompt for AI Agents
In modelopt/onnx/trt_utils.py around lines 413 to 421, the current loop only
adds a precision entry when inp_precision_cast is non-empty and assigns a new
dict for the op_type, which drops output-only casts and overwrites any existing
entries; change it to add the precision mapping when either inp_precision_cast
or out_precision_cast is non-empty, ensure custom_ops_to_cast[precision] is
created if missing, and merge into the existing op_type entry (creating 'inp'
and/or 'out' keys as needed) by extending or setting the index lists rather than
replacing the whole op_type dict so both input-only, output-only and combined
casts are preserved and multiple calls accumulate safely.

# Will add Q/DQ nodes in the requested I/O indices
inp_precision_quant = [i for i, p in enumerate(inp_precision) if p in ["int8", "fp8"]]
Expand Down
Loading