From 6ee9eeaf1c42e06cee1abd9e32e93481088a4b97 Mon Sep 17 00:00:00 2001 From: zpye Date: Mon, 29 Sep 2025 07:08:05 -0500 Subject: [PATCH 1/4] [VitisAI] Fix OrtShapeInferContext for optional inputs ### Description When there is an optional input (empty input type) in the OrtShapeInferContext construction, use undefined data type and empty shape as a placeholder. ### Motivation and Context VitisAI EP may add nodes with optional inputs during graph optimization to meet the requirements of AMD AI compilers. This fix may help other execution providers to improve the graph optimization process. --- onnxruntime/core/session/custom_ops.cc | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/onnxruntime/core/session/custom_ops.cc b/onnxruntime/core/session/custom_ops.cc index 00f5017a55db0..d89abf93a76bb 100644 --- a/onnxruntime/core/session/custom_ops.cc +++ b/onnxruntime/core/session/custom_ops.cc @@ -71,16 +71,21 @@ struct OrtShapeInferContext { auto num_inputs = ctx_.getNumInputs(); for (size_t ith_input = 0; ith_input < num_inputs; ++ith_input) { const auto* input_type = ctx_.getInputType(ith_input); - const auto& value_case = input_type->value_case(); - ORT_ENFORCE(value_case == ONNX_NAMESPACE::TypeProto::kTensorType, - "shape inference not yet supported for non-tensor types"); - const auto& shape_proto = input_type->tensor_type().shape(); - const auto& type_proto = input_type->tensor_type(); - auto elem_type = ::onnxruntime::utils::CApiElementTypeFromProtoType(type_proto.elem_type()); - auto tensor_shape = ::onnxruntime::utils::GetTensorShapeFromTensorShapeProto(shape_proto); - auto symbolic_dims = GetSymbolicDims(shape_proto); - input_type_shapes_.emplace_back( - OrtTensorTypeAndShapeInfo::GetTensorShapeAndTypeHelper(elem_type, &tensor_shape, &symbolic_dims)); + if (input_type != nullptr) { + const auto& value_case = input_type->value_case(); + ORT_ENFORCE(value_case == ONNX_NAMESPACE::TypeProto::kTensorType, + "shape inference not yet supported for non-tensor types"); + const auto& shape_proto = input_type->tensor_type().shape(); + const auto& type_proto = input_type->tensor_type(); + auto elem_type = ::onnxruntime::utils::CApiElementTypeFromProtoType(type_proto.elem_type()); + auto tensor_shape = ::onnxruntime::utils::GetTensorShapeFromTensorShapeProto(shape_proto); + auto symbolic_dims = GetSymbolicDims(shape_proto); + input_type_shapes_.emplace_back( + OrtTensorTypeAndShapeInfo::GetTensorShapeAndTypeHelper(elem_type, &tensor_shape, &symbolic_dims)); + } else { + input_type_shapes_.emplace_back( + OrtTensorTypeAndShapeInfo::GetTensorShapeAndTypeHelper(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED, onnxruntime::TensorShape(), nullptr).release()); + } } } From 69cca17ad85ebaf7d70620d0799e8adaef5bc993 Mon Sep 17 00:00:00 2001 From: zpye Date: Mon, 29 Sep 2025 07:29:47 -0500 Subject: [PATCH 2/4] [fix] remove release() --- onnxruntime/core/session/custom_ops.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/session/custom_ops.cc b/onnxruntime/core/session/custom_ops.cc index d89abf93a76bb..93e53c5b32714 100644 --- a/onnxruntime/core/session/custom_ops.cc +++ b/onnxruntime/core/session/custom_ops.cc @@ -84,7 +84,7 @@ struct OrtShapeInferContext { OrtTensorTypeAndShapeInfo::GetTensorShapeAndTypeHelper(elem_type, &tensor_shape, &symbolic_dims)); } else { input_type_shapes_.emplace_back( - OrtTensorTypeAndShapeInfo::GetTensorShapeAndTypeHelper(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED, onnxruntime::TensorShape(), nullptr).release()); + OrtTensorTypeAndShapeInfo::GetTensorShapeAndTypeHelper(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED, onnxruntime::TensorShape(), nullptr)); } } } From f4c51cff6c42858769b44b21eba46c87943e60e3 Mon Sep 17 00:00:00 2001 From: zpye Date: Tue, 30 Sep 2025 04:46:58 -0500 Subject: [PATCH 3/4] [test] update test, fix code style --- onnxruntime/core/session/custom_ops.cc | 3 ++- onnxruntime/test/framework/shape_inference_test.cc | 3 +++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/onnxruntime/core/session/custom_ops.cc b/onnxruntime/core/session/custom_ops.cc index 93e53c5b32714..9bc6c8d0a96a1 100644 --- a/onnxruntime/core/session/custom_ops.cc +++ b/onnxruntime/core/session/custom_ops.cc @@ -84,7 +84,8 @@ struct OrtShapeInferContext { OrtTensorTypeAndShapeInfo::GetTensorShapeAndTypeHelper(elem_type, &tensor_shape, &symbolic_dims)); } else { input_type_shapes_.emplace_back( - OrtTensorTypeAndShapeInfo::GetTensorShapeAndTypeHelper(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED, onnxruntime::TensorShape(), nullptr)); + OrtTensorTypeAndShapeInfo::GetTensorShapeAndTypeHelper( + ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED, nullptr, nullptr)); } } } diff --git a/onnxruntime/test/framework/shape_inference_test.cc b/onnxruntime/test/framework/shape_inference_test.cc index f5258760eb20d..fca0cc08eb316 100644 --- a/onnxruntime/test/framework/shape_inference_test.cc +++ b/onnxruntime/test/framework/shape_inference_test.cc @@ -129,6 +129,9 @@ const ORTCHAR_T* const OPTIONAL_INPUT_CUSTOM_OP_MODEL_URI_2 = ORT_TSTR("testdata // that inference proceeds for all of the outputs when absent optional inputs are present TEST(ShapeInferenceCustomOpTest, custom_op_optional_input_inference_test) { MyCustomOpWithOptionalInput custom_op{onnxruntime::kCpuExecutionProvider}; + custom_op.InferOutputShapeFn = [](const OrtCustomOp*, OrtShapeInferContext*) -> OrtStatusPtr { + return nullptr; + }; const auto& env = GetEnvironment(); From 34e3b8b2600feefc0150c14539b18c12437bd757 Mon Sep 17 00:00:00 2001 From: zpye Date: Tue, 30 Sep 2025 08:25:04 -0500 Subject: [PATCH 4/4] [fix] add unused param names --- onnxruntime/test/framework/shape_inference_test.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/test/framework/shape_inference_test.cc b/onnxruntime/test/framework/shape_inference_test.cc index fca0cc08eb316..2d5c3a43ee8ed 100644 --- a/onnxruntime/test/framework/shape_inference_test.cc +++ b/onnxruntime/test/framework/shape_inference_test.cc @@ -129,7 +129,7 @@ const ORTCHAR_T* const OPTIONAL_INPUT_CUSTOM_OP_MODEL_URI_2 = ORT_TSTR("testdata // that inference proceeds for all of the outputs when absent optional inputs are present TEST(ShapeInferenceCustomOpTest, custom_op_optional_input_inference_test) { MyCustomOpWithOptionalInput custom_op{onnxruntime::kCpuExecutionProvider}; - custom_op.InferOutputShapeFn = [](const OrtCustomOp*, OrtShapeInferContext*) -> OrtStatusPtr { + custom_op.InferOutputShapeFn = [](const OrtCustomOp* /*op*/, OrtShapeInferContext* /*ctx*/) -> OrtStatusPtr { return nullptr; };