Skip to content

Commit 6ca303d

Browse files
committed
nit: added ir_version as arg in model building function
Signed-off-by: gcunhase <[email protected]>
1 parent aeb425e commit 6ca303d

File tree

2 files changed

+7
-5
lines changed

2 files changed

+7
-5
lines changed

tests/_test_utils/onnx_quantization/lib_test_models.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -557,7 +557,7 @@ def build_convtranspose_conv_residual_model():
557557
return model_inferred
558558

559559

560-
def build_matmul_relu_model_ir_12():
560+
def build_matmul_relu_model(ir_version=12):
561561
# Define your model inputs and outputs
562562
input_names = ["input_0"]
563563
output_names = ["output_0"]
@@ -600,12 +600,14 @@ def build_matmul_relu_model_ir_12():
600600
]
601601

602602
# Create the ONNX graph with the nodes and initializers
603-
graph = helper.make_graph(nodes, "matmul_relu", inputs, outputs, initializer=initializers)
603+
graph = helper.make_graph(
604+
nodes, f"matmul_relu_ir_{ir_version}", inputs, outputs, initializer=initializers
605+
)
604606

605607
# Create the ONNX model
606608
model = helper.make_model(graph)
607609
model.opset_import[0].version = 13
608-
model.ir_version = 12
610+
model.ir_version = ir_version
609611

610612
# Check the ONNX model
611613
model_inferred = onnx.shape_inference.infer_shapes(model)

tests/unit/onnx/test_onnx_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import numpy as np
1919
import onnx
2020
import pytest
21-
from _test_utils.onnx_quantization.lib_test_models import build_matmul_relu_model_ir_12
21+
from _test_utils.onnx_quantization.lib_test_models import build_matmul_relu_model
2222
from _test_utils.torch_model.vision_models import get_tiny_resnet_and_input
2323
from onnx.helper import (
2424
make_graph,
@@ -258,7 +258,7 @@ def test_remove_node_extra_training_outputs():
258258

259259

260260
def test_ir_version_support(tmp_path):
261-
model = build_matmul_relu_model_ir_12()
261+
model = build_matmul_relu_model(ir_version=12)
262262
model_path = os.path.join(tmp_path, "test_matmul_relu.onnx")
263263
onnx.save(model, model_path)
264264
model_reload, _, _, _, _ = load_onnx_model(model_path, intermediate_generated_files=[])

0 commit comments

Comments
 (0)