|
18 | 18 | import numpy as np |
19 | 19 | import onnx |
20 | 20 | import pytest |
21 | | -from _test_utils.onnx_quantization.lib_test_models import build_matmul_relu_model |
22 | 21 | from _test_utils.torch_model.vision_models import get_tiny_resnet_and_input |
23 | 22 | from onnx.helper import ( |
24 | 23 | make_graph, |
@@ -257,8 +256,67 @@ def test_remove_node_extra_training_outputs(): |
257 | 256 | assert "saved_inv_std" not in value_info_names |
258 | 257 |
|
259 | 258 |
|
| 259 | +def _make_matmul_relu_model(ir_version=12): |
| 260 | + # Define your model inputs and outputs |
| 261 | + input_names = ["input_0"] |
| 262 | + output_names = ["output_0"] |
| 263 | + input_shapes = [(1, 1024, 1024)] |
| 264 | + output_shapes = [(1, 1024, 16)] |
| 265 | + |
| 266 | + inputs = [ |
| 267 | + make_tensor_value_info(input_name, onnx.TensorProto.FLOAT, input_shape) |
| 268 | + for input_name, input_shape in zip(input_names, input_shapes) |
| 269 | + ] |
| 270 | + outputs = [ |
| 271 | + make_tensor_value_info(output_name, onnx.TensorProto.FLOAT, output_shape) |
| 272 | + for output_name, output_shape in zip(output_names, output_shapes) |
| 273 | + ] |
| 274 | + |
| 275 | + # Create the ONNX graph with the nodes |
| 276 | + nodes = [ |
| 277 | + make_node( |
| 278 | + op_type="MatMul", |
| 279 | + inputs=["input_0", "weights_1"], |
| 280 | + outputs=["matmul1_matmul/MatMul:0"], |
| 281 | + name="matmul1_matmul/MatMul", |
| 282 | + ), |
| 283 | + make_node( |
| 284 | + op_type="Relu", |
| 285 | + inputs=["matmul1_matmul/MatMul:0"], |
| 286 | + outputs=["output_0"], |
| 287 | + name="relu1_relu/Relu", |
| 288 | + ), |
| 289 | + ] |
| 290 | + |
| 291 | + # Create the ONNX initializers |
| 292 | + initializers = [ |
| 293 | + make_tensor( |
| 294 | + name="weights_1", |
| 295 | + data_type=onnx.TensorProto.FLOAT, |
| 296 | + dims=(1024, 16), |
| 297 | + vals=np.random.uniform(low=0.5, high=1.0, size=1024 * 16), |
| 298 | + ), |
| 299 | + ] |
| 300 | + |
| 301 | + # Create the ONNX graph with the nodes and initializers |
| 302 | + graph = make_graph( |
| 303 | + nodes, f"matmul_relu_ir_{ir_version}", inputs, outputs, initializer=initializers |
| 304 | + ) |
| 305 | + |
| 306 | + # Create the ONNX model |
| 307 | + model = make_model(graph) |
| 308 | + model.opset_import[0].version = 13 |
| 309 | + model.ir_version = ir_version |
| 310 | + |
| 311 | + # Check the ONNX model |
| 312 | + model_inferred = onnx.shape_inference.infer_shapes(model) |
| 313 | + onnx.checker.check_model(model_inferred) |
| 314 | + |
| 315 | + return model_inferred |
| 316 | + |
| 317 | + |
260 | 318 | def test_ir_version_support(tmp_path): |
261 | | - model = build_matmul_relu_model(ir_version=12) |
| 319 | + model = _make_matmul_relu_model(ir_version=12) |
262 | 320 | model_path = os.path.join(tmp_path, "test_matmul_relu.onnx") |
263 | 321 | onnx.save(model, model_path) |
264 | 322 | model_reload, _, _, _, _ = load_onnx_model(model_path, intermediate_generated_files=[]) |
|
0 commit comments