diff --git a/modelopt/onnx/trt_utils.py b/modelopt/onnx/trt_utils.py index e5a5d9a4a..fe01d672f 100644 --- a/modelopt/onnx/trt_utils.py +++ b/modelopt/onnx/trt_utils.py @@ -36,6 +36,8 @@ except ImportError: TRT_PYTHON_AVAILABLE = False +MAX_IR_VERSION = 10 + def get_custom_layers( onnx_path: str | onnx.ModelProto, @@ -296,7 +298,8 @@ def load_onnx_model( static_shaped_onnx_path = onnx_path.replace(".onnx", "_static.onnx") save_onnx(onnx_model, static_shaped_onnx_path, use_external_data_format) - intermediate_generated_files.append(static_shaped_onnx_path) # type: ignore[union-attr] + if intermediate_generated_files is not None: + intermediate_generated_files.append(static_shaped_onnx_path) if TRT_PYTHON_AVAILABLE and platform.system() != "Windows": # Check if there's a custom TensorRT op in the ONNX model. If so, make it ORT compatible by adding @@ -318,11 +321,27 @@ def load_onnx_model( # Infer types and shapes in the graph for ORT compatibility onnx_model = infer_types_shapes_tensorrt(onnx_model, trt_plugins or [], all_tensor_info) + # Enforce IR version = 10 + ir_version_onnx_path = None + if onnx_model.ir_version > MAX_IR_VERSION: + onnx_model.ir_version = MAX_IR_VERSION + ir_version_onnx_path = ( + static_shaped_onnx_path.replace(".onnx", f"_ir{MAX_IR_VERSION}.onnx") + if static_shaped_onnx_path + else onnx_path.replace(".onnx", f"_ir{MAX_IR_VERSION}.onnx") + ) + save_onnx(onnx_model, ir_version_onnx_path, use_external_data_format) + if intermediate_generated_files is not None: + intermediate_generated_files.append(ir_version_onnx_path) + + # Check that the model is valid + onnx.checker.check_model(onnx_model) + return ( onnx_model, has_custom_op, custom_ops, - static_shaped_onnx_path or onnx_path, + ir_version_onnx_path or static_shaped_onnx_path or onnx_path, use_external_data_format, ) diff --git a/tests/unit/onnx/test_onnx_utils.py b/tests/unit/onnx/test_onnx_utils.py index ede97302d..a58da7002 100644 --- a/tests/unit/onnx/test_onnx_utils.py +++ b/tests/unit/onnx/test_onnx_utils.py @@ -28,6 +28,7 @@ make_tensor_value_info, ) +from modelopt.onnx.trt_utils import load_onnx_model from modelopt.onnx.utils import ( get_input_names_from_bytes, get_output_names_from_bytes, @@ -253,3 +254,72 @@ def test_remove_node_extra_training_outputs(): value_info_names = [vi.name for vi in result_model.graph.value_info] assert "saved_mean" not in value_info_names assert "saved_inv_std" not in value_info_names + + +def _make_matmul_relu_model(ir_version=12): + # Define your model inputs and outputs + input_names = ["input_0"] + output_names = ["output_0"] + input_shapes = [(1, 1024, 1024)] + output_shapes = [(1, 1024, 16)] + + inputs = [ + make_tensor_value_info(input_name, onnx.TensorProto.FLOAT, input_shape) + for input_name, input_shape in zip(input_names, input_shapes) + ] + outputs = [ + make_tensor_value_info(output_name, onnx.TensorProto.FLOAT, output_shape) + for output_name, output_shape in zip(output_names, output_shapes) + ] + + # Create the ONNX graph with the nodes + nodes = [ + make_node( + op_type="MatMul", + inputs=["input_0", "weights_1"], + outputs=["matmul1_matmul/MatMul:0"], + name="matmul1_matmul/MatMul", + ), + make_node( + op_type="Relu", + inputs=["matmul1_matmul/MatMul:0"], + outputs=["output_0"], + name="relu1_relu/Relu", + ), + ] + + # Create the ONNX initializers + initializers = [ + make_tensor( + name="weights_1", + data_type=onnx.TensorProto.FLOAT, + dims=(1024, 16), + vals=np.random.uniform(low=0.5, high=1.0, size=1024 * 16), + ), + ] + + # Create the ONNX graph with the nodes and initializers + graph = make_graph( + nodes, f"matmul_relu_ir_{ir_version}", inputs, outputs, initializer=initializers + ) + + # Create the ONNX model + model = make_model(graph) + model.opset_import[0].version = 13 + model.ir_version = ir_version + + # Check the ONNX model + model_inferred = onnx.shape_inference.infer_shapes(model) + onnx.checker.check_model(model_inferred) + + return model_inferred + + +def test_ir_version_support(tmp_path): + model = _make_matmul_relu_model(ir_version=12) + model_path = os.path.join(tmp_path, "test_matmul_relu.onnx") + onnx.save(model, model_path) + model_reload, _, _, _, _ = load_onnx_model(model_path, intermediate_generated_files=[]) + assert model_reload.ir_version == 10, ( + f"The maximum supported IR version is 10, but version {model_reload.ir_version} was detected." + )