Skip to content

Commit ec3fe85

Browse files
committed
Add unittest
Signed-off-by: gcunhase <[email protected]>
1 parent 39ef962 commit ec3fe85

File tree

2 files changed

+69
-0
lines changed

2 files changed

+69
-0
lines changed

tests/_test_utils/onnx_quantization/lib_test_models.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -372,3 +372,60 @@ def build_conv_concat_model():
372372
onnx.checker.check_model(model_inferred)
373373

374374
return model_inferred
375+
376+
377+
def build_matmul_relu_model_ir_12():
378+
# Define your model inputs and outputs
379+
input_names = ["input_0"]
380+
output_names = ["output_0"]
381+
input_shapes = [(1, 1024, 1024)]
382+
output_shapes = [(1, 1024, 16)]
383+
384+
inputs = [
385+
helper.make_tensor_value_info(input_name, onnx.TensorProto.FLOAT, input_shape)
386+
for input_name, input_shape in zip(input_names, input_shapes)
387+
]
388+
outputs = [
389+
helper.make_tensor_value_info(output_name, onnx.TensorProto.FLOAT, output_shape)
390+
for output_name, output_shape in zip(output_names, output_shapes)
391+
]
392+
393+
# Create the ONNX graph with the nodes
394+
nodes = [
395+
helper.make_node(
396+
op_type="MatMul",
397+
inputs=["input_0", "weights_1"],
398+
outputs=["matmul1_matmul/MatMul:0"],
399+
name="matmul1_matmul/MatMul",
400+
),
401+
helper.make_node(
402+
op_type="Relu",
403+
inputs=["matmul1_matmul/MatMul:0"],
404+
outputs=["output_0"],
405+
name="relu1_relu/Relu",
406+
),
407+
]
408+
409+
# Create the ONNX initializers
410+
initializers = [
411+
helper.make_tensor(
412+
name="weights_1",
413+
data_type=onnx.TensorProto.FLOAT,
414+
dims=(1024, 16),
415+
vals=np.random.uniform(low=0.5, high=1.0, size=1024 * 16),
416+
),
417+
]
418+
419+
# Create the ONNX graph with the nodes and initializers
420+
graph = helper.make_graph(nodes, "r1a", inputs, outputs, initializer=initializers)
421+
422+
# Create the ONNX model
423+
model = helper.make_model(graph)
424+
model.opset_import[0].version = 13
425+
model.ir_version = 12
426+
427+
# Check the ONNX model
428+
model_inferred = onnx.shape_inference.infer_shapes(model)
429+
onnx.checker.check_model(model_inferred)
430+
431+
return model_inferred

tests/unit/onnx/test_onnx_utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import numpy as np
1919
import onnx
2020
import pytest
21+
from _test_utils.onnx_quantization.lib_test_models import build_matmul_relu_model_ir_12
2122
from _test_utils.torch_model.vision_models import get_tiny_resnet_and_input
2223
from onnx.helper import (
2324
make_graph,
@@ -28,6 +29,7 @@
2829
make_tensor_value_info,
2930
)
3031

32+
from modelopt.onnx.trt_utils import load_onnx_model
3133
from modelopt.onnx.utils import (
3234
get_input_names_from_bytes,
3335
get_output_names_from_bytes,
@@ -253,3 +255,13 @@ def test_remove_node_extra_training_outputs():
253255
value_info_names = [vi.name for vi in result_model.graph.value_info]
254256
assert "saved_mean" not in value_info_names
255257
assert "saved_inv_std" not in value_info_names
258+
259+
260+
def test_ir_version_support(tmp_path="./"):
261+
model = build_matmul_relu_model_ir_12()
262+
model_path = os.path.join(tmp_path, "test_matmul_relu.onnx")
263+
onnx.save(model, model_path)
264+
model_reload, _, _, _, _ = load_onnx_model(model_path, intermediate_generated_files=[])
265+
assert model_reload.ir_version == 10, (
266+
f"The maximum supported IR version is 10, but version {model_reload.ir_version} was detected."
267+
)

0 commit comments

Comments
 (0)