Skip to content
Open
Show file tree
Hide file tree
Changes from 32 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
8c03c61
Sketch API funcs
adrianlizarraga Sep 23, 2025
40d7866
Merge branch 'main' into adrianl/ep-abi-kernel-based-eps
adrianlizarraga Sep 26, 2025
ec46ea5
Implement more c apis
adrianlizarraga Sep 27, 2025
e8532a9
OpKernel class for plugin EPs
adrianlizarraga Sep 27, 2025
246681c
Initialize PluginExecutionProvider's kernel registry
adrianlizarraga Sep 28, 2025
6b61b91
Move files to session
adrianlizarraga Sep 28, 2025
4a112fc
Add API to set kernel def I/O memory types
adrianlizarraga Sep 28, 2025
0f870d0
Add C API to add type constraints to a kernel definition
adrianlizarraga Sep 28, 2025
9be6923
Start implementing MemcpyFromHost kernel in example EP
adrianlizarraga Sep 28, 2025
5df3fb5
Get kernel for MemcpyFromHost working for example plugin EP!
adrianlizarraga Sep 28, 2025
5aade60
Moved example plugin EP's kernel stuff to different file
adrianlizarraga Sep 28, 2025
d159b38
Add separate utility to load OrtMLDataTypes
adrianlizarraga Sep 29, 2025
8187233
Add MLDataTypes::GetTensorType()
adrianlizarraga Sep 29, 2025
bf92f04
Add C++ Ort::KernelDefBuilder to allow creation of macro
adrianlizarraga Sep 29, 2025
da14d65
Create macro for defining BuildKernelCreateInfo functions for each op
adrianlizarraga Sep 30, 2025
73ae307
Move kernels to separate directory/files
adrianlizarraga Sep 30, 2025
fb4a6a6
Use data transfer in MemcpyFromHost and MemcpyToHost
adrianlizarraga Sep 30, 2025
b8867d6
Release OrtKernelCreateInfo instances if an error occurs
adrianlizarraga Sep 30, 2025
de8be32
Move typedef and add forward-declaration of OrtKernelImpl for gcc
adrianlizarraga Sep 30, 2025
dc78ec3
Merge branch 'main' into adrianl/ep-abi-kernel-based-eps
adrianlizarraga Sep 30, 2025
8babb63
Apply suggestions from code review
adrianlizarraga Oct 1, 2025
90bf598
Simplify with OrtKernelRegistry
adrianlizarraga Oct 2, 2025
9f95589
Pass custom state to kernel creation in plugin EP
adrianlizarraga Oct 2, 2025
90e4fc1
Clean up
adrianlizarraga Oct 2, 2025
5f52cdc
ExampleEp: cache kernel registry in factory so it can be reused by al…
adrianlizarraga Oct 2, 2025
9bfc281
Add C API to lookup a kernel from within OrtEp::GetCapability
adrianlizarraga Oct 2, 2025
60ea06c
Disambiguate a compiled subgraph (of one node) from a registered kern…
adrianlizarraga Oct 2, 2025
33ffd8d
Add unit test for EpGraphSupportInfo_LookUpKernel()
adrianlizarraga Oct 2, 2025
31cdc82
Add missing include needed for linux ci
adrianlizarraga Oct 2, 2025
93b99e6
Add KernelDef to C++ api and add basic getters
adrianlizarraga Oct 2, 2025
995e25b
Add documentation comments
adrianlizarraga Oct 3, 2025
0f7145f
Remove incorrect comment
adrianlizarraga Oct 3, 2025
f528f6f
Add missing API_IMPL_BEGIN/END macro calls
adrianlizarraga Oct 6, 2025
9fbf230
Merge branch 'main' into adrianl/ep-abi-kernel-based-eps
adrianlizarraga Oct 6, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion cmake/onnxruntime_unittests.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -1976,7 +1976,9 @@ if (WIN32 AND onnxruntime_BUILD_SHARED_LIB AND
NOT onnxruntime_MINIMAL_BUILD)
# example_plugin_ep
file(GLOB onnxruntime_autoep_test_library_src "${TEST_SRC_DIR}/autoep/library/*.h"
"${TEST_SRC_DIR}/autoep/library/*.cc")
"${TEST_SRC_DIR}/autoep/library/*.cc"
"${TEST_SRC_DIR}/autoep/library/kernels/*.h"
"${TEST_SRC_DIR}/autoep/library/kernels/*.cc")
onnxruntime_add_shared_library_module(example_plugin_ep ${onnxruntime_autoep_test_library_src})
target_include_directories(example_plugin_ep PRIVATE ${REPO_ROOT}/include/onnxruntime/core/session)
target_link_libraries(example_plugin_ep PRIVATE onnxruntime)
Expand Down
61 changes: 61 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_cxx_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -644,6 +644,8 @@ ORT_DEFINE_RELEASE(ValueInfo);

ORT_DEFINE_RELEASE_FROM_API_STRUCT(ModelCompilationOptions, GetCompileApi);
ORT_DEFINE_RELEASE_FROM_API_STRUCT(EpDevice, GetEpApi);
ORT_DEFINE_RELEASE_FROM_API_STRUCT(KernelDef, GetEpApi);
ORT_DEFINE_RELEASE_FROM_API_STRUCT(KernelDefBuilder, GetEpApi);

// This is defined explicitly since OrtTensorRTProviderOptionsV2 is not a C API type,
// but the struct has V2 in its name to indicate that it is the second version of the options.
Expand Down Expand Up @@ -3284,5 +3286,64 @@ struct Model : detail::ModelImpl<OrtModel> {
explicit Model(const std::vector<DomainOpsetPair>& opsets);
#endif
};

namespace detail {
template <typename T>
struct ConstKernelDefImpl : Base<T> {
using B = Base<T>;
using B::B;

///< Wraps OrtEpApi::KernelDef_GetOperatorType
const char* GetOperatorType() const;

///< Wraps OrtEpApi::KernelDef_GetDomain
const char* GetDomain() const;

///< Wraps OrtEpApi::KernelDef_GetSinceVersion
std::pair<int, int> GetSinceVersion() const;

///< Wraps OrtEpApi::KernelDef_GetExecutionProvider
const char* GetExecutionProvider() const;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If any of the information for any getters is optional, suggest returning a status

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The underlying C API function OrtEpApi::KernelDef_GetExecutionProvider returns the const char* directly (doesn't return a status).

Sorry, I don't fully understand the comment. Is the request to make this return a status instead? What is meant by "information for any getters is optional"?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was a general comment not to be bound to throwing exceptions by default. In this case, the data is returned without error reporting.


///< Wraps OrtEpApi::KernelDef_GetInputMemType
OrtMemType GetInputMemType(size_t input_index) const;

///< Wraps OrtEpApi::KernelDef_GetOutputMemType
OrtMemType GetOutputMemType(size_t output_index) const;
};
} // namespace detail

using ConstKernelDef = detail::ConstKernelDefImpl<detail::Unowned<const OrtKernelDef>>;

struct KernelDef : detail::ConstKernelDefImpl<OrtKernelDef> {
using Base = detail::ConstKernelDefImpl<OrtKernelDef>;
using Base::Base;

explicit KernelDef(std::nullptr_t) {}
explicit KernelDef(OrtKernelDef* p) : detail::ConstKernelDefImpl<OrtKernelDef>{p} {}

ConstKernelDef GetConst() const { return ConstKernelDef{this->p_}; }
};

/** \brief Builder for OrtKernelDef.
*
* Used by plugin EPs to build a kernel definition.
*/
struct KernelDefBuilder : detail::Base<OrtKernelDefBuilder> {
KernelDefBuilder(); ///< Wraps OrtEpApi::CreateKernelDefBuilder
explicit KernelDefBuilder(std::nullptr_t) {} ///< Create an empty object, must be assigned a valid one to be used
explicit KernelDefBuilder(OrtKernelDefBuilder* ort_kernel_def_builder);

KernelDefBuilder& SetOperatorType(const char* op_type);
KernelDefBuilder& SetDomain(const char* domain);
KernelDefBuilder& SetSinceVersion(int since_version_start, int since_version_end = -1);
KernelDefBuilder& SetExecutionProvider(const char* ep_name);
KernelDefBuilder& SetInputMemType(size_t input_index, OrtMemType mem_type);
KernelDefBuilder& SetOutputMemType(size_t output_index, OrtMemType mem_type);
KernelDefBuilder& AddTypeConstraint(const char* arg_name, const OrtMLDataType* data_types);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: it is one data type with this overload, right?

Suggested change
KernelDefBuilder& AddTypeConstraint(const char* arg_name, const OrtMLDataType* data_types);
KernelDefBuilder& AddTypeConstraint(const char* arg_name, const OrtMLDataType* data_type);

KernelDefBuilder& AddTypeConstraint(const char* arg_name, const std::vector<const OrtMLDataType*>& data_types);

KernelDef Build();
};
} // namespace Ort
#include "onnxruntime_cxx_inline.h"
97 changes: 97 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_cxx_inline.h
Original file line number Diff line number Diff line change
Expand Up @@ -3553,4 +3553,101 @@
}
#endif

namespace detail {
template <typename T>
inline const char* ConstKernelDefImpl<T>::GetOperatorType() const {
return GetEpApi().KernelDef_GetOperatorType(this->p_);
}

template <typename T>
inline const char* ConstKernelDefImpl<T>::GetDomain() const {
return GetEpApi().KernelDef_GetDomain(this->p_);
}

template <typename T>
inline std::pair<int, int> ConstKernelDefImpl<T>::GetSinceVersion() const {
int start = 0;
int end = 0;

ThrowOnError(GetEpApi().KernelDef_GetSinceVersion(this->p_, &start, &end));
return std::pair<int, int>(start, end);

Check warning on line 3573 in include/onnxruntime/core/session/onnxruntime_cxx_inline.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <utility> for pair<> [build/include_what_you_use] [4] Raw Output: include/onnxruntime/core/session/onnxruntime_cxx_inline.h:3573: Add #include <utility> for pair<> [build/include_what_you_use] [4]
}

template <typename T>
inline const char* ConstKernelDefImpl<T>::GetExecutionProvider() const {
return GetEpApi().KernelDef_GetExecutionProvider(this->p_);
}

template <typename T>
inline OrtMemType ConstKernelDefImpl<T>::GetInputMemType(size_t input_index) const {
OrtMemType mem_type{};
ThrowOnError(GetEpApi().KernelDef_GetInputMemType(this->p_, input_index, &mem_type));

return mem_type;
}

template <typename T>
inline OrtMemType ConstKernelDefImpl<T>::GetOutputMemType(size_t output_index) const {
OrtMemType mem_type{};
ThrowOnError(GetEpApi().KernelDef_GetOutputMemType(this->p_, output_index, &mem_type));

return mem_type;
}
} // namespace detail

inline KernelDefBuilder::KernelDefBuilder() {
ThrowOnError(GetEpApi().CreateKernelDefBuilder(&p_));
}

inline KernelDefBuilder::KernelDefBuilder(OrtKernelDefBuilder* p) : detail::Base<OrtKernelDefBuilder>{p} {
}

inline KernelDefBuilder& KernelDefBuilder::SetOperatorType(const char* op_type) {
ThrowOnError(GetEpApi().KernelDefBuilder_SetOperatorType(p_, op_type));
return *this;
}

inline KernelDefBuilder& KernelDefBuilder::SetDomain(const char* domain) {
ThrowOnError(GetEpApi().KernelDefBuilder_SetDomain(p_, domain));
return *this;
}

inline KernelDefBuilder& KernelDefBuilder::SetSinceVersion(int since_version_start, int since_version_end) {
ThrowOnError(GetEpApi().KernelDefBuilder_SetSinceVersion(p_, since_version_start, since_version_end));
return *this;
}

inline KernelDefBuilder& KernelDefBuilder::SetExecutionProvider(const char* ep_name) {
ThrowOnError(GetEpApi().KernelDefBuilder_SetExecutionProvider(p_, ep_name));
return *this;
}

inline KernelDefBuilder& KernelDefBuilder::SetInputMemType(size_t input_index, OrtMemType mem_type) {
ThrowOnError(GetEpApi().KernelDefBuilder_SetInputMemType(p_, input_index, mem_type));
return *this;
}

inline KernelDefBuilder& KernelDefBuilder::SetOutputMemType(size_t output_index, OrtMemType mem_type) {
ThrowOnError(GetEpApi().KernelDefBuilder_SetOutputMemType(p_, output_index, mem_type));
return *this;
}

inline KernelDefBuilder& KernelDefBuilder::AddTypeConstraint(const char* arg_name,
const OrtMLDataType* data_types) {
ThrowOnError(GetEpApi().KernelDefBuilder_AddTypeConstraint(p_, arg_name, &data_types, 1));
return *this;
}

inline KernelDefBuilder& KernelDefBuilder::AddTypeConstraint(const char* arg_name,
const std::vector<const OrtMLDataType*>& data_types) {
ThrowOnError(GetEpApi().KernelDefBuilder_AddTypeConstraint(p_, arg_name, data_types.data(), data_types.size()));
return *this;
}

inline KernelDef KernelDefBuilder::Build() {
OrtKernelDef* kernel_def = nullptr;
ThrowOnError(GetEpApi().KernelDefBuilder_Build(p_, &kernel_def));
return KernelDef(kernel_def);
}

} // namespace Ort
Loading
Loading