Skip to content

Commit e396c55

Browse files
committed
nit: added ir_version as arg in model building function
Signed-off-by: gcunhase <[email protected]>
1 parent e1b53d9 commit e396c55

File tree

2 files changed

+7
-5
lines changed

2 files changed

+7
-5
lines changed

tests/_test_utils/onnx_quantization/lib_test_models.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -556,7 +556,7 @@ def build_convtranspose_conv_residual_model():
556556
return model_inferred
557557

558558

559-
def build_matmul_relu_model_ir_12():
559+
def build_matmul_relu_model(ir_version=12):
560560
# Define your model inputs and outputs
561561
input_names = ["input_0"]
562562
output_names = ["output_0"]
@@ -599,12 +599,14 @@ def build_matmul_relu_model_ir_12():
599599
]
600600

601601
# Create the ONNX graph with the nodes and initializers
602-
graph = helper.make_graph(nodes, "matmul_relu", inputs, outputs, initializer=initializers)
602+
graph = helper.make_graph(
603+
nodes, f"matmul_relu_ir_{ir_version}", inputs, outputs, initializer=initializers
604+
)
603605

604606
# Create the ONNX model
605607
model = helper.make_model(graph)
606608
model.opset_import[0].version = 13
607-
model.ir_version = 12
609+
model.ir_version = ir_version
608610

609611
# Check the ONNX model
610612
model_inferred = onnx.shape_inference.infer_shapes(model)

tests/unit/onnx/test_onnx_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import numpy as np
1919
import onnx
2020
import pytest
21-
from _test_utils.onnx_quantization.lib_test_models import build_matmul_relu_model_ir_12
21+
from _test_utils.onnx_quantization.lib_test_models import build_matmul_relu_model
2222
from _test_utils.torch_model.vision_models import get_tiny_resnet_and_input
2323
from onnx.helper import (
2424
make_graph,
@@ -258,7 +258,7 @@ def test_remove_node_extra_training_outputs():
258258

259259

260260
def test_ir_version_support(tmp_path):
261-
model = build_matmul_relu_model_ir_12()
261+
model = build_matmul_relu_model(ir_version=12)
262262
model_path = os.path.join(tmp_path, "test_matmul_relu.onnx")
263263
onnx.save(model, model_path)
264264
model_reload, _, _, _, _ = load_onnx_model(model_path, intermediate_generated_files=[])

0 commit comments

Comments
 (0)