Skip to content

Commit 7808cbd

Browse files
committed
Move model building function to test_onnx_utils
Signed-off-by: gcunhase <[email protected]>
1 parent 6ca303d commit 7808cbd

File tree

1 file changed

+60
-2
lines changed

1 file changed

+60
-2
lines changed

tests/unit/onnx/test_onnx_utils.py

Lines changed: 60 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
import numpy as np
1919
import onnx
2020
import pytest
21-
from _test_utils.onnx_quantization.lib_test_models import build_matmul_relu_model
2221
from _test_utils.torch_model.vision_models import get_tiny_resnet_and_input
2322
from onnx.helper import (
2423
make_graph,
@@ -257,8 +256,67 @@ def test_remove_node_extra_training_outputs():
257256
assert "saved_inv_std" not in value_info_names
258257

259258

259+
def _make_matmul_relu_model(ir_version=12):
260+
# Define your model inputs and outputs
261+
input_names = ["input_0"]
262+
output_names = ["output_0"]
263+
input_shapes = [(1, 1024, 1024)]
264+
output_shapes = [(1, 1024, 16)]
265+
266+
inputs = [
267+
make_tensor_value_info(input_name, onnx.TensorProto.FLOAT, input_shape)
268+
for input_name, input_shape in zip(input_names, input_shapes)
269+
]
270+
outputs = [
271+
make_tensor_value_info(output_name, onnx.TensorProto.FLOAT, output_shape)
272+
for output_name, output_shape in zip(output_names, output_shapes)
273+
]
274+
275+
# Create the ONNX graph with the nodes
276+
nodes = [
277+
make_node(
278+
op_type="MatMul",
279+
inputs=["input_0", "weights_1"],
280+
outputs=["matmul1_matmul/MatMul:0"],
281+
name="matmul1_matmul/MatMul",
282+
),
283+
make_node(
284+
op_type="Relu",
285+
inputs=["matmul1_matmul/MatMul:0"],
286+
outputs=["output_0"],
287+
name="relu1_relu/Relu",
288+
),
289+
]
290+
291+
# Create the ONNX initializers
292+
initializers = [
293+
make_tensor(
294+
name="weights_1",
295+
data_type=onnx.TensorProto.FLOAT,
296+
dims=(1024, 16),
297+
vals=np.random.uniform(low=0.5, high=1.0, size=1024 * 16),
298+
),
299+
]
300+
301+
# Create the ONNX graph with the nodes and initializers
302+
graph = make_graph(
303+
nodes, f"matmul_relu_ir_{ir_version}", inputs, outputs, initializer=initializers
304+
)
305+
306+
# Create the ONNX model
307+
model = make_model(graph)
308+
model.opset_import[0].version = 13
309+
model.ir_version = ir_version
310+
311+
# Check the ONNX model
312+
model_inferred = onnx.shape_inference.infer_shapes(model)
313+
onnx.checker.check_model(model_inferred)
314+
315+
return model_inferred
316+
317+
260318
def test_ir_version_support(tmp_path):
261-
model = build_matmul_relu_model(ir_version=12)
319+
model = _make_matmul_relu_model(ir_version=12)
262320
model_path = os.path.join(tmp_path, "test_matmul_relu.onnx")
263321
onnx.save(model, model_path)
264322
model_reload, _, _, _, _ = load_onnx_model(model_path, intermediate_generated_files=[])

0 commit comments

Comments
 (0)