Skip to content

Commit b01bc7a

Browse files
committed
Move model building function to test_onnx_utils
Signed-off-by: gcunhase <[email protected]>
1 parent cf8dc63 commit b01bc7a

File tree

2 files changed

+60
-61
lines changed

2 files changed

+60
-61
lines changed

tests/_test_utils/onnx_quantization/lib_test_models.py

Lines changed: 0 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -372,62 +372,3 @@ def build_conv_concat_model():
372372
onnx.checker.check_model(model_inferred)
373373

374374
return model_inferred
375-
376-
377-
def build_matmul_relu_model(ir_version=12):
378-
# Define your model inputs and outputs
379-
input_names = ["input_0"]
380-
output_names = ["output_0"]
381-
input_shapes = [(1, 1024, 1024)]
382-
output_shapes = [(1, 1024, 16)]
383-
384-
inputs = [
385-
helper.make_tensor_value_info(input_name, onnx.TensorProto.FLOAT, input_shape)
386-
for input_name, input_shape in zip(input_names, input_shapes)
387-
]
388-
outputs = [
389-
helper.make_tensor_value_info(output_name, onnx.TensorProto.FLOAT, output_shape)
390-
for output_name, output_shape in zip(output_names, output_shapes)
391-
]
392-
393-
# Create the ONNX graph with the nodes
394-
nodes = [
395-
helper.make_node(
396-
op_type="MatMul",
397-
inputs=["input_0", "weights_1"],
398-
outputs=["matmul1_matmul/MatMul:0"],
399-
name="matmul1_matmul/MatMul",
400-
),
401-
helper.make_node(
402-
op_type="Relu",
403-
inputs=["matmul1_matmul/MatMul:0"],
404-
outputs=["output_0"],
405-
name="relu1_relu/Relu",
406-
),
407-
]
408-
409-
# Create the ONNX initializers
410-
initializers = [
411-
helper.make_tensor(
412-
name="weights_1",
413-
data_type=onnx.TensorProto.FLOAT,
414-
dims=(1024, 16),
415-
vals=np.random.uniform(low=0.5, high=1.0, size=1024 * 16),
416-
),
417-
]
418-
419-
# Create the ONNX graph with the nodes and initializers
420-
graph = helper.make_graph(
421-
nodes, f"matmul_relu_ir_{ir_version}", inputs, outputs, initializer=initializers
422-
)
423-
424-
# Create the ONNX model
425-
model = helper.make_model(graph)
426-
model.opset_import[0].version = 13
427-
model.ir_version = ir_version
428-
429-
# Check the ONNX model
430-
model_inferred = onnx.shape_inference.infer_shapes(model)
431-
onnx.checker.check_model(model_inferred)
432-
433-
return model_inferred

tests/unit/onnx/test_onnx_utils.py

Lines changed: 60 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
import numpy as np
1919
import onnx
2020
import pytest
21-
from _test_utils.onnx_quantization.lib_test_models import build_matmul_relu_model
2221
from _test_utils.torch_model.vision_models import get_tiny_resnet_and_input
2322
from onnx.helper import (
2423
make_graph,
@@ -257,8 +256,67 @@ def test_remove_node_extra_training_outputs():
257256
assert "saved_inv_std" not in value_info_names
258257

259258

259+
def _make_matmul_relu_model(ir_version=12):
260+
# Define your model inputs and outputs
261+
input_names = ["input_0"]
262+
output_names = ["output_0"]
263+
input_shapes = [(1, 1024, 1024)]
264+
output_shapes = [(1, 1024, 16)]
265+
266+
inputs = [
267+
make_tensor_value_info(input_name, onnx.TensorProto.FLOAT, input_shape)
268+
for input_name, input_shape in zip(input_names, input_shapes)
269+
]
270+
outputs = [
271+
make_tensor_value_info(output_name, onnx.TensorProto.FLOAT, output_shape)
272+
for output_name, output_shape in zip(output_names, output_shapes)
273+
]
274+
275+
# Create the ONNX graph with the nodes
276+
nodes = [
277+
make_node(
278+
op_type="MatMul",
279+
inputs=["input_0", "weights_1"],
280+
outputs=["matmul1_matmul/MatMul:0"],
281+
name="matmul1_matmul/MatMul",
282+
),
283+
make_node(
284+
op_type="Relu",
285+
inputs=["matmul1_matmul/MatMul:0"],
286+
outputs=["output_0"],
287+
name="relu1_relu/Relu",
288+
),
289+
]
290+
291+
# Create the ONNX initializers
292+
initializers = [
293+
make_tensor(
294+
name="weights_1",
295+
data_type=onnx.TensorProto.FLOAT,
296+
dims=(1024, 16),
297+
vals=np.random.uniform(low=0.5, high=1.0, size=1024 * 16),
298+
),
299+
]
300+
301+
# Create the ONNX graph with the nodes and initializers
302+
graph = make_graph(
303+
nodes, f"matmul_relu_ir_{ir_version}", inputs, outputs, initializer=initializers
304+
)
305+
306+
# Create the ONNX model
307+
model = make_model(graph)
308+
model.opset_import[0].version = 13
309+
model.ir_version = ir_version
310+
311+
# Check the ONNX model
312+
model_inferred = onnx.shape_inference.infer_shapes(model)
313+
onnx.checker.check_model(model_inferred)
314+
315+
return model_inferred
316+
317+
260318
def test_ir_version_support(tmp_path):
261-
model = build_matmul_relu_model(ir_version=12)
319+
model = _make_matmul_relu_model(ir_version=12)
262320
model_path = os.path.join(tmp_path, "test_matmul_relu.onnx")
263321
onnx.save(model, model_path)
264322
model_reload, _, _, _, _ = load_onnx_model(model_path, intermediate_generated_files=[])

0 commit comments

Comments
 (0)