|
28 | 28 | make_tensor_value_info, |
29 | 29 | ) |
30 | 30 |
|
| 31 | +from modelopt.onnx.trt_utils import load_onnx_model |
31 | 32 | from modelopt.onnx.utils import ( |
32 | 33 | get_input_names_from_bytes, |
33 | 34 | get_output_names_from_bytes, |
@@ -253,3 +254,72 @@ def test_remove_node_extra_training_outputs(): |
253 | 254 | value_info_names = [vi.name for vi in result_model.graph.value_info] |
254 | 255 | assert "saved_mean" not in value_info_names |
255 | 256 | assert "saved_inv_std" not in value_info_names |
| 257 | + |
| 258 | + |
| 259 | +def _make_matmul_relu_model(ir_version=12): |
| 260 | + # Define your model inputs and outputs |
| 261 | + input_names = ["input_0"] |
| 262 | + output_names = ["output_0"] |
| 263 | + input_shapes = [(1, 1024, 1024)] |
| 264 | + output_shapes = [(1, 1024, 16)] |
| 265 | + |
| 266 | + inputs = [ |
| 267 | + make_tensor_value_info(input_name, onnx.TensorProto.FLOAT, input_shape) |
| 268 | + for input_name, input_shape in zip(input_names, input_shapes) |
| 269 | + ] |
| 270 | + outputs = [ |
| 271 | + make_tensor_value_info(output_name, onnx.TensorProto.FLOAT, output_shape) |
| 272 | + for output_name, output_shape in zip(output_names, output_shapes) |
| 273 | + ] |
| 274 | + |
| 275 | + # Create the ONNX graph with the nodes |
| 276 | + nodes = [ |
| 277 | + make_node( |
| 278 | + op_type="MatMul", |
| 279 | + inputs=["input_0", "weights_1"], |
| 280 | + outputs=["matmul1_matmul/MatMul:0"], |
| 281 | + name="matmul1_matmul/MatMul", |
| 282 | + ), |
| 283 | + make_node( |
| 284 | + op_type="Relu", |
| 285 | + inputs=["matmul1_matmul/MatMul:0"], |
| 286 | + outputs=["output_0"], |
| 287 | + name="relu1_relu/Relu", |
| 288 | + ), |
| 289 | + ] |
| 290 | + |
| 291 | + # Create the ONNX initializers |
| 292 | + initializers = [ |
| 293 | + make_tensor( |
| 294 | + name="weights_1", |
| 295 | + data_type=onnx.TensorProto.FLOAT, |
| 296 | + dims=(1024, 16), |
| 297 | + vals=np.random.uniform(low=0.5, high=1.0, size=1024 * 16), |
| 298 | + ), |
| 299 | + ] |
| 300 | + |
| 301 | + # Create the ONNX graph with the nodes and initializers |
| 302 | + graph = make_graph( |
| 303 | + nodes, f"matmul_relu_ir_{ir_version}", inputs, outputs, initializer=initializers |
| 304 | + ) |
| 305 | + |
| 306 | + # Create the ONNX model |
| 307 | + model = make_model(graph) |
| 308 | + model.opset_import[0].version = 13 |
| 309 | + model.ir_version = ir_version |
| 310 | + |
| 311 | + # Check the ONNX model |
| 312 | + model_inferred = onnx.shape_inference.infer_shapes(model) |
| 313 | + onnx.checker.check_model(model_inferred) |
| 314 | + |
| 315 | + return model_inferred |
| 316 | + |
| 317 | + |
| 318 | +def test_ir_version_support(tmp_path): |
| 319 | + model = _make_matmul_relu_model(ir_version=12) |
| 320 | + model_path = os.path.join(tmp_path, "test_matmul_relu.onnx") |
| 321 | + onnx.save(model, model_path) |
| 322 | + model_reload, _, _, _, _ = load_onnx_model(model_path, intermediate_generated_files=[]) |
| 323 | + assert model_reload.ir_version == 10, ( |
| 324 | + f"The maximum supported IR version is 10, but version {model_reload.ir_version} was detected." |
| 325 | + ) |
0 commit comments