Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions modelopt/onnx/autocast/graphsanitizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,6 @@ def convert_fp64_to_fp32(self) -> None:

if modified:
logger.info("Converted FP64 initializers, I/O types, and nodes to FP32")
self.model = onnx_utils.infer_shapes(self.model, strict_mode=True)

def ensure_custom_ops_precision(self) -> None:
"""Ensure that custom ops run in the requested precision."""
Expand Down Expand Up @@ -144,9 +143,6 @@ def remove_disconnected_outputs(self) -> None:
def convert_opset(self) -> None:
"""Convert the model to the given opset version.

Args:
min_opset: minimum opset version to use

The method checks all opset imports and converts the model if any are below the minimum version.
"""
# Check all opset imports
Expand Down
7 changes: 6 additions & 1 deletion modelopt/onnx/autocast/precisionconverter.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,11 @@ def __init__(
self.min_opset = min_opset
self.max_ir_version = max_ir_version
self.trt_plugins = trt_plugins
self.op_types_not_supported_in_low_precision = OP_TYPES_NOT_SUPPORTED_IN_LOW_PRECISION + (
utils.get_op_types_not_supported_in_low_precision(
self.model, self.low_precision_type.str_full, self.min_opset
)
)

def convert(
self,
Expand Down Expand Up @@ -446,7 +451,7 @@ def _filter_unsupported_op_types(
# precision so we need to set Resize and Upsample to high precision
for node in self.model.graph.node:
if (
node.op_type in OP_TYPES_NOT_SUPPORTED_IN_LOW_PRECISION
node.op_type in self.op_types_not_supported_in_low_precision
and node.name in low_precision_nodes
):
low_precision_nodes.remove(node.name)
Expand Down
62 changes: 62 additions & 0 deletions modelopt/onnx/autocast/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
support the core functionality of model precision conversion.
"""

import logging
from collections import defaultdict

import onnx


Expand Down Expand Up @@ -115,3 +118,62 @@ def get_cast_to_type(cast_node: onnx.NodeProto) -> int:
if attr.name == "to":
return attr.i
raise ValueError("Cast node does not have 'to' attribute")


def get_op_types_not_supported_in_low_precision(
model: onnx.ModelProto,
low_precision_type: str,
min_opset: int,
) -> list[str]:
"""Get a list of ops not supported in low precision for the current opset version.

Args:
model: ONNX model.
low_precision_type: Target precision to reduce to ('float16' or 'bfloat16').
min_opset: Minimum opset version.

Returns:
ops_without_support: List of ops not supported in low precision for the current opset version.
"""
# Obtain the current model's opset version
ai_onnx_domain = [
opset
for opset in model.opset_import
if not opset.domain or opset.domain in ["ai.onnx", "ai.onnx.contrib", "trt.plugins"]
]
opset_version = max(ai_onnx_domain[0].version, min_opset)

# Get all ops precision support information
precision = "tensor(float16)" if low_precision_type == "float16" else "tensor(bfloat16)"
model_ops = {n.op_type for n in model.graph.node}
schemas_dict = defaultdict(dict)
for schema in onnx.defs.get_all_schemas_with_history():
if schema.name not in model_ops:
continue
float16_supported = False
for constr in schema.type_constraints:
if precision in constr.allowed_type_strs:
float16_supported = True
break
schemas_dict[schema.name].update({schema.since_version: float16_supported})

# Check that all ops are supported in low precision for the current opset version.
# Otherwise, exclude from conversion.
ops_without_support = {}
for op, schema in schemas_dict.items():
supported_opsets = [k for k, v in schema.items() if v]
if supported_opsets:
min_opset = min(supported_opsets)
if min_opset > opset_version:
ops_without_support[op] = min_opset
else:
ops_without_support[op] = None

if ops_without_support:
logging.warning(
f"{len(ops_without_support)} ops are not supported in '{low_precision_type}' in opset {opset_version}, "
f"skipping those from conversion. Upgrade the model's opset version as follows to run them in low "
f" precision: {ops_without_support}."
)

return list(ops_without_support.keys())
99 changes: 99 additions & 0 deletions tests/_test_utils/onnx/quantization/lib_test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -822,3 +822,102 @@ def build_conv_act_pool_model(include_reshape_node=False):
onnx.checker.check_model(model_inferred)

return model_inferred


def build_conv_isinf_model():
# Define your model inputs and outputs
input_names = ["input_0"]
output_names = ["output_0"]
input_shapes = [(6, 32, 900, 256)]
output_shapes = [(6, 32, 900, 256)]

inputs = [
helper.make_tensor_value_info(input_name, onnx.TensorProto.FLOAT, input_shape)
for input_name, input_shape in zip(input_names, input_shapes)
]
outputs = [
helper.make_tensor_value_info(output_name, onnx.TensorProto.FLOAT, output_shape)
for output_name, output_shape in zip(output_names, output_shapes)
]

# Create the ONNX graph with the nodes
nodes = [
helper.make_node(
op_type="Conv",
inputs=["input_0", "weights_1"],
outputs=["conv1_conv/Conv2D:0"],
name="conv1_conv/Conv2D",
dilations=[1, 1],
group=1,
kernel_shape=[3, 3],
pads=[1, 1, 1, 1],
strides=[1, 1],
),
helper.make_node(
op_type="Cast",
inputs=["conv1_conv/Conv2D:0"],
outputs=["cast1_cast/Cast:0"],
name="cast1_cast/Cast",
to=onnx.TensorProto.DOUBLE,
),
helper.make_node(
op_type="IsInf",
inputs=["cast1_cast/Cast:0"],
outputs=["isinf1_isinf/IsInf:0"],
name="isinf1_isinf/IsInf",
),
helper.make_node(
op_type="Greater",
inputs=["conv1_conv/Conv2D:0", "greater_const1"],
outputs=["greater1_greater/Greater:0"],
name="greater1_greater/Greater",
),
helper.make_node(
op_type="And",
inputs=["isinf1_isinf/IsInf:0", "greater1_greater/Greater:0"],
outputs=["and1_and/And:0"],
name="and1_and/And",
),
helper.make_node(
op_type="Where",
inputs=["and1_and/And:0", "conv1_conv/Conv2D:0", "where_const1"],
outputs=["output_0"],
name="where1_where/Where",
),
]

# Create the ONNX initializers
initializers = [
helper.make_tensor(
name="weights_1",
data_type=onnx.TensorProto.FLOAT,
dims=(32, 32, 3, 3),
vals=np.random.uniform(low=0.5, high=1.0, size=32 * 32 * 3 * 3),
),
helper.make_tensor(
name="greater_const1",
data_type=onnx.TensorProto.FLOAT,
dims=(1,),
vals=[0],
),
helper.make_tensor(
name="where_const1",
data_type=onnx.TensorProto.FLOAT,
dims=(1,),
vals=[10000],
),
]

# Create the ONNX graph with the nodes and initializers
graph = helper.make_graph(nodes, "conv_isinf", inputs, outputs, initializer=initializers)

# Create the ONNX model
model = helper.make_model(graph)
model.opset_import[0].version = 13
model.ir_version = 10

# Check the ONNX model
model_inferred = onnx.shape_inference.infer_shapes(model)
onnx.checker.check_model(model_inferred)

return model_inferred
30 changes: 30 additions & 0 deletions tests/unit/onnx/test_qdq_rules_int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from _test_utils.onnx.quantization.lib_test_models import (
build_conv_act_pool_model,
build_conv_batchnorm_sig_mul_model,
build_conv_isinf_model,
build_convtranspose_conv_residual_model,
build_r1a_model,
build_resnet_block,
Expand Down Expand Up @@ -208,3 +209,32 @@ def test_conv_act_pool_int8(tmp_path, include_reshape_node):
# Check that MaxPool is not quantized
pool_nodes = [n for n in graph.nodes if n.op == "MaxPool"]
assert assert_nodes_are_not_quantized(pool_nodes)


def test_conv_isinf_int8(tmp_path):
onnx_model = build_conv_isinf_model()
onnx_path = os.path.join(tmp_path, "conv_isinf_model.onnx")
save_onnx(onnx_model, onnx_path)

quantize(onnx_path, quantize_mode="int8", high_precision_dtype="fp16")

# Output model should be produced in the same tmp_path
output_onnx_path = onnx_path.replace(".onnx", ".quant.onnx")

# Check that quantized explicit model is generated
assert os.path.isfile(output_onnx_path)

# Load the output model and check QDQ node placements
graph = gs.import_onnx(onnx.load(output_onnx_path))

# Check that Conv is quantized
conv_nodes = [n for n in graph.nodes if "Conv" in n.op]
assert assert_nodes_are_quantized(conv_nodes)

# Check that IsInf is running in FP32
isinf_nodes = [n for n in graph.nodes if n.op == "IsInf"]
for node in isinf_nodes:
for inp in node.inputs:
assert inp.dtype == "float32", (
f"Node of type 'IsInf' has type {inp.dtype} but should have type float32"
)
Loading