Skip to content

Commit e400d7f

Browse files
committed
nit: added ir_version as arg in model building function
Signed-off-by: gcunhase <[email protected]>
1 parent 3f02fec commit e400d7f

File tree

2 files changed

+7
-5
lines changed

2 files changed

+7
-5
lines changed

tests/_test_utils/onnx_quantization/lib_test_models.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,7 @@ def build_conv_concat_model():
374374
return model_inferred
375375

376376

377-
def build_matmul_relu_model_ir_12():
377+
def build_matmul_relu_model(ir_version=12):
378378
# Define your model inputs and outputs
379379
input_names = ["input_0"]
380380
output_names = ["output_0"]
@@ -417,12 +417,14 @@ def build_matmul_relu_model_ir_12():
417417
]
418418

419419
# Create the ONNX graph with the nodes and initializers
420-
graph = helper.make_graph(nodes, "matmul_relu", inputs, outputs, initializer=initializers)
420+
graph = helper.make_graph(
421+
nodes, f"matmul_relu_ir_{ir_version}", inputs, outputs, initializer=initializers
422+
)
421423

422424
# Create the ONNX model
423425
model = helper.make_model(graph)
424426
model.opset_import[0].version = 13
425-
model.ir_version = 12
427+
model.ir_version = ir_version
426428

427429
# Check the ONNX model
428430
model_inferred = onnx.shape_inference.infer_shapes(model)

tests/unit/onnx/test_onnx_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import numpy as np
1919
import onnx
2020
import pytest
21-
from _test_utils.onnx_quantization.lib_test_models import build_matmul_relu_model_ir_12
21+
from _test_utils.onnx_quantization.lib_test_models import build_matmul_relu_model
2222
from _test_utils.torch_model.vision_models import get_tiny_resnet_and_input
2323
from onnx.helper import (
2424
make_graph,
@@ -258,7 +258,7 @@ def test_remove_node_extra_training_outputs():
258258

259259

260260
def test_ir_version_support(tmp_path):
261-
model = build_matmul_relu_model_ir_12()
261+
model = build_matmul_relu_model(ir_version=12)
262262
model_path = os.path.join(tmp_path, "test_matmul_relu.onnx")
263263
onnx.save(model, model_path)
264264
model_reload, _, _, _, _ = load_onnx_model(model_path, intermediate_generated_files=[])

0 commit comments

Comments
 (0)