Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 21 additions & 2 deletions modelopt/onnx/trt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
except ImportError:
TRT_PYTHON_AVAILABLE = False

MAX_IR_VERSION = 10


def get_custom_layers(
onnx_path: str | onnx.ModelProto,
Expand Down Expand Up @@ -296,7 +298,8 @@ def load_onnx_model(

static_shaped_onnx_path = onnx_path.replace(".onnx", "_static.onnx")
save_onnx(onnx_model, static_shaped_onnx_path, use_external_data_format)
intermediate_generated_files.append(static_shaped_onnx_path) # type: ignore[union-attr]
if intermediate_generated_files is not None:
intermediate_generated_files.append(static_shaped_onnx_path)

if TRT_PYTHON_AVAILABLE and platform.system() != "Windows":
# Check if there's a custom TensorRT op in the ONNX model. If so, make it ORT compatible by adding
Expand All @@ -318,11 +321,27 @@ def load_onnx_model(
# Infer types and shapes in the graph for ORT compatibility
onnx_model = infer_types_shapes_tensorrt(onnx_model, trt_plugins or [], all_tensor_info)

# Enforce IR version = 10
ir_version_onnx_path = None
if onnx_model.ir_version > MAX_IR_VERSION:
onnx_model.ir_version = MAX_IR_VERSION
ir_version_onnx_path = (
static_shaped_onnx_path.replace(".onnx", f"_ir{MAX_IR_VERSION}.onnx")
if static_shaped_onnx_path
else onnx_path.replace(".onnx", f"_ir{MAX_IR_VERSION}.onnx")
)
save_onnx(onnx_model, ir_version_onnx_path, use_external_data_format)
if intermediate_generated_files is not None:
intermediate_generated_files.append(ir_version_onnx_path)

# Check that the model is valid
onnx.checker.check_model(onnx_model)

return (
onnx_model,
has_custom_op,
custom_ops,
static_shaped_onnx_path or onnx_path,
ir_version_onnx_path or static_shaped_onnx_path or onnx_path,
use_external_data_format,
)

Expand Down
70 changes: 70 additions & 0 deletions tests/unit/onnx/test_onnx_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
make_tensor_value_info,
)

from modelopt.onnx.trt_utils import load_onnx_model
from modelopt.onnx.utils import (
get_input_names_from_bytes,
get_output_names_from_bytes,
Expand Down Expand Up @@ -253,3 +254,72 @@ def test_remove_node_extra_training_outputs():
value_info_names = [vi.name for vi in result_model.graph.value_info]
assert "saved_mean" not in value_info_names
assert "saved_inv_std" not in value_info_names


def _make_matmul_relu_model(ir_version=12):
# Define your model inputs and outputs
input_names = ["input_0"]
output_names = ["output_0"]
input_shapes = [(1, 1024, 1024)]
output_shapes = [(1, 1024, 16)]

inputs = [
make_tensor_value_info(input_name, onnx.TensorProto.FLOAT, input_shape)
for input_name, input_shape in zip(input_names, input_shapes)
]
outputs = [
make_tensor_value_info(output_name, onnx.TensorProto.FLOAT, output_shape)
for output_name, output_shape in zip(output_names, output_shapes)
]

# Create the ONNX graph with the nodes
nodes = [
make_node(
op_type="MatMul",
inputs=["input_0", "weights_1"],
outputs=["matmul1_matmul/MatMul:0"],
name="matmul1_matmul/MatMul",
),
make_node(
op_type="Relu",
inputs=["matmul1_matmul/MatMul:0"],
outputs=["output_0"],
name="relu1_relu/Relu",
),
]

# Create the ONNX initializers
initializers = [
make_tensor(
name="weights_1",
data_type=onnx.TensorProto.FLOAT,
dims=(1024, 16),
vals=np.random.uniform(low=0.5, high=1.0, size=1024 * 16),
),
]

# Create the ONNX graph with the nodes and initializers
graph = make_graph(
nodes, f"matmul_relu_ir_{ir_version}", inputs, outputs, initializer=initializers
)

# Create the ONNX model
model = make_model(graph)
model.opset_import[0].version = 13
model.ir_version = ir_version

# Check the ONNX model
model_inferred = onnx.shape_inference.infer_shapes(model)
onnx.checker.check_model(model_inferred)

return model_inferred


def test_ir_version_support(tmp_path):
model = _make_matmul_relu_model(ir_version=12)
model_path = os.path.join(tmp_path, "test_matmul_relu.onnx")
onnx.save(model, model_path)
model_reload, _, _, _, _ = load_onnx_model(model_path, intermediate_generated_files=[])
assert model_reload.ir_version == 10, (
f"The maximum supported IR version is 10, but version {model_reload.ir_version} was detected."
)