Skip to content

Commit 00e1103

Browse files
committed
Ensure that the ONNX IR version is the max supported version (10)
Signed-off-by: gcunhase <[email protected]>
1 parent 3a76d28 commit 00e1103

File tree

1 file changed

+15
-1
lines changed

1 file changed

+15
-1
lines changed

modelopt/onnx/trt_utils.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@
3636
except ImportError:
3737
TRT_PYTHON_AVAILABLE = False
3838

39+
MAX_IR_VERSION = 10
40+
3941

4042
def get_custom_layers(
4143
onnx_path: str | onnx.ModelProto,
@@ -318,11 +320,23 @@ def load_onnx_model(
318320
# Infer types and shapes in the graph for ORT compatibility
319321
onnx_model = infer_types_shapes_tensorrt(onnx_model, trt_plugins or [], all_tensor_info)
320322

323+
# Enforce IR version = 10
324+
ir_version_onnx_path = None
325+
if onnx_model.ir_version > MAX_IR_VERSION:
326+
onnx_model.ir_version = MAX_IR_VERSION
327+
ir_version_onnx_path = (
328+
static_shaped_onnx_path.replace(".onnx", f"_ir{MAX_IR_VERSION}.onnx")
329+
if static_shaped_onnx_path
330+
else onnx_path.replace(".onnx", f"_ir{MAX_IR_VERSION}.onnx")
331+
)
332+
save_onnx(onnx_model, ir_version_onnx_path, use_external_data_format)
333+
intermediate_generated_files.append(ir_version_onnx_path) # type: ignore[union-attr]
334+
321335
return (
322336
onnx_model,
323337
has_custom_op,
324338
custom_ops,
325-
static_shaped_onnx_path or onnx_path,
339+
ir_version_onnx_path or static_shaped_onnx_path or onnx_path,
326340
use_external_data_format,
327341
)
328342

0 commit comments

Comments
 (0)