Skip to content

Commit bffe2ff

Browse files
authored
Ensure that the ONNX IR version is the max supported version (10) (#416)
Signed-off-by: gcunhase <[email protected]>
1 parent e6e0d2c commit bffe2ff

File tree

2 files changed

+91
-2
lines changed

2 files changed

+91
-2
lines changed

modelopt/onnx/trt_utils.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@
3636
except ImportError:
3737
TRT_PYTHON_AVAILABLE = False
3838

39+
MAX_IR_VERSION = 10
40+
3941

4042
def get_custom_layers(
4143
onnx_path: str | onnx.ModelProto,
@@ -296,7 +298,8 @@ def load_onnx_model(
296298

297299
static_shaped_onnx_path = onnx_path.replace(".onnx", "_static.onnx")
298300
save_onnx(onnx_model, static_shaped_onnx_path, use_external_data_format)
299-
intermediate_generated_files.append(static_shaped_onnx_path) # type: ignore[union-attr]
301+
if intermediate_generated_files is not None:
302+
intermediate_generated_files.append(static_shaped_onnx_path)
300303

301304
if TRT_PYTHON_AVAILABLE and platform.system() != "Windows":
302305
# Check if there's a custom TensorRT op in the ONNX model. If so, make it ORT compatible by adding
@@ -318,11 +321,27 @@ def load_onnx_model(
318321
# Infer types and shapes in the graph for ORT compatibility
319322
onnx_model = infer_types_shapes_tensorrt(onnx_model, trt_plugins or [], all_tensor_info)
320323

324+
# Enforce IR version = 10
325+
ir_version_onnx_path = None
326+
if onnx_model.ir_version > MAX_IR_VERSION:
327+
onnx_model.ir_version = MAX_IR_VERSION
328+
ir_version_onnx_path = (
329+
static_shaped_onnx_path.replace(".onnx", f"_ir{MAX_IR_VERSION}.onnx")
330+
if static_shaped_onnx_path
331+
else onnx_path.replace(".onnx", f"_ir{MAX_IR_VERSION}.onnx")
332+
)
333+
save_onnx(onnx_model, ir_version_onnx_path, use_external_data_format)
334+
if intermediate_generated_files is not None:
335+
intermediate_generated_files.append(ir_version_onnx_path)
336+
337+
# Check that the model is valid
338+
onnx.checker.check_model(onnx_model)
339+
321340
return (
322341
onnx_model,
323342
has_custom_op,
324343
custom_ops,
325-
static_shaped_onnx_path or onnx_path,
344+
ir_version_onnx_path or static_shaped_onnx_path or onnx_path,
326345
use_external_data_format,
327346
)
328347

tests/unit/onnx/test_onnx_utils.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
make_tensor_value_info,
2929
)
3030

31+
from modelopt.onnx.trt_utils import load_onnx_model
3132
from modelopt.onnx.utils import (
3233
get_input_names_from_bytes,
3334
get_output_names_from_bytes,
@@ -253,3 +254,72 @@ def test_remove_node_extra_training_outputs():
253254
value_info_names = [vi.name for vi in result_model.graph.value_info]
254255
assert "saved_mean" not in value_info_names
255256
assert "saved_inv_std" not in value_info_names
257+
258+
259+
def _make_matmul_relu_model(ir_version=12):
260+
# Define your model inputs and outputs
261+
input_names = ["input_0"]
262+
output_names = ["output_0"]
263+
input_shapes = [(1, 1024, 1024)]
264+
output_shapes = [(1, 1024, 16)]
265+
266+
inputs = [
267+
make_tensor_value_info(input_name, onnx.TensorProto.FLOAT, input_shape)
268+
for input_name, input_shape in zip(input_names, input_shapes)
269+
]
270+
outputs = [
271+
make_tensor_value_info(output_name, onnx.TensorProto.FLOAT, output_shape)
272+
for output_name, output_shape in zip(output_names, output_shapes)
273+
]
274+
275+
# Create the ONNX graph with the nodes
276+
nodes = [
277+
make_node(
278+
op_type="MatMul",
279+
inputs=["input_0", "weights_1"],
280+
outputs=["matmul1_matmul/MatMul:0"],
281+
name="matmul1_matmul/MatMul",
282+
),
283+
make_node(
284+
op_type="Relu",
285+
inputs=["matmul1_matmul/MatMul:0"],
286+
outputs=["output_0"],
287+
name="relu1_relu/Relu",
288+
),
289+
]
290+
291+
# Create the ONNX initializers
292+
initializers = [
293+
make_tensor(
294+
name="weights_1",
295+
data_type=onnx.TensorProto.FLOAT,
296+
dims=(1024, 16),
297+
vals=np.random.uniform(low=0.5, high=1.0, size=1024 * 16),
298+
),
299+
]
300+
301+
# Create the ONNX graph with the nodes and initializers
302+
graph = make_graph(
303+
nodes, f"matmul_relu_ir_{ir_version}", inputs, outputs, initializer=initializers
304+
)
305+
306+
# Create the ONNX model
307+
model = make_model(graph)
308+
model.opset_import[0].version = 13
309+
model.ir_version = ir_version
310+
311+
# Check the ONNX model
312+
model_inferred = onnx.shape_inference.infer_shapes(model)
313+
onnx.checker.check_model(model_inferred)
314+
315+
return model_inferred
316+
317+
318+
def test_ir_version_support(tmp_path):
319+
model = _make_matmul_relu_model(ir_version=12)
320+
model_path = os.path.join(tmp_path, "test_matmul_relu.onnx")
321+
onnx.save(model, model_path)
322+
model_reload, _, _, _, _ = load_onnx_model(model_path, intermediate_generated_files=[])
323+
assert model_reload.ir_version == 10, (
324+
f"The maximum supported IR version is 10, but version {model_reload.ir_version} was detected."
325+
)

0 commit comments

Comments
 (0)