| 
18 | 18 | import numpy as np  | 
19 | 19 | import onnx  | 
20 | 20 | import pytest  | 
21 |  | -from _test_utils.onnx_quantization.lib_test_models import build_matmul_relu_model  | 
22 | 21 | from _test_utils.torch_model.vision_models import get_tiny_resnet_and_input  | 
23 | 22 | from onnx.helper import (  | 
24 | 23 |     make_graph,  | 
@@ -257,8 +256,67 @@ def test_remove_node_extra_training_outputs():  | 
257 | 256 |     assert "saved_inv_std" not in value_info_names  | 
258 | 257 | 
 
  | 
259 | 258 | 
 
  | 
 | 259 | +def _make_matmul_relu_model(ir_version=12):  | 
 | 260 | +    # Define your model inputs and outputs  | 
 | 261 | +    input_names = ["input_0"]  | 
 | 262 | +    output_names = ["output_0"]  | 
 | 263 | +    input_shapes = [(1, 1024, 1024)]  | 
 | 264 | +    output_shapes = [(1, 1024, 16)]  | 
 | 265 | + | 
 | 266 | +    inputs = [  | 
 | 267 | +        make_tensor_value_info(input_name, onnx.TensorProto.FLOAT, input_shape)  | 
 | 268 | +        for input_name, input_shape in zip(input_names, input_shapes)  | 
 | 269 | +    ]  | 
 | 270 | +    outputs = [  | 
 | 271 | +        make_tensor_value_info(output_name, onnx.TensorProto.FLOAT, output_shape)  | 
 | 272 | +        for output_name, output_shape in zip(output_names, output_shapes)  | 
 | 273 | +    ]  | 
 | 274 | + | 
 | 275 | +    # Create the ONNX graph with the nodes  | 
 | 276 | +    nodes = [  | 
 | 277 | +        make_node(  | 
 | 278 | +            op_type="MatMul",  | 
 | 279 | +            inputs=["input_0", "weights_1"],  | 
 | 280 | +            outputs=["matmul1_matmul/MatMul:0"],  | 
 | 281 | +            name="matmul1_matmul/MatMul",  | 
 | 282 | +        ),  | 
 | 283 | +        make_node(  | 
 | 284 | +            op_type="Relu",  | 
 | 285 | +            inputs=["matmul1_matmul/MatMul:0"],  | 
 | 286 | +            outputs=["output_0"],  | 
 | 287 | +            name="relu1_relu/Relu",  | 
 | 288 | +        ),  | 
 | 289 | +    ]  | 
 | 290 | + | 
 | 291 | +    # Create the ONNX initializers  | 
 | 292 | +    initializers = [  | 
 | 293 | +        make_tensor(  | 
 | 294 | +            name="weights_1",  | 
 | 295 | +            data_type=onnx.TensorProto.FLOAT,  | 
 | 296 | +            dims=(1024, 16),  | 
 | 297 | +            vals=np.random.uniform(low=0.5, high=1.0, size=1024 * 16),  | 
 | 298 | +        ),  | 
 | 299 | +    ]  | 
 | 300 | + | 
 | 301 | +    # Create the ONNX graph with the nodes and initializers  | 
 | 302 | +    graph = make_graph(  | 
 | 303 | +        nodes, f"matmul_relu_ir_{ir_version}", inputs, outputs, initializer=initializers  | 
 | 304 | +    )  | 
 | 305 | + | 
 | 306 | +    # Create the ONNX model  | 
 | 307 | +    model = make_model(graph)  | 
 | 308 | +    model.opset_import[0].version = 13  | 
 | 309 | +    model.ir_version = ir_version  | 
 | 310 | + | 
 | 311 | +    # Check the ONNX model  | 
 | 312 | +    model_inferred = onnx.shape_inference.infer_shapes(model)  | 
 | 313 | +    onnx.checker.check_model(model_inferred)  | 
 | 314 | + | 
 | 315 | +    return model_inferred  | 
 | 316 | + | 
 | 317 | + | 
260 | 318 | def test_ir_version_support(tmp_path):  | 
261 |  | -    model = build_matmul_relu_model(ir_version=12)  | 
 | 319 | +    model = _make_matmul_relu_model(ir_version=12)  | 
262 | 320 |     model_path = os.path.join(tmp_path, "test_matmul_relu.onnx")  | 
263 | 321 |     onnx.save(model, model_path)  | 
264 | 322 |     model_reload, _, _, _, _ = load_onnx_model(model_path, intermediate_generated_files=[])  | 
 | 
0 commit comments