diff --git a/onnxruntime/core/session/custom_ops.cc b/onnxruntime/core/session/custom_ops.cc index 00f5017a55db0..9bc6c8d0a96a1 100644 --- a/onnxruntime/core/session/custom_ops.cc +++ b/onnxruntime/core/session/custom_ops.cc @@ -71,16 +71,22 @@ struct OrtShapeInferContext { auto num_inputs = ctx_.getNumInputs(); for (size_t ith_input = 0; ith_input < num_inputs; ++ith_input) { const auto* input_type = ctx_.getInputType(ith_input); - const auto& value_case = input_type->value_case(); - ORT_ENFORCE(value_case == ONNX_NAMESPACE::TypeProto::kTensorType, - "shape inference not yet supported for non-tensor types"); - const auto& shape_proto = input_type->tensor_type().shape(); - const auto& type_proto = input_type->tensor_type(); - auto elem_type = ::onnxruntime::utils::CApiElementTypeFromProtoType(type_proto.elem_type()); - auto tensor_shape = ::onnxruntime::utils::GetTensorShapeFromTensorShapeProto(shape_proto); - auto symbolic_dims = GetSymbolicDims(shape_proto); - input_type_shapes_.emplace_back( - OrtTensorTypeAndShapeInfo::GetTensorShapeAndTypeHelper(elem_type, &tensor_shape, &symbolic_dims)); + if (input_type != nullptr) { + const auto& value_case = input_type->value_case(); + ORT_ENFORCE(value_case == ONNX_NAMESPACE::TypeProto::kTensorType, + "shape inference not yet supported for non-tensor types"); + const auto& shape_proto = input_type->tensor_type().shape(); + const auto& type_proto = input_type->tensor_type(); + auto elem_type = ::onnxruntime::utils::CApiElementTypeFromProtoType(type_proto.elem_type()); + auto tensor_shape = ::onnxruntime::utils::GetTensorShapeFromTensorShapeProto(shape_proto); + auto symbolic_dims = GetSymbolicDims(shape_proto); + input_type_shapes_.emplace_back( + OrtTensorTypeAndShapeInfo::GetTensorShapeAndTypeHelper(elem_type, &tensor_shape, &symbolic_dims)); + } else { + input_type_shapes_.emplace_back( + OrtTensorTypeAndShapeInfo::GetTensorShapeAndTypeHelper( + ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED, nullptr, nullptr)); + } } } diff --git a/onnxruntime/test/framework/shape_inference_test.cc b/onnxruntime/test/framework/shape_inference_test.cc index f5258760eb20d..2d5c3a43ee8ed 100644 --- a/onnxruntime/test/framework/shape_inference_test.cc +++ b/onnxruntime/test/framework/shape_inference_test.cc @@ -129,6 +129,9 @@ const ORTCHAR_T* const OPTIONAL_INPUT_CUSTOM_OP_MODEL_URI_2 = ORT_TSTR("testdata // that inference proceeds for all of the outputs when absent optional inputs are present TEST(ShapeInferenceCustomOpTest, custom_op_optional_input_inference_test) { MyCustomOpWithOptionalInput custom_op{onnxruntime::kCpuExecutionProvider}; + custom_op.InferOutputShapeFn = [](const OrtCustomOp* /*op*/, OrtShapeInferContext* /*ctx*/) -> OrtStatusPtr { + return nullptr; + }; const auto& env = GetEnvironment();