From 0744103af93b78ce857a1800ec3b5ebe9b2492fa Mon Sep 17 00:00:00 2001 From: gcunhase <4861122+gcunhase@users.noreply.github.com> Date: Fri, 24 Oct 2025 14:51:03 -0400 Subject: [PATCH 01/13] Block IsInf and IsNaN from running in low precision Signed-off-by: gcunhase <4861122+gcunhase@users.noreply.github.com> --- modelopt/onnx/autocast/precisionconverter.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/modelopt/onnx/autocast/precisionconverter.py b/modelopt/onnx/autocast/precisionconverter.py index 64e38f44e..b2218cedf 100644 --- a/modelopt/onnx/autocast/precisionconverter.py +++ b/modelopt/onnx/autocast/precisionconverter.py @@ -65,7 +65,13 @@ class InitializerConsumerTracker: ONNX_TYPES = [t.onnx_type for t in PRECISION_MAP.values()] -OP_TYPES_NOT_SUPPORTED_IN_LOW_PRECISION = ["Upsample", "NonMaxSuppression", "Celu"] +OP_TYPES_NOT_SUPPORTED_IN_LOW_PRECISION = [ + "Upsample", + "NonMaxSuppression", + "Celu", + "IsInf", + "IsNaN", +] # Temporarily block these ops in low precision, as they are not supported yet OP_TYPES_NOT_SUPPORTED_IN_LOW_PRECISION.extend(["Scan", "If", "Loop", "LSTM"]) From f92216c7496e7a48aabeede60e6d03f4ba9a247b Mon Sep 17 00:00:00 2001 From: gcunhase <4861122+gcunhase@users.noreply.github.com> Date: Fri, 24 Oct 2025 14:51:30 -0400 Subject: [PATCH 02/13] Fix type inference error by postponing infer_shapes to PrecisionConverter.convert() Signed-off-by: gcunhase <4861122+gcunhase@users.noreply.github.com> --- modelopt/onnx/autocast/graphsanitizer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/modelopt/onnx/autocast/graphsanitizer.py b/modelopt/onnx/autocast/graphsanitizer.py index d27379760..8e62bd8b6 100644 --- a/modelopt/onnx/autocast/graphsanitizer.py +++ b/modelopt/onnx/autocast/graphsanitizer.py @@ -92,7 +92,6 @@ def convert_fp64_to_fp32(self) -> None: if modified: logger.info("Converted FP64 initializers, I/O types, and nodes to FP32") - self.model = onnx_utils.infer_shapes(self.model, strict_mode=True) def ensure_custom_ops_precision(self) -> None: """Ensure that custom ops run in the requested precision.""" From d4296815e88ecbe061ff2607322d337dc17e3eb0 Mon Sep 17 00:00:00 2001 From: gcunhase <4861122+gcunhase@users.noreply.github.com> Date: Fri, 24 Oct 2025 16:35:04 -0400 Subject: [PATCH 03/13] Added unittest Signed-off-by: gcunhase <4861122+gcunhase@users.noreply.github.com> --- .../onnx/quantization/lib_test_models.py | 99 +++++++++++++++++++ tests/unit/onnx/test_quantize_int8.py | 6 +- 2 files changed, 104 insertions(+), 1 deletion(-) diff --git a/tests/_test_utils/onnx/quantization/lib_test_models.py b/tests/_test_utils/onnx/quantization/lib_test_models.py index ffd827b62..b61061b27 100644 --- a/tests/_test_utils/onnx/quantization/lib_test_models.py +++ b/tests/_test_utils/onnx/quantization/lib_test_models.py @@ -822,3 +822,102 @@ def build_conv_act_pool_model(include_reshape_node=False): onnx.checker.check_model(model_inferred) return model_inferred + + +def build_conv_isinf_model(): + # Define your model inputs and outputs + input_names = ["input_0"] + output_names = ["output_0"] + input_shapes = [(6, 32, 900, 256)] + output_shapes = [(6, 32, 900, 256)] + + inputs = [ + helper.make_tensor_value_info(input_name, onnx.TensorProto.FLOAT, input_shape) + for input_name, input_shape in zip(input_names, input_shapes) + ] + outputs = [ + helper.make_tensor_value_info(output_name, onnx.TensorProto.FLOAT, output_shape) + for output_name, output_shape in zip(output_names, output_shapes) + ] + + # Create the ONNX graph with the nodes + nodes = [ + helper.make_node( + op_type="Conv", + inputs=["input_0", "weights_1"], + outputs=["conv1_conv/Conv2D:0"], + name="conv1_conv/Conv2D", + dilations=[1, 1], + group=1, + kernel_shape=[3, 3], + pads=[1, 1, 1, 1], + strides=[1, 1], + ), + helper.make_node( + op_type="Cast", + inputs=["conv1_conv/Conv2D:0"], + outputs=["cast1_cast/Cast:0"], + name="cast1_cast/Cast", + to=onnx.TensorProto.DOUBLE, + ), + helper.make_node( + op_type="IsInf", + inputs=["cast1_cast/Cast:0"], + outputs=["isinf1_isinf/IsInf:0"], + name="isinf1_isinf/IsInf", + ), + helper.make_node( + op_type="Greater", + inputs=["conv1_conv/Conv2D:0", "greater_const1"], + outputs=["greater1_greater/Greater:0"], + name="greater1_greater/Greater", + ), + helper.make_node( + op_type="And", + inputs=["isinf1_isinf/IsInf:0", "greater1_greater/Greater:0"], + outputs=["and1_and/And:0"], + name="and1_and/And", + ), + helper.make_node( + op_type="Where", + inputs=["and1_and/And:0", "conv1_conv/Conv2D:0", "where_const1"], + outputs=["output_0"], + name="where1_where/Where", + ), + ] + + # Create the ONNX initializers + initializers = [ + helper.make_tensor( + name="weights_1", + data_type=onnx.TensorProto.FLOAT, + dims=(32, 32, 3, 3), + vals=np.random.uniform(low=0.5, high=1.0, size=32 * 32 * 3 * 3), + ), + helper.make_tensor( + name="greater_const1", + data_type=onnx.TensorProto.FLOAT, + dims=(1,), + vals=[0], + ), + helper.make_tensor( + name="where_const1", + data_type=onnx.TensorProto.FLOAT, + dims=(1,), + vals=[10000], + ), + ] + + # Create the ONNX graph with the nodes and initializers + graph = helper.make_graph(nodes, "conv_isinf", inputs, outputs, initializer=initializers) + + # Create the ONNX model + model = helper.make_model(graph) + model.opset_import[0].version = 13 + model.ir_version = 10 + + # Check the ONNX model + model_inferred = onnx.shape_inference.infer_shapes(model) + onnx.checker.check_model(model_inferred) + + return model_inferred diff --git a/tests/unit/onnx/test_quantize_int8.py b/tests/unit/onnx/test_quantize_int8.py index 31c84eff1..9a8ca4992 100644 --- a/tests/unit/onnx/test_quantize_int8.py +++ b/tests/unit/onnx/test_quantize_int8.py @@ -19,7 +19,11 @@ import onnx_graphsurgeon as gs import pytest import torch -from _test_utils.onnx.quantization.lib_test_models import SimpleMLP, export_as_onnx +from _test_utils.onnx_quantization.lib_test_models import ( + SimpleMLP, + build_conv_isinf_model, + export_as_onnx, +) import modelopt.onnx.quantization as moq From 6ff6a3e999b2782ca334bb6222662a7bc926dfe3 Mon Sep 17 00:00:00 2001 From: gcunhase <4861122+gcunhase@users.noreply.github.com> Date: Tue, 28 Oct 2025 12:46:30 -0400 Subject: [PATCH 04/13] Add generic function to detect ops not supported in low precision and exclude from conversion Signed-off-by: gcunhase <4861122+gcunhase@users.noreply.github.com> --- modelopt/onnx/autocast/graphsanitizer.py | 3 - modelopt/onnx/autocast/precisionconverter.py | 13 ++--- modelopt/onnx/autocast/utils.py | 58 ++++++++++++++++++++ 3 files changed, 64 insertions(+), 10 deletions(-) diff --git a/modelopt/onnx/autocast/graphsanitizer.py b/modelopt/onnx/autocast/graphsanitizer.py index 8e62bd8b6..48fc00d92 100644 --- a/modelopt/onnx/autocast/graphsanitizer.py +++ b/modelopt/onnx/autocast/graphsanitizer.py @@ -143,9 +143,6 @@ def remove_disconnected_outputs(self) -> None: def convert_opset(self) -> None: """Convert the model to the given opset version. - Args: - min_opset: minimum opset version to use - The method checks all opset imports and converts the model if any are below the minimum version. """ # Check all opset imports diff --git a/modelopt/onnx/autocast/precisionconverter.py b/modelopt/onnx/autocast/precisionconverter.py index b2218cedf..66d98724e 100644 --- a/modelopt/onnx/autocast/precisionconverter.py +++ b/modelopt/onnx/autocast/precisionconverter.py @@ -65,13 +65,7 @@ class InitializerConsumerTracker: ONNX_TYPES = [t.onnx_type for t in PRECISION_MAP.values()] -OP_TYPES_NOT_SUPPORTED_IN_LOW_PRECISION = [ - "Upsample", - "NonMaxSuppression", - "Celu", - "IsInf", - "IsNaN", -] +OP_TYPES_NOT_SUPPORTED_IN_LOW_PRECISION = ["Upsample", "NonMaxSuppression", "Celu"] # Temporarily block these ops in low precision, as they are not supported yet OP_TYPES_NOT_SUPPORTED_IN_LOW_PRECISION.extend(["Scan", "If", "Loop", "LSTM"]) @@ -144,6 +138,11 @@ def __init__( self.min_opset = min_opset self.max_ir_version = max_ir_version self.trt_plugins = trt_plugins + OP_TYPES_NOT_SUPPORTED_IN_LOW_PRECISION.extend( + utils.get_ops_without_low_precision_support( + self.model, self.low_precision_type.str_full, self.min_opset + ) + ) def convert( self, diff --git a/modelopt/onnx/autocast/utils.py b/modelopt/onnx/autocast/utils.py index 04cf4a2fe..a9fcf06b2 100644 --- a/modelopt/onnx/autocast/utils.py +++ b/modelopt/onnx/autocast/utils.py @@ -21,6 +21,9 @@ support the core functionality of model precision conversion. """ +import logging +from collections import defaultdict + import onnx @@ -115,3 +118,58 @@ def get_cast_to_type(cast_node: onnx.NodeProto) -> int: if attr.name == "to": return attr.i raise ValueError("Cast node does not have 'to' attribute") + + +def get_ops_without_low_precision_support( + model: onnx.ModelProto, + low_precision_type: str, + min_opset: int, +) -> list[str]: + """Get a list of ops without low precision support for the current opset version. + + Args: + model: ONNX model. + low_precision_type: Target precision to reduce to ('float16' or 'bfloat16'). + min_opset: Minimum opset version. + + Returns: + ops_without_support: List of ops without low precision support for the current opset version. + """ + # Obtain the current model's opset version + ai_onnx_domain = [ + opset + for opset in model.opset_import + if not opset.domain or opset.domain in ["ai.onnx", "ai.onnx.contrib", "trt.plugins"] + ] + opset_version = max(ai_onnx_domain[0].version, min_opset) + + # Get all ops precision support information + precision = "tensor(float16)" if low_precision_type == "float16" else "tensor(bfloat16)" + schemas_dict = defaultdict(dict) + for schema in onnx.defs.get_all_schemas_with_history(): + float16_supported = False + for constr in schema.type_constraints: + if precision in constr.allowed_type_strs: + float16_supported = True + break + schemas_dict[schema.name].update({schema.since_version: float16_supported}) + + # Check that all ops are supported in low precision for the current opset version. + # Otherwise, exclude from conversion. + ops_without_support = [] + for op, schema in schemas_dict.items(): + supported_opsets = [k for k, v in schema.items() if v] + if supported_opsets: + min_opset = min(supported_opsets) + if min_opset > opset_version: + ops_without_support.append(op) + else: + ops_without_support.append(op) + + if ops_without_support: + logging.warning( + f"{len(ops_without_support)} ops are not supported in {low_precision_type} in opset {opset_version}. " + f"Skipping those from conversion: {ops_without_support}." + ) + + return ops_without_support From ce579787daa64db539e0618e20568b4034d9d528 Mon Sep 17 00:00:00 2001 From: gcunhase <4861122+gcunhase@users.noreply.github.com> Date: Tue, 28 Oct 2025 13:24:56 -0400 Subject: [PATCH 05/13] Consider only ops in the model and add min opset version in warning log Signed-off-by: gcunhase <4861122+gcunhase@users.noreply.github.com> --- modelopt/onnx/autocast/utils.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/modelopt/onnx/autocast/utils.py b/modelopt/onnx/autocast/utils.py index a9fcf06b2..571cfa737 100644 --- a/modelopt/onnx/autocast/utils.py +++ b/modelopt/onnx/autocast/utils.py @@ -145,8 +145,11 @@ def get_ops_without_low_precision_support( # Get all ops precision support information precision = "tensor(float16)" if low_precision_type == "float16" else "tensor(bfloat16)" + model_ops = {n.op_type for n in model.graph.node} schemas_dict = defaultdict(dict) for schema in onnx.defs.get_all_schemas_with_history(): + if schema.name not in model_ops: + continue float16_supported = False for constr in schema.type_constraints: if precision in constr.allowed_type_strs: @@ -156,20 +159,21 @@ def get_ops_without_low_precision_support( # Check that all ops are supported in low precision for the current opset version. # Otherwise, exclude from conversion. - ops_without_support = [] + ops_without_support = {} for op, schema in schemas_dict.items(): supported_opsets = [k for k, v in schema.items() if v] if supported_opsets: min_opset = min(supported_opsets) if min_opset > opset_version: - ops_without_support.append(op) + ops_without_support[op] = min_opset else: - ops_without_support.append(op) + ops_without_support[op] = None if ops_without_support: logging.warning( - f"{len(ops_without_support)} ops are not supported in {low_precision_type} in opset {opset_version}. " - f"Skipping those from conversion: {ops_without_support}." + f"{len(ops_without_support)} ops are not supported in '{low_precision_type}' in opset {opset_version}, " + f"skipping those from conversion. Upgrade the model's opset version as follows to run them in low " + f" precision: {ops_without_support}." ) - return ops_without_support + return list(ops_without_support.keys()) From 0410a67920e909f53c44e6170f18b60267c59d7b Mon Sep 17 00:00:00 2001 From: gcunhase <4861122+gcunhase@users.noreply.github.com> Date: Tue, 28 Oct 2025 14:21:36 -0400 Subject: [PATCH 06/13] Moved unsupported op detection out of constant Signed-off-by: gcunhase <4861122+gcunhase@users.noreply.github.com> --- modelopt/onnx/autocast/precisionconverter.py | 6 +++--- modelopt/onnx/autocast/utils.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/modelopt/onnx/autocast/precisionconverter.py b/modelopt/onnx/autocast/precisionconverter.py index 66d98724e..39b860b85 100644 --- a/modelopt/onnx/autocast/precisionconverter.py +++ b/modelopt/onnx/autocast/precisionconverter.py @@ -138,8 +138,8 @@ def __init__( self.min_opset = min_opset self.max_ir_version = max_ir_version self.trt_plugins = trt_plugins - OP_TYPES_NOT_SUPPORTED_IN_LOW_PRECISION.extend( - utils.get_ops_without_low_precision_support( + self.op_types_not_supported_in_low_precision = OP_TYPES_NOT_SUPPORTED_IN_LOW_PRECISION + ( + utils.get_op_types_not_supported_in_low_precision( self.model, self.low_precision_type.str_full, self.min_opset ) ) @@ -451,7 +451,7 @@ def _filter_unsupported_op_types( # precision so we need to set Resize and Upsample to high precision for node in self.model.graph.node: if ( - node.op_type in OP_TYPES_NOT_SUPPORTED_IN_LOW_PRECISION + node.op_type in self.op_types_not_supported_in_low_precision and node.name in low_precision_nodes ): low_precision_nodes.remove(node.name) diff --git a/modelopt/onnx/autocast/utils.py b/modelopt/onnx/autocast/utils.py index 571cfa737..ede2974f0 100644 --- a/modelopt/onnx/autocast/utils.py +++ b/modelopt/onnx/autocast/utils.py @@ -120,12 +120,12 @@ def get_cast_to_type(cast_node: onnx.NodeProto) -> int: raise ValueError("Cast node does not have 'to' attribute") -def get_ops_without_low_precision_support( +def get_op_types_not_supported_in_low_precision( model: onnx.ModelProto, low_precision_type: str, min_opset: int, ) -> list[str]: - """Get a list of ops without low precision support for the current opset version. + """Get a list of ops not supported in low precision for the current opset version. Args: model: ONNX model. @@ -133,7 +133,7 @@ def get_ops_without_low_precision_support( min_opset: Minimum opset version. Returns: - ops_without_support: List of ops without low precision support for the current opset version. + ops_without_support: List of ops not supported in low precision for the current opset version. """ # Obtain the current model's opset version ai_onnx_domain = [ From 5c3f15ebe38d8fff80773e67a5f79a344c8c42ca Mon Sep 17 00:00:00 2001 From: gcunhase <4861122+gcunhase@users.noreply.github.com> Date: Wed, 5 Nov 2025 10:16:45 -0500 Subject: [PATCH 07/13] Move unittest to qdq_utils script Signed-off-by: gcunhase <4861122+gcunhase@users.noreply.github.com> --- tests/unit/onnx/test_qdq_rules_int8.py | 30 ++++++++++++++++++++++++++ tests/unit/onnx/test_quantize_int8.py | 6 +----- 2 files changed, 31 insertions(+), 5 deletions(-) diff --git a/tests/unit/onnx/test_qdq_rules_int8.py b/tests/unit/onnx/test_qdq_rules_int8.py index 3f6104427..8e9ddfab7 100644 --- a/tests/unit/onnx/test_qdq_rules_int8.py +++ b/tests/unit/onnx/test_qdq_rules_int8.py @@ -22,6 +22,7 @@ from _test_utils.onnx.quantization.lib_test_models import ( build_conv_act_pool_model, build_conv_batchnorm_sig_mul_model, + build_conv_isinf_model, build_convtranspose_conv_residual_model, build_r1a_model, build_resnet_block, @@ -208,3 +209,32 @@ def test_conv_act_pool_int8(tmp_path, include_reshape_node): # Check that MaxPool is not quantized pool_nodes = [n for n in graph.nodes if n.op == "MaxPool"] assert assert_nodes_are_not_quantized(pool_nodes) + + +def test_conv_isinf_int8(tmp_path): + onnx_model = build_conv_isinf_model() + onnx_path = os.path.join(tmp_path, "conv_isinf_model.onnx") + save_onnx(onnx_model, onnx_path) + + quantize(onnx_path, quantize_mode="int8", high_precision_dtype="fp16") + + # Output model should be produced in the same tmp_path + output_onnx_path = onnx_path.replace(".onnx", ".quant.onnx") + + # Check that quantized explicit model is generated + assert os.path.isfile(output_onnx_path) + + # Load the output model and check QDQ node placements + graph = gs.import_onnx(onnx.load(output_onnx_path)) + + # Check that Conv is quantized + conv_nodes = [n for n in graph.nodes if "Conv" in n.op] + assert assert_nodes_are_quantized(conv_nodes) + + # Check that IsInf is running in FP32 + isinf_nodes = [n for n in graph.nodes if n.op == "IsInf"] + for node in isinf_nodes: + for inp in node.inputs: + assert inp.dtype == "float32", ( + f"Node of type 'IsInf' has type {inp.dtype} but should have type float32" + ) diff --git a/tests/unit/onnx/test_quantize_int8.py b/tests/unit/onnx/test_quantize_int8.py index 9a8ca4992..c2c7cc5b8 100644 --- a/tests/unit/onnx/test_quantize_int8.py +++ b/tests/unit/onnx/test_quantize_int8.py @@ -19,11 +19,7 @@ import onnx_graphsurgeon as gs import pytest import torch -from _test_utils.onnx_quantization.lib_test_models import ( - SimpleMLP, - build_conv_isinf_model, - export_as_onnx, -) +from _test_utils.onnx_quantization.lib_test_models import SimpleMLP, export_as_onnx import modelopt.onnx.quantization as moq From 31fbcc57b0ebb0815af681240e8e10ab1f716221 Mon Sep 17 00:00:00 2001 From: gcunhase <4861122+gcunhase@users.noreply.github.com> Date: Wed, 5 Nov 2025 10:17:40 -0500 Subject: [PATCH 08/13] Fix import path Signed-off-by: gcunhase <4861122+gcunhase@users.noreply.github.com> --- tests/unit/onnx/test_quantize_int8.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/onnx/test_quantize_int8.py b/tests/unit/onnx/test_quantize_int8.py index c2c7cc5b8..31c84eff1 100644 --- a/tests/unit/onnx/test_quantize_int8.py +++ b/tests/unit/onnx/test_quantize_int8.py @@ -19,7 +19,7 @@ import onnx_graphsurgeon as gs import pytest import torch -from _test_utils.onnx_quantization.lib_test_models import SimpleMLP, export_as_onnx +from _test_utils.onnx.quantization.lib_test_models import SimpleMLP, export_as_onnx import modelopt.onnx.quantization as moq From e326aff9e01e4a5bbe364bba97522165c79af678 Mon Sep 17 00:00:00 2001 From: gcunhase <4861122+gcunhase@users.noreply.github.com> Date: Wed, 5 Nov 2025 11:22:25 -0500 Subject: [PATCH 09/13] Moved test model zoo to _test_utils.onnx Signed-off-by: gcunhase <4861122+gcunhase@users.noreply.github.com> --- tests/_test_utils/onnx/{quantization => }/lib_test_models.py | 0 tests/gpu/onnx/test_concat_elim.py | 2 +- tests/gpu/onnx/test_qdq_utils_fp8.py | 2 +- tests/gpu/onnx/test_quantize_fp8.py | 2 +- tests/gpu/onnx/test_quantize_onnx_torch_int4_awq.py | 2 +- tests/gpu/onnx/test_simplify.py | 2 +- tests/unit/onnx/test_convtranspose_qdq.py | 2 +- tests/unit/onnx/test_partitioning.py | 2 +- tests/unit/onnx/test_qdq_rules_int8.py | 2 +- tests/unit/onnx/test_quantize_int8.py | 2 +- tests/unit/onnx/test_quantize_zint4.py | 2 +- 11 files changed, 10 insertions(+), 10 deletions(-) rename tests/_test_utils/onnx/{quantization => }/lib_test_models.py (100%) diff --git a/tests/_test_utils/onnx/quantization/lib_test_models.py b/tests/_test_utils/onnx/lib_test_models.py similarity index 100% rename from tests/_test_utils/onnx/quantization/lib_test_models.py rename to tests/_test_utils/onnx/lib_test_models.py diff --git a/tests/gpu/onnx/test_concat_elim.py b/tests/gpu/onnx/test_concat_elim.py index d6f8fe5ba..9b42c758e 100644 --- a/tests/gpu/onnx/test_concat_elim.py +++ b/tests/gpu/onnx/test_concat_elim.py @@ -18,7 +18,7 @@ import onnx import onnx_graphsurgeon as gs -from _test_utils.onnx.quantization.lib_test_models import build_conv_concat_model +from _test_utils.onnx.lib_test_models import build_conv_concat_model from modelopt.onnx.quantization.quantize import quantize diff --git a/tests/gpu/onnx/test_qdq_utils_fp8.py b/tests/gpu/onnx/test_qdq_utils_fp8.py index b6f45ac8c..0b99adf21 100644 --- a/tests/gpu/onnx/test_qdq_utils_fp8.py +++ b/tests/gpu/onnx/test_qdq_utils_fp8.py @@ -19,7 +19,7 @@ import onnx_graphsurgeon as gs import pytest import torch -from _test_utils.onnx.quantization.lib_test_models import SimpleMLP, export_as_onnx +from _test_utils.onnx.lib_test_models import SimpleMLP, export_as_onnx from modelopt.onnx.quantization.quantize import quantize diff --git a/tests/gpu/onnx/test_quantize_fp8.py b/tests/gpu/onnx/test_quantize_fp8.py index 6b76a2604..2e84082fe 100644 --- a/tests/gpu/onnx/test_quantize_fp8.py +++ b/tests/gpu/onnx/test_quantize_fp8.py @@ -18,7 +18,7 @@ import onnx import onnx_graphsurgeon as gs import torch -from _test_utils.onnx.quantization.lib_test_models import SimpleMLP, export_as_onnx +from _test_utils.onnx.lib_test_models import SimpleMLP, export_as_onnx import modelopt.onnx.quantization as moq diff --git a/tests/gpu/onnx/test_quantize_onnx_torch_int4_awq.py b/tests/gpu/onnx/test_quantize_onnx_torch_int4_awq.py index cb1cad932..73302df22 100644 --- a/tests/gpu/onnx/test_quantize_onnx_torch_int4_awq.py +++ b/tests/gpu/onnx/test_quantize_onnx_torch_int4_awq.py @@ -21,7 +21,7 @@ import torch from _test_utils.import_helper import skip_if_no_libcudnn -from _test_utils.onnx.quantization.lib_test_models import SimpleMLP, export_as_onnx, find_init +from _test_utils.onnx.lib_test_models import SimpleMLP, export_as_onnx, find_init from _test_utils.torch.quantization.quantize_common import get_awq_config import modelopt.onnx.quantization.int4 as int4 diff --git a/tests/gpu/onnx/test_simplify.py b/tests/gpu/onnx/test_simplify.py index 689de27c9..3b6acccb6 100644 --- a/tests/gpu/onnx/test_simplify.py +++ b/tests/gpu/onnx/test_simplify.py @@ -18,7 +18,7 @@ import onnx import onnx_graphsurgeon as gs import torch -from _test_utils.onnx.quantization.lib_test_models import NonSimplifiedModel, export_as_onnx +from _test_utils.onnx.lib_test_models import NonSimplifiedModel, export_as_onnx from _test_utils.onnx.quantization.utils import assert_nodes_are_quantized from modelopt.onnx.quantization.quantize import quantize diff --git a/tests/unit/onnx/test_convtranspose_qdq.py b/tests/unit/onnx/test_convtranspose_qdq.py index c14dc76e9..d06e8fc6d 100644 --- a/tests/unit/onnx/test_convtranspose_qdq.py +++ b/tests/unit/onnx/test_convtranspose_qdq.py @@ -17,7 +17,7 @@ import onnx import pytest import torch -from _test_utils.onnx.quantization.lib_test_models import UNet, export_as_onnx +from _test_utils.onnx.lib_test_models import UNet, export_as_onnx from modelopt.onnx.quantization import quantize diff --git a/tests/unit/onnx/test_partitioning.py b/tests/unit/onnx/test_partitioning.py index ff5b0d90e..43244846f 100644 --- a/tests/unit/onnx/test_partitioning.py +++ b/tests/unit/onnx/test_partitioning.py @@ -17,7 +17,7 @@ import onnx import onnx_graphsurgeon as gs -from _test_utils.onnx.quantization.lib_test_models import export_as_onnx +from _test_utils.onnx.lib_test_models import export_as_onnx from _test_utils.torch.vision_models import get_tiny_resnet_and_input from modelopt.onnx.quantization.graph_utils import ( diff --git a/tests/unit/onnx/test_qdq_rules_int8.py b/tests/unit/onnx/test_qdq_rules_int8.py index 8e9ddfab7..122bb2984 100644 --- a/tests/unit/onnx/test_qdq_rules_int8.py +++ b/tests/unit/onnx/test_qdq_rules_int8.py @@ -19,7 +19,7 @@ import onnx import onnx_graphsurgeon as gs import pytest -from _test_utils.onnx.quantization.lib_test_models import ( +from _test_utils.onnx.lib_test_models import ( build_conv_act_pool_model, build_conv_batchnorm_sig_mul_model, build_conv_isinf_model, diff --git a/tests/unit/onnx/test_quantize_int8.py b/tests/unit/onnx/test_quantize_int8.py index 31c84eff1..9dde2584d 100644 --- a/tests/unit/onnx/test_quantize_int8.py +++ b/tests/unit/onnx/test_quantize_int8.py @@ -19,7 +19,7 @@ import onnx_graphsurgeon as gs import pytest import torch -from _test_utils.onnx.quantization.lib_test_models import SimpleMLP, export_as_onnx +from _test_utils.onnx.lib_test_models import SimpleMLP, export_as_onnx import modelopt.onnx.quantization as moq diff --git a/tests/unit/onnx/test_quantize_zint4.py b/tests/unit/onnx/test_quantize_zint4.py index 1618a1a88..f60533a82 100644 --- a/tests/unit/onnx/test_quantize_zint4.py +++ b/tests/unit/onnx/test_quantize_zint4.py @@ -19,7 +19,7 @@ import numpy as np import onnx import onnx_graphsurgeon as gs -from _test_utils.onnx.quantization.lib_test_models import find_init +from _test_utils.onnx.lib_test_models import find_init import modelopt.onnx.quantization as moq from modelopt.onnx.quantization.int4 import quantize as quantize_int4 From 877487b469caf2b95f77845cf4ef58ead267b1fc Mon Sep 17 00:00:00 2001 From: gcunhase <4861122+gcunhase@users.noreply.github.com> Date: Wed, 5 Nov 2025 11:31:49 -0500 Subject: [PATCH 10/13] Create util to check the model's opset version Signed-off-by: gcunhase <4861122+gcunhase@users.noreply.github.com> --- modelopt/onnx/autocast/utils.py | 9 +++------ modelopt/onnx/quantization/quantize.py | 14 +++++++------- modelopt/onnx/utils.py | 10 ++++++++++ 3 files changed, 20 insertions(+), 13 deletions(-) diff --git a/modelopt/onnx/autocast/utils.py b/modelopt/onnx/autocast/utils.py index ede2974f0..b4ab775b0 100644 --- a/modelopt/onnx/autocast/utils.py +++ b/modelopt/onnx/autocast/utils.py @@ -26,6 +26,8 @@ import onnx +from modelopt.onnx.utils import get_opset_version + def setup_mappings(model: onnx.ModelProto) -> tuple[dict, dict, dict]: """Setup and return mappings for model components. @@ -136,12 +138,7 @@ def get_op_types_not_supported_in_low_precision( ops_without_support: List of ops not supported in low precision for the current opset version. """ # Obtain the current model's opset version - ai_onnx_domain = [ - opset - for opset in model.opset_import - if not opset.domain or opset.domain in ["ai.onnx", "ai.onnx.contrib", "trt.plugins"] - ] - opset_version = max(ai_onnx_domain[0].version, min_opset) + opset_version = max(get_opset_version(model), min_opset) # Get all ops precision support information precision = "tensor(float16)" if low_precision_type == "float16" else "tensor(bfloat16)" diff --git a/modelopt/onnx/quantization/quantize.py b/modelopt/onnx/quantization/quantize.py index 9bc025e33..c00e3ca37 100755 --- a/modelopt/onnx/quantization/quantize.py +++ b/modelopt/onnx/quantization/quantize.py @@ -67,7 +67,12 @@ remove_input_dq_and_output_q, ) from modelopt.onnx.trt_utils import interpret_trt_plugins_precision_flag, load_onnx_model -from modelopt.onnx.utils import duplicate_shared_constants, name_onnx_nodes, save_onnx +from modelopt.onnx.utils import ( + duplicate_shared_constants, + get_opset_version, + name_onnx_nodes, + save_onnx, +) __all__ = ["quantize"] @@ -113,12 +118,7 @@ def _preprocess_onnx( ) # Per-Channel support with QDQ format requires onnx opset version 13 or above - ai_onnx_domain = [ - opset - for opset in onnx_model.opset_import - if not opset.domain or opset.domain in ["ai.onnx", "ai.onnx.contrib"] - ] - opset_version = ai_onnx_domain[0].version + opset_version = get_opset_version(onnx_model) required_opset_version = 13 if opset_version < required_opset_version and opset_version != 1: diff --git a/modelopt/onnx/utils.py b/modelopt/onnx/utils.py index 283b68ea6..a6b37758e 100644 --- a/modelopt/onnx/utils.py +++ b/modelopt/onnx/utils.py @@ -686,6 +686,16 @@ def update_domain(onnx_model: onnx.ModelProto, op_type: str, domain: str) -> onn return onnx_model +def get_opset_version(model: onnx.ModelProto) -> int: + """Returns the opset version of the given model.""" + ai_onnx_domain = [ + opset + for opset in model.opset_import + if not opset.domain or opset.domain in ["ai.onnx", "ai.onnx.contrib", "trt.plugins"] + ] + return ai_onnx_domain[0].version + + def bfloat16_to_float32(bf16_array): """Converts a bfloat16 array (as raw data) to a float32 array.""" uint32_array = bf16_array.astype(np.uint32) << 16 From cd1bb3ebbdb884071379622d146667bc12f117d3 Mon Sep 17 00:00:00 2001 From: gcunhase <4861122+gcunhase@users.noreply.github.com> Date: Wed, 5 Nov 2025 11:32:53 -0500 Subject: [PATCH 11/13] Updated test for multiple opset versions Signed-off-by: gcunhase <4861122+gcunhase@users.noreply.github.com> --- tests/_test_utils/onnx/lib_test_models.py | 4 ++-- tests/unit/onnx/test_qdq_rules_int8.py | 15 ++++++++++----- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/tests/_test_utils/onnx/lib_test_models.py b/tests/_test_utils/onnx/lib_test_models.py index b61061b27..50a991373 100644 --- a/tests/_test_utils/onnx/lib_test_models.py +++ b/tests/_test_utils/onnx/lib_test_models.py @@ -824,7 +824,7 @@ def build_conv_act_pool_model(include_reshape_node=False): return model_inferred -def build_conv_isinf_model(): +def build_conv_isinf_model(opset_version=13): # Define your model inputs and outputs input_names = ["input_0"] output_names = ["output_0"] @@ -913,7 +913,7 @@ def build_conv_isinf_model(): # Create the ONNX model model = helper.make_model(graph) - model.opset_import[0].version = 13 + model.opset_import[0].version = opset_version model.ir_version = 10 # Check the ONNX model diff --git a/tests/unit/onnx/test_qdq_rules_int8.py b/tests/unit/onnx/test_qdq_rules_int8.py index 122bb2984..43d6e4a4e 100644 --- a/tests/unit/onnx/test_qdq_rules_int8.py +++ b/tests/unit/onnx/test_qdq_rules_int8.py @@ -31,7 +31,7 @@ ) from modelopt.onnx.quantization.quantize import quantize -from modelopt.onnx.utils import save_onnx +from modelopt.onnx.utils import get_opset_version, save_onnx def assert_nodes_are_quantized(nodes): @@ -225,16 +225,21 @@ def test_conv_isinf_int8(tmp_path): assert os.path.isfile(output_onnx_path) # Load the output model and check QDQ node placements - graph = gs.import_onnx(onnx.load(output_onnx_path)) + onnx_model = onnx.load(output_onnx_path) + graph = gs.import_onnx(onnx_model) # Check that Conv is quantized conv_nodes = [n for n in graph.nodes if "Conv" in n.op] assert assert_nodes_are_quantized(conv_nodes) - # Check that IsInf is running in FP32 + # Check that IsInf is running in the lowest supported precision: + # - FP32 if opset < 20, or + # - FP16 if opset >= 20 isinf_nodes = [n for n in graph.nodes if n.op == "IsInf"] + opset_version = get_opset_version(onnx_model) + supported_dtype = "float32" if opset_version < 20 else "float16" for node in isinf_nodes: for inp in node.inputs: - assert inp.dtype == "float32", ( - f"Node of type 'IsInf' has type {inp.dtype} but should have type float32" + assert inp.dtype == supported_dtype, ( + f"Node of type {node.op} has type {inp.dtype} but should have type {supported_dtype}" ) From 8f79c72c67f075a2380d5ea7e39ae76ceb9c2ec1 Mon Sep 17 00:00:00 2001 From: gcunhase <4861122+gcunhase@users.noreply.github.com> Date: Wed, 5 Nov 2025 11:33:16 -0500 Subject: [PATCH 12/13] Add unittest in autocast Signed-off-by: gcunhase <4861122+gcunhase@users.noreply.github.com> --- tests/unit/onnx/autocast/test_autocast.py | 41 +++++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/tests/unit/onnx/autocast/test_autocast.py b/tests/unit/onnx/autocast/test_autocast.py index d04ff9163..812df34e3 100644 --- a/tests/unit/onnx/autocast/test_autocast.py +++ b/tests/unit/onnx/autocast/test_autocast.py @@ -13,11 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os from pathlib import Path import numpy as np import onnx +import onnx_graphsurgeon as gs import pytest +from _test_utils.onnx.lib_test_models import build_conv_isinf_model import modelopt.onnx.autocast.utils as utils import modelopt.onnx.utils as onnx_utils @@ -146,3 +149,41 @@ def test_convert_simple_model(temp_model_path, temp_output_path, keep_io_types): assert loaded_model.graph.output[0].type.tensor_type.elem_type == expected_io_type onnx.checker.check_model(loaded_model) + + +def assert_input_precision(nodes, dtype="float16"): + for node in nodes: + for inp in node.inputs: + assert inp.dtype == dtype, ( + f"Node of type {node.op} has type {inp.dtype} but should have type {dtype}" + ) + return True + + +@pytest.mark.parametrize("opset_version", [13, 21]) +def test_conv_isinf_conversion(tmp_path, opset_version): + onnx_model = build_conv_isinf_model(opset_version) + onnx_path = os.path.join(tmp_path, f"conv_isinf_model_opset{opset_version}.onnx") + onnx.save(onnx_model, onnx_path) + + # Convert the model + converted_model = convert_to_mixed_precision(onnx_path=onnx_path, keep_io_types=True) + + # Output model should be produced in the same tmp_path + output_onnx_path = onnx_path.replace(".onnx", ".fp16.onnx") + onnx.save(converted_model, output_onnx_path) + + # Load the output model and check QDQ node placements + graph = gs.import_onnx(converted_model) + + # Check that Conv is converted + conv_nodes = [n for n in graph.nodes if "Conv" in n.op] + assert assert_input_precision(conv_nodes) + + # Check that IsInf is running in the lowest supported precision: + # - FP32 if opset < 20, or + # - FP16 if opset >= 20 + isinf_nodes = [n for n in graph.nodes if n.op == "IsInf"] + opset_version = onnx_utils.get_opset_version(converted_model) + supported_dtype = "float32" if opset_version < 20 else "float16" + assert assert_input_precision(isinf_nodes, dtype=supported_dtype) From 23f099488ea041eb3eb41e9427d710912323b9d0 Mon Sep 17 00:00:00 2001 From: gcunhase <4861122+gcunhase@users.noreply.github.com> Date: Wed, 5 Nov 2025 11:33:51 -0500 Subject: [PATCH 13/13] Add information in docstring, changed args order Signed-off-by: gcunhase <4861122+gcunhase@users.noreply.github.com> --- modelopt/onnx/autocast/precisionconverter.py | 6 +++++- modelopt/onnx/autocast/utils.py | 16 ++++++++++------ 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/modelopt/onnx/autocast/precisionconverter.py b/modelopt/onnx/autocast/precisionconverter.py index 39b860b85..2a0e1e4a7 100644 --- a/modelopt/onnx/autocast/precisionconverter.py +++ b/modelopt/onnx/autocast/precisionconverter.py @@ -138,9 +138,13 @@ def __init__( self.min_opset = min_opset self.max_ir_version = max_ir_version self.trt_plugins = trt_plugins + + # Detect additional ops not supported in low precision according to the model's opset version self.op_types_not_supported_in_low_precision = OP_TYPES_NOT_SUPPORTED_IN_LOW_PRECISION + ( utils.get_op_types_not_supported_in_low_precision( - self.model, self.low_precision_type.str_full, self.min_opset + self.model, + self.min_opset, + self.low_precision_type.str_full, ) ) diff --git a/modelopt/onnx/autocast/utils.py b/modelopt/onnx/autocast/utils.py index b4ab775b0..d9dc3a1f1 100644 --- a/modelopt/onnx/autocast/utils.py +++ b/modelopt/onnx/autocast/utils.py @@ -124,15 +124,19 @@ def get_cast_to_type(cast_node: onnx.NodeProto) -> int: def get_op_types_not_supported_in_low_precision( model: onnx.ModelProto, - low_precision_type: str, min_opset: int, + low_precision_type: str = "float16", ) -> list[str]: - """Get a list of ops not supported in low precision for the current opset version. + """Get a list of ops not supported in low precision for the opset_version = max(model.opset, min_opset). + + An op is considered to be supported if at least one of the inputs may be in low precision. + Ops where only some of the inputs may be in low precision are considered supported by this function + and may need special handling. See PrecisionConverter::_should_skip_low_precision_input_conversion. Args: model: ONNX model. - low_precision_type: Target precision to reduce to ('float16' or 'bfloat16'). min_opset: Minimum opset version. + low_precision_type: Target precision to reduce to ('float16' or 'bfloat16'). Returns: ops_without_support: List of ops not supported in low precision for the current opset version. @@ -160,9 +164,9 @@ def get_op_types_not_supported_in_low_precision( for op, schema in schemas_dict.items(): supported_opsets = [k for k, v in schema.items() if v] if supported_opsets: - min_opset = min(supported_opsets) - if min_opset > opset_version: - ops_without_support[op] = min_opset + min_supported_opset = min(supported_opsets) + if min_supported_opset > opset_version: + ops_without_support[op] = min_supported_opset else: ops_without_support[op] = None