Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -6590,6 +6590,23 @@ struct OrtApi {
* \since Version 1.24
*/
ORT_API_T(bool, TensorTypeAndShape_HasShape, _In_ const OrtTensorTypeAndShapeInfo* info);

/** \brief Get all config entries from ::OrtKernelInfo.
*
* Gets all configuration entries from the ::OrtKernelInfo object as key-value pairs.
* Config entries are set on the ::OrtSessionOptions and are accessible in custom operator kernels.
*
* Used in the CreateKernel callback of an OrtCustomOp to access all session configuration entries
* during kernel construction.
*
* \param[in] info An instance of ::OrtKernelInfo.
* \param[out] out A pointer to a newly created OrtKeyValuePairs instance containing all config entries.
* Note: the user should call OrtApi::ReleaseKeyValuePairs.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
* \since Version 1.24
*/
ORT_API2_STATUS(KernelInfo_GetConfigEntries, _In_ const OrtKernelInfo* info, _Outptr_ OrtKeyValuePairs** out);
};

/*
Expand Down
2 changes: 2 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_cxx_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -2768,6 +2768,8 @@ struct KernelInfoImpl : Base<T> {

std::string GetNodeName() const;
Logger GetLogger() const;

KeyValuePairs GetConfigEntries() const;
};

} // namespace detail
Expand Down
7 changes: 7 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_cxx_inline.h
Original file line number Diff line number Diff line change
Expand Up @@ -2822,6 +2822,13 @@ inline Logger KernelInfoImpl<T>::GetLogger() const {
return Logger{out};
}

template <typename T>
inline KeyValuePairs KernelInfoImpl<T>::GetConfigEntries() const {
OrtKeyValuePairs* out = nullptr;
Ort::ThrowOnError(GetApi().KernelInfo_GetConfigEntries(this->p_, &out));
return KeyValuePairs{out};
}

inline void attr_utils::GetAttr(const OrtKernelInfo* p, const char* name, float& out) {
Ort::ThrowOnError(GetApi().KernelInfoGetAttribute_float(p, name, &out));
}
Expand Down
15 changes: 15 additions & 0 deletions onnxruntime/core/session/custom_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -755,6 +755,21 @@ ORT_API_STATUS_IMPL(OrtApis::KernelInfoGetAllocator, _In_ const OrtKernelInfo* i
});
}

ORT_API_STATUS_IMPL(OrtApis::KernelInfo_GetConfigEntries, _In_ const OrtKernelInfo* info, _Outptr_ OrtKeyValuePairs** out) {
return ExecuteIfCustomOpsApiEnabled([&]() -> OrtStatusPtr {
const auto* op_info = reinterpret_cast<const onnxruntime::OpKernelInfo*>(info);
const auto& config_options_map = op_info->GetConfigOptions().GetConfigOptionsMap();

auto kvps = std::make_unique<OrtKeyValuePairs>();
for (const auto& kv : config_options_map) {
kvps->Add(kv.first.c_str(), kv.second.c_str());
}

*out = kvps.release();
return nullptr;
});
}

ORT_API_STATUS_IMPL(OrtApis::KernelContext_GetScratchBuffer, _In_ const OrtKernelContext* context, _In_ const OrtMemoryInfo* mem_info, _In_ size_t count_or_bytes, _Outptr_ void** out) {
if (count_or_bytes == 0) {
*out = nullptr;
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/core/session/onnxruntime_c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4231,6 +4231,7 @@ static constexpr OrtApi ort_api_1_to_24 = {
// End of Version 23 - DO NOT MODIFY ABOVE (see above text for more information)

&OrtApis::TensorTypeAndShape_HasShape,
&OrtApis::KernelInfo_GetConfigEntries,
};

// OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase.
Expand Down
3 changes: 3 additions & 0 deletions onnxruntime/core/session/ort_apis.h
Original file line number Diff line number Diff line change
Expand Up @@ -751,4 +751,7 @@ ORT_API_STATUS_IMPL(CopyTensors, _In_ const OrtEnv* env,
_In_reads_(num_tensors) OrtValue* const* dst_tensors,
_In_opt_ OrtSyncStream* stream,
_In_ size_t num_tensors);

ORT_API_STATUS_IMPL(KernelInfo_GetConfigEntries, _In_ const OrtKernelInfo* info, _Outptr_ OrtKeyValuePairs** out);

} // namespace OrtApis
15 changes: 14 additions & 1 deletion onnxruntime/test/framework/shape_inference_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,18 @@ TEST_F(ShapeInferenceTest, BasicTest) {

namespace {
struct MyCustomKernelWithOptionalInput {
MyCustomKernelWithOptionalInput(const OrtKernelInfo* /*info*/) {
MyCustomKernelWithOptionalInput(const OrtKernelInfo* info) {
Ort::ConstKernelInfo k_info(info);

Ort::KeyValuePairs kvp = k_info.GetConfigEntries();

EXPECT_NE(nullptr, kvp.GetValue("session.inter_op.allow_spinning"));
EXPECT_STREQ("0", kvp.GetValue("session.inter_op.allow_spinning"));

EXPECT_NE(nullptr, kvp.GetValue("session.intra_op.allow_spinning"));
EXPECT_STREQ("0", kvp.GetValue("session.intra_op.allow_spinning"));

EXPECT_EQ(nullptr, kvp.GetValue("__not__exist__"));
}

OrtStatusPtr ComputeV2(OrtKernelContext* /* context */) const {
Expand Down Expand Up @@ -143,6 +154,8 @@ TEST(ShapeInferenceCustomOpTest, custom_op_optional_input_inference_test) {
SessionOptions sess_opts;
sess_opts.inter_op_param.thread_pool_size = 1;
sess_opts.intra_op_param.thread_pool_size = 1;
ASSERT_STATUS_OK(sess_opts.config_options.AddConfigEntry("session.inter_op.allow_spinning", "0"));
ASSERT_STATUS_OK(sess_opts.config_options.AddConfigEntry("session.intra_op.allow_spinning", "0"));

InferenceSessionWrapper session{sess_opts, env, OPTIONAL_INPUT_CUSTOM_OP_MODEL_URI_2};
ASSERT_STATUS_OK(session.AddCustomOpDomains(AsSpan(op_domains)));
Expand Down
Loading