From cb40dec9a7231ee13a64956da36e3909e838524f Mon Sep 17 00:00:00 2001 From: zpye Date: Sun, 16 Nov 2025 21:39:13 -0600 Subject: [PATCH] [feat] access config entries from KernelInfo --- .../core/session/onnxruntime_c_api.h | 17 +++++++++++++++++ .../core/session/onnxruntime_cxx_api.h | 2 ++ .../core/session/onnxruntime_cxx_inline.h | 7 +++++++ onnxruntime/core/session/custom_ops.cc | 15 +++++++++++++++ onnxruntime/core/session/onnxruntime_c_api.cc | 1 + onnxruntime/core/session/ort_apis.h | 3 +++ .../test/framework/shape_inference_test.cc | 15 ++++++++++++++- 7 files changed, 59 insertions(+), 1 deletion(-) diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 434aa075e62d6..2d50064a5cdb5 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -6590,6 +6590,23 @@ struct OrtApi { * \since Version 1.24 */ ORT_API_T(bool, TensorTypeAndShape_HasShape, _In_ const OrtTensorTypeAndShapeInfo* info); + + /** \brief Get all config entries from ::OrtKernelInfo. + * + * Gets all configuration entries from the ::OrtKernelInfo object as key-value pairs. + * Config entries are set on the ::OrtSessionOptions and are accessible in custom operator kernels. + * + * Used in the CreateKernel callback of an OrtCustomOp to access all session configuration entries + * during kernel construction. + * + * \param[in] info An instance of ::OrtKernelInfo. + * \param[out] out A pointer to a newly created OrtKeyValuePairs instance containing all config entries. + * Note: the user should call OrtApi::ReleaseKeyValuePairs. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * \since Version 1.24 + */ + ORT_API2_STATUS(KernelInfo_GetConfigEntries, _In_ const OrtKernelInfo* info, _Outptr_ OrtKeyValuePairs** out); }; /* diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index d3a8856455c49..22708bbf06a3d 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -2768,6 +2768,8 @@ struct KernelInfoImpl : Base { std::string GetNodeName() const; Logger GetLogger() const; + + KeyValuePairs GetConfigEntries() const; }; } // namespace detail diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index 8ee057f51eb20..5144418db2b58 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -2822,6 +2822,13 @@ inline Logger KernelInfoImpl::GetLogger() const { return Logger{out}; } +template +inline KeyValuePairs KernelInfoImpl::GetConfigEntries() const { + OrtKeyValuePairs* out = nullptr; + Ort::ThrowOnError(GetApi().KernelInfo_GetConfigEntries(this->p_, &out)); + return KeyValuePairs{out}; +} + inline void attr_utils::GetAttr(const OrtKernelInfo* p, const char* name, float& out) { Ort::ThrowOnError(GetApi().KernelInfoGetAttribute_float(p, name, &out)); } diff --git a/onnxruntime/core/session/custom_ops.cc b/onnxruntime/core/session/custom_ops.cc index 9bc6c8d0a96a1..6c6c589ffcad4 100644 --- a/onnxruntime/core/session/custom_ops.cc +++ b/onnxruntime/core/session/custom_ops.cc @@ -755,6 +755,21 @@ ORT_API_STATUS_IMPL(OrtApis::KernelInfoGetAllocator, _In_ const OrtKernelInfo* i }); } +ORT_API_STATUS_IMPL(OrtApis::KernelInfo_GetConfigEntries, _In_ const OrtKernelInfo* info, _Outptr_ OrtKeyValuePairs** out) { + return ExecuteIfCustomOpsApiEnabled([&]() -> OrtStatusPtr { + const auto* op_info = reinterpret_cast(info); + const auto& config_options_map = op_info->GetConfigOptions().GetConfigOptionsMap(); + + auto kvps = std::make_unique(); + for (const auto& kv : config_options_map) { + kvps->Add(kv.first.c_str(), kv.second.c_str()); + } + + *out = kvps.release(); + return nullptr; + }); +} + ORT_API_STATUS_IMPL(OrtApis::KernelContext_GetScratchBuffer, _In_ const OrtKernelContext* context, _In_ const OrtMemoryInfo* mem_info, _In_ size_t count_or_bytes, _Outptr_ void** out) { if (count_or_bytes == 0) { *out = nullptr; diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 546b11ae580d5..6e9e5f6311ae6 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -4231,6 +4231,7 @@ static constexpr OrtApi ort_api_1_to_24 = { // End of Version 23 - DO NOT MODIFY ABOVE (see above text for more information) &OrtApis::TensorTypeAndShape_HasShape, + &OrtApis::KernelInfo_GetConfigEntries, }; // OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase. diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index f016bb3215330..c0e4d32ac0167 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -751,4 +751,7 @@ ORT_API_STATUS_IMPL(CopyTensors, _In_ const OrtEnv* env, _In_reads_(num_tensors) OrtValue* const* dst_tensors, _In_opt_ OrtSyncStream* stream, _In_ size_t num_tensors); + +ORT_API_STATUS_IMPL(KernelInfo_GetConfigEntries, _In_ const OrtKernelInfo* info, _Outptr_ OrtKeyValuePairs** out); + } // namespace OrtApis diff --git a/onnxruntime/test/framework/shape_inference_test.cc b/onnxruntime/test/framework/shape_inference_test.cc index 2d5c3a43ee8ed..37c3825101ba4 100644 --- a/onnxruntime/test/framework/shape_inference_test.cc +++ b/onnxruntime/test/framework/shape_inference_test.cc @@ -78,7 +78,18 @@ TEST_F(ShapeInferenceTest, BasicTest) { namespace { struct MyCustomKernelWithOptionalInput { - MyCustomKernelWithOptionalInput(const OrtKernelInfo* /*info*/) { + MyCustomKernelWithOptionalInput(const OrtKernelInfo* info) { + Ort::ConstKernelInfo k_info(info); + + Ort::KeyValuePairs kvp = k_info.GetConfigEntries(); + + EXPECT_NE(nullptr, kvp.GetValue("session.inter_op.allow_spinning")); + EXPECT_STREQ("0", kvp.GetValue("session.inter_op.allow_spinning")); + + EXPECT_NE(nullptr, kvp.GetValue("session.intra_op.allow_spinning")); + EXPECT_STREQ("0", kvp.GetValue("session.intra_op.allow_spinning")); + + EXPECT_EQ(nullptr, kvp.GetValue("__not__exist__")); } OrtStatusPtr ComputeV2(OrtKernelContext* /* context */) const { @@ -143,6 +154,8 @@ TEST(ShapeInferenceCustomOpTest, custom_op_optional_input_inference_test) { SessionOptions sess_opts; sess_opts.inter_op_param.thread_pool_size = 1; sess_opts.intra_op_param.thread_pool_size = 1; + ASSERT_STATUS_OK(sess_opts.config_options.AddConfigEntry("session.inter_op.allow_spinning", "0")); + ASSERT_STATUS_OK(sess_opts.config_options.AddConfigEntry("session.intra_op.allow_spinning", "0")); InferenceSessionWrapper session{sess_opts, env, OPTIONAL_INPUT_CUSTOM_OP_MODEL_URI_2}; ASSERT_STATUS_OK(session.AddCustomOpDomains(AsSpan(op_domains)));