Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions modelopt/onnx/autocast/graphsanitizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,6 @@ def convert_fp64_to_fp32(self) -> None:

if modified:
logger.info("Converted FP64 initializers, I/O types, and nodes to FP32")
self.model = onnx_utils.infer_shapes(self.model, strict_mode=True)

def ensure_custom_ops_precision(self) -> None:
"""Ensure that custom ops run in the requested precision."""
Expand Down Expand Up @@ -144,9 +143,6 @@ def remove_disconnected_outputs(self) -> None:
def convert_opset(self) -> None:
"""Convert the model to the given opset version.

Args:
min_opset: minimum opset version to use

The method checks all opset imports and converts the model if any are below the minimum version.
"""
# Check all opset imports
Expand Down
11 changes: 10 additions & 1 deletion modelopt/onnx/autocast/precisionconverter.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,15 @@ def __init__(
self.max_ir_version = max_ir_version
self.trt_plugins = trt_plugins

# Detect additional ops not supported in low precision according to the model's opset version
self.op_types_not_supported_in_low_precision = OP_TYPES_NOT_SUPPORTED_IN_LOW_PRECISION + (
utils.get_op_types_not_supported_in_low_precision(
self.model,
self.min_opset,
self.low_precision_type.str_full,
)
)

def convert(
self,
high_precision_nodes: list[str],
Expand Down Expand Up @@ -446,7 +455,7 @@ def _filter_unsupported_op_types(
# precision so we need to set Resize and Upsample to high precision
for node in self.model.graph.node:
if (
node.op_type in OP_TYPES_NOT_SUPPORTED_IN_LOW_PRECISION
node.op_type in self.op_types_not_supported_in_low_precision
and node.name in low_precision_nodes
):
low_precision_nodes.remove(node.name)
Expand Down
63 changes: 63 additions & 0 deletions modelopt/onnx/autocast/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,13 @@
support the core functionality of model precision conversion.
"""

import logging
from collections import defaultdict

import onnx

from modelopt.onnx.utils import get_opset_version


def setup_mappings(model: onnx.ModelProto) -> tuple[dict, dict, dict]:
"""Setup and return mappings for model components.
Expand Down Expand Up @@ -115,3 +120,61 @@ def get_cast_to_type(cast_node: onnx.NodeProto) -> int:
if attr.name == "to":
return attr.i
raise ValueError("Cast node does not have 'to' attribute")


def get_op_types_not_supported_in_low_precision(
model: onnx.ModelProto,
min_opset: int,
low_precision_type: str = "float16",
) -> list[str]:
"""Get a list of ops not supported in low precision for the opset_version = max(model.opset, min_opset).

An op is considered to be supported if at least one of the inputs may be in low precision.
Ops where only some of the inputs may be in low precision are considered supported by this function
and may need special handling. See PrecisionConverter::_should_skip_low_precision_input_conversion.

Args:
model: ONNX model.
min_opset: Minimum opset version.
low_precision_type: Target precision to reduce to ('float16' or 'bfloat16').

Returns:
ops_without_support: List of ops not supported in low precision for the current opset version.
"""
# Obtain the current model's opset version
opset_version = max(get_opset_version(model), min_opset)

# Get all ops precision support information
precision = "tensor(float16)" if low_precision_type == "float16" else "tensor(bfloat16)"
model_ops = {n.op_type for n in model.graph.node}
schemas_dict = defaultdict(dict)
for schema in onnx.defs.get_all_schemas_with_history():
if schema.name not in model_ops:
continue
float16_supported = False
for constr in schema.type_constraints:
if precision in constr.allowed_type_strs:
float16_supported = True
break
schemas_dict[schema.name].update({schema.since_version: float16_supported})

# Check that all ops are supported in low precision for the current opset version.
# Otherwise, exclude from conversion.
ops_without_support = {}
for op, schema in schemas_dict.items():
supported_opsets = [k for k, v in schema.items() if v]
if supported_opsets:
min_supported_opset = min(supported_opsets)
if min_supported_opset > opset_version:
ops_without_support[op] = min_supported_opset
else:
ops_without_support[op] = None

if ops_without_support:
logging.warning(
f"{len(ops_without_support)} ops are not supported in '{low_precision_type}' in opset {opset_version}, "
f"skipping those from conversion. Upgrade the model's opset version as follows to run them in low "
f" precision: {ops_without_support}."
)

return list(ops_without_support.keys())
14 changes: 7 additions & 7 deletions modelopt/onnx/quantization/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,12 @@
remove_input_dq_and_output_q,
)
from modelopt.onnx.trt_utils import interpret_trt_plugins_precision_flag, load_onnx_model
from modelopt.onnx.utils import duplicate_shared_constants, name_onnx_nodes, save_onnx
from modelopt.onnx.utils import (
duplicate_shared_constants,
get_opset_version,
name_onnx_nodes,
save_onnx,
)

__all__ = ["quantize"]

Expand Down Expand Up @@ -113,12 +118,7 @@ def _preprocess_onnx(
)

# Per-Channel support with QDQ format requires onnx opset version 13 or above
ai_onnx_domain = [
opset
for opset in onnx_model.opset_import
if not opset.domain or opset.domain in ["ai.onnx", "ai.onnx.contrib"]
]
opset_version = ai_onnx_domain[0].version
opset_version = get_opset_version(onnx_model)

required_opset_version = 13
if opset_version < required_opset_version and opset_version != 1:
Expand Down
10 changes: 10 additions & 0 deletions modelopt/onnx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,6 +686,16 @@ def update_domain(onnx_model: onnx.ModelProto, op_type: str, domain: str) -> onn
return onnx_model


def get_opset_version(model: onnx.ModelProto) -> int:
"""Returns the opset version of the given model."""
ai_onnx_domain = [
opset
for opset in model.opset_import
if not opset.domain or opset.domain in ["ai.onnx", "ai.onnx.contrib", "trt.plugins"]
]
return ai_onnx_domain[0].version


def bfloat16_to_float32(bf16_array):
"""Converts a bfloat16 array (as raw data) to a float32 array."""
uint32_array = bf16_array.astype(np.uint32) << 16
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -822,3 +822,102 @@ def build_conv_act_pool_model(include_reshape_node=False):
onnx.checker.check_model(model_inferred)

return model_inferred


def build_conv_isinf_model(opset_version=13):
# Define your model inputs and outputs
input_names = ["input_0"]
output_names = ["output_0"]
input_shapes = [(6, 32, 900, 256)]
output_shapes = [(6, 32, 900, 256)]

inputs = [
helper.make_tensor_value_info(input_name, onnx.TensorProto.FLOAT, input_shape)
for input_name, input_shape in zip(input_names, input_shapes)
]
outputs = [
helper.make_tensor_value_info(output_name, onnx.TensorProto.FLOAT, output_shape)
for output_name, output_shape in zip(output_names, output_shapes)
]

# Create the ONNX graph with the nodes
nodes = [
helper.make_node(
op_type="Conv",
inputs=["input_0", "weights_1"],
outputs=["conv1_conv/Conv2D:0"],
name="conv1_conv/Conv2D",
dilations=[1, 1],
group=1,
kernel_shape=[3, 3],
pads=[1, 1, 1, 1],
strides=[1, 1],
),
helper.make_node(
op_type="Cast",
inputs=["conv1_conv/Conv2D:0"],
outputs=["cast1_cast/Cast:0"],
name="cast1_cast/Cast",
to=onnx.TensorProto.DOUBLE,
),
helper.make_node(
op_type="IsInf",
inputs=["cast1_cast/Cast:0"],
outputs=["isinf1_isinf/IsInf:0"],
name="isinf1_isinf/IsInf",
),
helper.make_node(
op_type="Greater",
inputs=["conv1_conv/Conv2D:0", "greater_const1"],
outputs=["greater1_greater/Greater:0"],
name="greater1_greater/Greater",
),
helper.make_node(
op_type="And",
inputs=["isinf1_isinf/IsInf:0", "greater1_greater/Greater:0"],
outputs=["and1_and/And:0"],
name="and1_and/And",
),
helper.make_node(
op_type="Where",
inputs=["and1_and/And:0", "conv1_conv/Conv2D:0", "where_const1"],
outputs=["output_0"],
name="where1_where/Where",
),
]

# Create the ONNX initializers
initializers = [
helper.make_tensor(
name="weights_1",
data_type=onnx.TensorProto.FLOAT,
dims=(32, 32, 3, 3),
vals=np.random.uniform(low=0.5, high=1.0, size=32 * 32 * 3 * 3),
),
helper.make_tensor(
name="greater_const1",
data_type=onnx.TensorProto.FLOAT,
dims=(1,),
vals=[0],
),
helper.make_tensor(
name="where_const1",
data_type=onnx.TensorProto.FLOAT,
dims=(1,),
vals=[10000],
),
]

# Create the ONNX graph with the nodes and initializers
graph = helper.make_graph(nodes, "conv_isinf", inputs, outputs, initializer=initializers)

# Create the ONNX model
model = helper.make_model(graph)
model.opset_import[0].version = opset_version
model.ir_version = 10

# Check the ONNX model
model_inferred = onnx.shape_inference.infer_shapes(model)
onnx.checker.check_model(model_inferred)

return model_inferred
2 changes: 1 addition & 1 deletion tests/gpu/onnx/test_concat_elim.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import onnx
import onnx_graphsurgeon as gs
from _test_utils.onnx.quantization.lib_test_models import build_conv_concat_model
from _test_utils.onnx.lib_test_models import build_conv_concat_model

from modelopt.onnx.quantization.quantize import quantize

Expand Down
2 changes: 1 addition & 1 deletion tests/gpu/onnx/test_qdq_utils_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import onnx_graphsurgeon as gs
import pytest
import torch
from _test_utils.onnx.quantization.lib_test_models import SimpleMLP, export_as_onnx
from _test_utils.onnx.lib_test_models import SimpleMLP, export_as_onnx

from modelopt.onnx.quantization.quantize import quantize

Expand Down
2 changes: 1 addition & 1 deletion tests/gpu/onnx/test_quantize_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import onnx
import onnx_graphsurgeon as gs
import torch
from _test_utils.onnx.quantization.lib_test_models import SimpleMLP, export_as_onnx
from _test_utils.onnx.lib_test_models import SimpleMLP, export_as_onnx

import modelopt.onnx.quantization as moq

Expand Down
2 changes: 1 addition & 1 deletion tests/gpu/onnx/test_quantize_onnx_torch_int4_awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

import torch
from _test_utils.import_helper import skip_if_no_libcudnn
from _test_utils.onnx.quantization.lib_test_models import SimpleMLP, export_as_onnx, find_init
from _test_utils.onnx.lib_test_models import SimpleMLP, export_as_onnx, find_init
from _test_utils.torch.quantization.quantize_common import get_awq_config

import modelopt.onnx.quantization.int4 as int4
Expand Down
2 changes: 1 addition & 1 deletion tests/gpu/onnx/test_simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import onnx
import onnx_graphsurgeon as gs
import torch
from _test_utils.onnx.quantization.lib_test_models import NonSimplifiedModel, export_as_onnx
from _test_utils.onnx.lib_test_models import NonSimplifiedModel, export_as_onnx
from _test_utils.onnx.quantization.utils import assert_nodes_are_quantized

from modelopt.onnx.quantization.quantize import quantize
Expand Down
41 changes: 41 additions & 0 deletions tests/unit/onnx/autocast/test_autocast.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from pathlib import Path

import numpy as np
import onnx
import onnx_graphsurgeon as gs
import pytest
from _test_utils.onnx.lib_test_models import build_conv_isinf_model

import modelopt.onnx.autocast.utils as utils
import modelopt.onnx.utils as onnx_utils
Expand Down Expand Up @@ -146,3 +149,41 @@ def test_convert_simple_model(temp_model_path, temp_output_path, keep_io_types):
assert loaded_model.graph.output[0].type.tensor_type.elem_type == expected_io_type

onnx.checker.check_model(loaded_model)


def assert_input_precision(nodes, dtype="float16"):
for node in nodes:
for inp in node.inputs:
assert inp.dtype == dtype, (
f"Node of type {node.op} has type {inp.dtype} but should have type {dtype}"
)
return True


@pytest.mark.parametrize("opset_version", [13, 21])
def test_conv_isinf_conversion(tmp_path, opset_version):
onnx_model = build_conv_isinf_model(opset_version)
onnx_path = os.path.join(tmp_path, f"conv_isinf_model_opset{opset_version}.onnx")
onnx.save(onnx_model, onnx_path)

# Convert the model
converted_model = convert_to_mixed_precision(onnx_path=onnx_path, keep_io_types=True)

# Output model should be produced in the same tmp_path
output_onnx_path = onnx_path.replace(".onnx", ".fp16.onnx")
onnx.save(converted_model, output_onnx_path)

# Load the output model and check QDQ node placements
graph = gs.import_onnx(converted_model)

# Check that Conv is converted
conv_nodes = [n for n in graph.nodes if "Conv" in n.op]
assert assert_input_precision(conv_nodes)

# Check that IsInf is running in the lowest supported precision:
# - FP32 if opset < 20, or
# - FP16 if opset >= 20
isinf_nodes = [n for n in graph.nodes if n.op == "IsInf"]
opset_version = onnx_utils.get_opset_version(converted_model)
supported_dtype = "float32" if opset_version < 20 else "float16"
assert assert_input_precision(isinf_nodes, dtype=supported_dtype)
2 changes: 1 addition & 1 deletion tests/unit/onnx/test_convtranspose_qdq.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import onnx
import pytest
import torch
from _test_utils.onnx.quantization.lib_test_models import UNet, export_as_onnx
from _test_utils.onnx.lib_test_models import UNet, export_as_onnx

from modelopt.onnx.quantization import quantize

Expand Down
2 changes: 1 addition & 1 deletion tests/unit/onnx/test_partitioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import onnx
import onnx_graphsurgeon as gs
from _test_utils.onnx.quantization.lib_test_models import export_as_onnx
from _test_utils.onnx.lib_test_models import export_as_onnx
from _test_utils.torch.vision_models import get_tiny_resnet_and_input

from modelopt.onnx.quantization.graph_utils import (
Expand Down
Loading