Skip to content

Commit ae915ea

Browse files
authored
[OMNIML-2493] AutoCast: Configure target opset from CLI (#519)
## What does this PR do? new feature **Overview:** AutoCast: Allow configuring target opset from CLI and main API. ## Usage <!-- You can potentially add a usage example below. --> python -m modelopt.onnx.autocast --onnx_path model.onnx --opset 22 ```python # Add a code snippet demonstrating how to use this ``` ## Testing <!-- Mention how have you tested your change if applicable. --> pytest tests/unit/onnx/autocast/test_autocast.py -k test_opset ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes - **Did you write any new necessary tests?**: Yes - **Did you add or update any necessary documentation?**: Yes - **Did you update [Changelog](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CHANGELOG.rst)?**: No - minor API addition <!--- Only for new features, API changes, critical bug fixes or bw breaking changes. --> ## Additional Information <!-- E.g. related issue. --> Signed-off-by: Gal Hubara Agam <[email protected]>
1 parent 8cf516e commit ae915ea

File tree

5 files changed

+176
-4
lines changed

5 files changed

+176
-4
lines changed

docs/source/guides/8_autocast.rst

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ AutoCast can also be used programmatically through its Python API:
4141
providers=["cpu"], # list of Execution Providers for ONNX-Runtime backend
4242
trt_plugins=[], # list of TensorRT plugin library paths in .so format
4343
max_depth_of_reduction=None, # maximum depth of reduction allowed in low precision
44+
opset=None, # optional target ONNX opset version (default: 13 for fp16, 22 for bf16)
4445
)
4546
4647
# Save the converted model
@@ -55,7 +56,7 @@ AutoCast follows these steps to convert a model:
5556

5657
- Loads the ONNX model
5758
- Performs graph sanitization and optimizations
58-
- Ensures minimum opset version requirements (22 for BF16, 13 for FP16)
59+
- Ensures minimum opset version requirements (22 for BF16, 13 for FP16 by default, or user-specified via ``--opset``)
5960

6061
#. **Node Classification**:
6162

@@ -135,6 +136,14 @@ Best Practices
135136
- To also enable the CUDA execution provider, use ``--providers cpu cuda:x``, where ``x`` is your device ID (``x=0`` if your system only has 1 GPU).
136137
- Use ``--trt_plugins`` to provide the paths to the necessary TensorRT plugin libraries (in ``.so`` format).
137138

139+
#. **Opset Version Control**
140+
141+
- Use ``--opset`` to specify a target ONNX opset version for the converted model.
142+
- If not specified, AutoCast keeps the existing model's opset, subject to a minimum opset based on precision type (13 for FP16, 22 for BF16).
143+
- A warning will be issued if you specify an opset lower than recommended minimum.
144+
- A warning will be issued if you specify an opset lower than the original model's opset, as downgrading opset versions may cause compatibility issues.
145+
- The opset may be automatically increased beyond your specified value if certain operations require it (e.g., quantization nodes require opset >= 19).
146+
138147
Limitations and Restrictions
139148
----------------------------
140149
- AutoCast does not yet support quantized models.
@@ -176,3 +185,15 @@ Limit depth of reduction for precision-sensitive operations:
176185
.. code-block:: bash
177186
178187
python -m modelopt.onnx.autocast --onnx_path model.onnx --max_depth_of_reduction 1024
188+
189+
Specify a target opset version:
190+
191+
.. code-block:: bash
192+
193+
python -m modelopt.onnx.autocast --onnx_path model.onnx --opset 19
194+
195+
Convert to BF16 with a specific opset:
196+
197+
.. code-block:: bash
198+
199+
python -m modelopt.onnx.autocast --onnx_path model.onnx --low_precision_type bf16 --opset 22

modelopt/onnx/autocast/__main__.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,16 @@ def get_parser() -> argparse.ArgumentParser:
175175
"For example: op_type_1:fp16 op_type_2:[fp16,fp32]:[fp16]."
176176
),
177177
)
178+
parser.add_argument(
179+
"--opset",
180+
type=int,
181+
help=(
182+
"Target ONNX opset version for the converted model. If not specified, uses default minimum opset "
183+
"based on precision type (22 for bf16, 13 for fp16). Note: BF16 requires opset >= 22 for full support. "
184+
"The opset may be automatically increased if certain operations (e.g., quantization nodes) require a "
185+
"higher version."
186+
),
187+
)
178188

179189
return parser
180190

@@ -207,6 +217,7 @@ def main(argv=None):
207217
trt_plugins=args.trt_plugins,
208218
trt_plugins_precision=args.trt_plugins_precision,
209219
max_depth_of_reduction=args.max_depth_of_reduction,
220+
opset=args.opset,
210221
)
211222

212223
output_path = args.output_path

modelopt/onnx/autocast/convert.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ def convert_to_mixed_precision(
6060
trt_plugins: list[str] = [],
6161
trt_plugins_precision: list[str] = [],
6262
max_depth_of_reduction: int | None = None,
63+
opset: int | None = None,
6364
) -> onnx.ModelProto:
6465
"""Convert model to mixed precision.
6566
@@ -81,6 +82,9 @@ def convert_to_mixed_precision(
8182
trt_plugins: List of TensorRT plugin library paths in .so format (compiled shared library).
8283
trt_plugins_precision: List indicating the precision for each custom op.
8384
max_depth_of_reduction: Maximum depth of reduction for node classification.
85+
opset: Target ONNX opset version. If None, uses default minimum opset based on low_precision_type
86+
(22 for bf16, 13 for fp16). The opset may be automatically increased if certain operations
87+
require a higher version.
8488
8589
Returns:
8690
onnx.ModelProto: The converted mixed precision model.
@@ -89,10 +93,34 @@ def convert_to_mixed_precision(
8993
model = onnx.load(onnx_path, load_external_data=True)
9094
assert low_precision_type in ["fp16", "bf16"], "low_precision_type must be either fp16 or bf16"
9195

96+
# Get original model's opset version
97+
original_opset = onnx_utils.get_opset_version(model)
98+
9299
# Apply graph sanitization and optimizations
93100
# Opsets < 22 have a very limited support for bfloat16
94101
# Otherwise, prefer to keep the original opset version unless it's very old
95-
min_opset = 22 if low_precision_type == "bf16" else 13
102+
if opset is not None:
103+
min_opset = opset
104+
# Validate opset compatibility
105+
if low_precision_type == "bf16" and opset < 22:
106+
logger.warning(
107+
f"Opset {opset} has limited BF16 support. Recommended minimum opset is 22. "
108+
"The conversion may fail or produce unexpected results."
109+
)
110+
elif low_precision_type == "fp16" and opset < 13:
111+
logger.warning(
112+
f"Opset {opset} has limited FP16 support. Recommended minimum opset is 13. "
113+
"The conversion may fail or produce unexpected results."
114+
)
115+
# Warn if user-specified opset is lower than original
116+
if opset < original_opset:
117+
logger.warning(
118+
f"Specified opset {opset} is lower than the original model's opset {original_opset}. "
119+
"Downgrading opset version may cause compatibility issues or conversion failures."
120+
)
121+
else:
122+
min_opset = 22 if low_precision_type == "bf16" else 13
123+
96124
graph_sanitizer = GraphSanitizer(
97125
model,
98126
min_opset,

modelopt/onnx/autocast/logging_config.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,10 @@ def configure_logging(level=logging.INFO, log_file=None):
6565
except Exception as e:
6666
logger.error(f"Failed to setup file logging to {log_file}: {e!s}")
6767

68-
# Prevent log messages from propagating to the root logger
69-
logger.propagate = False
68+
# Allow log messages to propagate to the root logger for testing compatibility
69+
# This enables pytest's caplog fixture to capture logs while still maintaining
70+
# our custom formatting through the handlers above
71+
logger.propagate = True
7072

7173
# Ensure all child loggers inherit the level setting
7274
for name in logging.root.manager.loggerDict:

tests/unit/onnx/autocast/test_autocast.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import modelopt.onnx.autocast.utils as utils
2626
import modelopt.onnx.utils as onnx_utils
2727
from modelopt.onnx.autocast import convert_to_mixed_precision
28+
from modelopt.onnx.autocast.__main__ import get_parser, main
2829
from modelopt.onnx.autocast.logging_config import configure_logging
2930

3031
configure_logging("DEBUG")
@@ -187,3 +188,112 @@ def test_conv_isinf_conversion(tmp_path, opset_version):
187188
opset_version = onnx_utils.get_opset_version(converted_model)
188189
supported_dtype = "float32" if opset_version < 20 else "float16"
189190
assert assert_input_precision(isinf_nodes, dtype=supported_dtype)
191+
192+
193+
@pytest.mark.parametrize("target_opset", [13, 17, 19, 21])
194+
def test_opset_parameter(temp_model_path, target_opset):
195+
"""Test that the opset parameter correctly sets the output model's opset version."""
196+
# Convert with specific opset
197+
converted_model = convert_to_mixed_precision(
198+
onnx_path=temp_model_path, low_precision_type="fp16", opset=target_opset
199+
)
200+
201+
# Verify the output model has the correct opset
202+
output_opset = onnx_utils.get_opset_version(converted_model)
203+
assert output_opset >= target_opset, f"Expected opset >= {target_opset}, got {output_opset}"
204+
205+
# Validate the model
206+
onnx.checker.check_model(converted_model)
207+
208+
209+
def test_opset_fp16_warning(temp_model_path, caplog):
210+
"""Test that a warning is issued when using fp16 with opset < 13."""
211+
# Convert with fp16 and very low opset
212+
converted_model = convert_to_mixed_precision(
213+
onnx_path=temp_model_path, low_precision_type="fp16", opset=11
214+
)
215+
216+
# Check that a warning was logged
217+
assert "limited FP16 support" in caplog.text, (
218+
"Expected warning about FP16 support with low opset"
219+
)
220+
assert "Recommended minimum opset is 13" in caplog.text
221+
222+
# Model should still be created
223+
assert isinstance(converted_model, onnx.ModelProto)
224+
225+
226+
def test_opset_bf16_warning(temp_model_path, caplog):
227+
"""Test that a warning is issued when using bf16 with opset < 22."""
228+
# Convert with bf16 and low opset
229+
converted_model = convert_to_mixed_precision(
230+
onnx_path=temp_model_path, low_precision_type="bf16", opset=13
231+
)
232+
233+
# Check that a warning was logged
234+
assert "limited BF16 support" in caplog.text, (
235+
"Expected warning about BF16 support with low opset"
236+
)
237+
assert "Recommended minimum opset is 22" in caplog.text
238+
239+
# Model should still be created
240+
assert isinstance(converted_model, onnx.ModelProto)
241+
242+
243+
def test_opset_downgrade_warning(temp_model_path, caplog):
244+
"""Test that a warning is issued when specified opset is lower than original model's opset."""
245+
# temp_model_path fixture creates a model with opset 20
246+
# Convert with lower opset
247+
converted_model = convert_to_mixed_precision(
248+
onnx_path=temp_model_path, low_precision_type="fp16", opset=13
249+
)
250+
251+
# Check that a warning was logged about downgrading
252+
assert "lower than the original model's opset" in caplog.text, (
253+
"Expected warning about downgrading opset"
254+
)
255+
256+
# Model should still be created
257+
assert isinstance(converted_model, onnx.ModelProto)
258+
259+
260+
def test_opset_cli_argument(temp_model_path, tmp_path):
261+
"""Test that the --opset CLI argument is properly parsed and used."""
262+
# Test the CLI with opset argument
263+
output_path = tmp_path / "test_output.onnx"
264+
args = [
265+
"--onnx_path",
266+
temp_model_path,
267+
"--output_path",
268+
str(output_path),
269+
"--opset",
270+
"21",
271+
"--low_precision_type",
272+
"fp16",
273+
]
274+
275+
result_model = main(args)
276+
277+
# Verify the output model has the correct opset
278+
output_opset = onnx_utils.get_opset_version(result_model)
279+
assert output_opset >= 21, f"Expected opset >= 21, got {output_opset}"
280+
281+
# Verify the file was created
282+
assert output_path.exists()
283+
284+
# Load and validate the saved model
285+
saved_model = onnx.load(str(output_path))
286+
onnx.checker.check_model(saved_model)
287+
288+
289+
def test_opset_parser_argument():
290+
"""Test that the parser correctly accepts the --opset argument."""
291+
parser = get_parser()
292+
293+
# Test parsing with opset
294+
args = parser.parse_args(["--onnx_path", "test.onnx", "--opset", "19"])
295+
assert args.opset == 19
296+
297+
# Test parsing without opset (should be None)
298+
args = parser.parse_args(["--onnx_path", "test.onnx"])
299+
assert args.opset is None

0 commit comments

Comments
 (0)