diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index 177bc4229df31..c129c6f9acb72 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -1979,7 +1979,9 @@ if (WIN32 AND onnxruntime_BUILD_SHARED_LIB AND NOT onnxruntime_MINIMAL_BUILD) # example_plugin_ep file(GLOB onnxruntime_autoep_test_library_src "${TEST_SRC_DIR}/autoep/library/*.h" - "${TEST_SRC_DIR}/autoep/library/*.cc") + "${TEST_SRC_DIR}/autoep/library/*.cc" + "${TEST_SRC_DIR}/autoep/library/kernels/*.h" + "${TEST_SRC_DIR}/autoep/library/kernels/*.cc") onnxruntime_add_shared_library_module(example_plugin_ep ${onnxruntime_autoep_test_library_src}) target_include_directories(example_plugin_ep PRIVATE ${REPO_ROOT}/include/onnxruntime/core/session) target_link_libraries(example_plugin_ep PRIVATE onnxruntime) diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index d3a8856455c49..aa0f244b13fb6 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -644,6 +644,8 @@ ORT_DEFINE_RELEASE(ValueInfo); ORT_DEFINE_RELEASE_FROM_API_STRUCT(ModelCompilationOptions, GetCompileApi); ORT_DEFINE_RELEASE_FROM_API_STRUCT(EpDevice, GetEpApi); +ORT_DEFINE_RELEASE_FROM_API_STRUCT(KernelDef, GetEpApi); +ORT_DEFINE_RELEASE_FROM_API_STRUCT(KernelDefBuilder, GetEpApi); // This is defined explicitly since OrtTensorRTProviderOptionsV2 is not a C API type, // but the struct has V2 in its name to indicate that it is the second version of the options. @@ -3284,5 +3286,64 @@ struct Model : detail::ModelImpl { explicit Model(const std::vector& opsets); #endif }; + +namespace detail { +template +struct ConstKernelDefImpl : Base { + using B = Base; + using B::B; + + ///< Wraps OrtEpApi::KernelDef_GetOperatorType + const char* GetOperatorType() const; + + ///< Wraps OrtEpApi::KernelDef_GetDomain + const char* GetDomain() const; + + ///< Wraps OrtEpApi::KernelDef_GetSinceVersion + std::pair GetSinceVersion() const; + + ///< Wraps OrtEpApi::KernelDef_GetExecutionProvider + const char* GetExecutionProvider() const; + + ///< Wraps OrtEpApi::KernelDef_GetInputMemType + OrtMemType GetInputMemType(size_t input_index) const; + + ///< Wraps OrtEpApi::KernelDef_GetOutputMemType + OrtMemType GetOutputMemType(size_t output_index) const; +}; +} // namespace detail + +using ConstKernelDef = detail::ConstKernelDefImpl>; + +struct KernelDef : detail::ConstKernelDefImpl { + using Base = detail::ConstKernelDefImpl; + using Base::Base; + + explicit KernelDef(std::nullptr_t) {} + explicit KernelDef(OrtKernelDef* p) : detail::ConstKernelDefImpl{p} {} + + ConstKernelDef GetConst() const { return ConstKernelDef{this->p_}; } +}; + +/** \brief Builder for OrtKernelDef. + * + * Used by plugin EPs to build a kernel definition. + */ +struct KernelDefBuilder : detail::Base { + KernelDefBuilder(); ///< Wraps OrtEpApi::CreateKernelDefBuilder + explicit KernelDefBuilder(std::nullptr_t) {} ///< Create an empty object, must be assigned a valid one to be used + explicit KernelDefBuilder(OrtKernelDefBuilder* ort_kernel_def_builder); + + KernelDefBuilder& SetOperatorType(const char* op_type); + KernelDefBuilder& SetDomain(const char* domain); + KernelDefBuilder& SetSinceVersion(int since_version_start, int since_version_end = -1); + KernelDefBuilder& SetExecutionProvider(const char* ep_name); + KernelDefBuilder& SetInputMemType(size_t input_index, OrtMemType mem_type); + KernelDefBuilder& SetOutputMemType(size_t output_index, OrtMemType mem_type); + KernelDefBuilder& AddTypeConstraint(const char* arg_name, const OrtMLDataType* data_types); + KernelDefBuilder& AddTypeConstraint(const char* arg_name, const std::vector& data_types); + + KernelDef Build(); +}; } // namespace Ort #include "onnxruntime_cxx_inline.h" diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index 8ee057f51eb20..f81eba2d5566a 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -3553,4 +3553,101 @@ inline Model::Model(const std::vector& opsets) { } #endif +namespace detail { +template +inline const char* ConstKernelDefImpl::GetOperatorType() const { + return GetEpApi().KernelDef_GetOperatorType(this->p_); +} + +template +inline const char* ConstKernelDefImpl::GetDomain() const { + return GetEpApi().KernelDef_GetDomain(this->p_); +} + +template +inline std::pair ConstKernelDefImpl::GetSinceVersion() const { + int start = 0; + int end = 0; + + ThrowOnError(GetEpApi().KernelDef_GetSinceVersion(this->p_, &start, &end)); + return std::pair(start, end); +} + +template +inline const char* ConstKernelDefImpl::GetExecutionProvider() const { + return GetEpApi().KernelDef_GetExecutionProvider(this->p_); +} + +template +inline OrtMemType ConstKernelDefImpl::GetInputMemType(size_t input_index) const { + OrtMemType mem_type{}; + ThrowOnError(GetEpApi().KernelDef_GetInputMemType(this->p_, input_index, &mem_type)); + + return mem_type; +} + +template +inline OrtMemType ConstKernelDefImpl::GetOutputMemType(size_t output_index) const { + OrtMemType mem_type{}; + ThrowOnError(GetEpApi().KernelDef_GetOutputMemType(this->p_, output_index, &mem_type)); + + return mem_type; +} +} // namespace detail + +inline KernelDefBuilder::KernelDefBuilder() { + ThrowOnError(GetEpApi().CreateKernelDefBuilder(&p_)); +} + +inline KernelDefBuilder::KernelDefBuilder(OrtKernelDefBuilder* p) : detail::Base{p} { +} + +inline KernelDefBuilder& KernelDefBuilder::SetOperatorType(const char* op_type) { + ThrowOnError(GetEpApi().KernelDefBuilder_SetOperatorType(p_, op_type)); + return *this; +} + +inline KernelDefBuilder& KernelDefBuilder::SetDomain(const char* domain) { + ThrowOnError(GetEpApi().KernelDefBuilder_SetDomain(p_, domain)); + return *this; +} + +inline KernelDefBuilder& KernelDefBuilder::SetSinceVersion(int since_version_start, int since_version_end) { + ThrowOnError(GetEpApi().KernelDefBuilder_SetSinceVersion(p_, since_version_start, since_version_end)); + return *this; +} + +inline KernelDefBuilder& KernelDefBuilder::SetExecutionProvider(const char* ep_name) { + ThrowOnError(GetEpApi().KernelDefBuilder_SetExecutionProvider(p_, ep_name)); + return *this; +} + +inline KernelDefBuilder& KernelDefBuilder::SetInputMemType(size_t input_index, OrtMemType mem_type) { + ThrowOnError(GetEpApi().KernelDefBuilder_SetInputMemType(p_, input_index, mem_type)); + return *this; +} + +inline KernelDefBuilder& KernelDefBuilder::SetOutputMemType(size_t output_index, OrtMemType mem_type) { + ThrowOnError(GetEpApi().KernelDefBuilder_SetOutputMemType(p_, output_index, mem_type)); + return *this; +} + +inline KernelDefBuilder& KernelDefBuilder::AddTypeConstraint(const char* arg_name, + const OrtMLDataType* data_types) { + ThrowOnError(GetEpApi().KernelDefBuilder_AddTypeConstraint(p_, arg_name, &data_types, 1)); + return *this; +} + +inline KernelDefBuilder& KernelDefBuilder::AddTypeConstraint(const char* arg_name, + const std::vector& data_types) { + ThrowOnError(GetEpApi().KernelDefBuilder_AddTypeConstraint(p_, arg_name, data_types.data(), data_types.size())); + return *this; +} + +inline KernelDef KernelDefBuilder::Build() { + OrtKernelDef* kernel_def = nullptr; + ThrowOnError(GetEpApi().KernelDefBuilder_Build(p_, &kernel_def)); + return KernelDef(kernel_def); +} + } // namespace Ort diff --git a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h index 975f6b453a88d..0ea30610bb732 100644 --- a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h @@ -17,6 +17,13 @@ ORT_RUNTIME_CLASS(DataTransferImpl); ORT_RUNTIME_CLASS(SyncNotificationImpl); ORT_RUNTIME_CLASS(SyncStreamImpl); +// Opaque types for kernel-based EPs +ORT_RUNTIME_CLASS(KernelRegistry); +ORT_RUNTIME_CLASS(KernelCreateContext); // stand-in for FuncManager. may not be needed. +ORT_RUNTIME_CLASS(KernelDefBuilder); +ORT_RUNTIME_CLASS(KernelDef); +ORT_RUNTIME_CLASS(MLDataType); // combination of ONNXType (e.g., Tensor, Map, Sequence) and ONNXTensorElementDataType + // struct that an EP implements for IDataTransfer to copy between devices it uses and CPU struct OrtDataTransferImpl { uint32_t ort_version_supported; ///< Must be initialized to ORT_API_VERSION @@ -264,6 +271,53 @@ struct OrtNodeComputeInfo { void(ORT_API_CALL* ReleaseState)(_In_ OrtNodeComputeInfo* this_ptr, _Frees_ptr_opt_ void* compute_state); }; +struct OrtKernelImpl; +typedef struct OrtKernelImpl OrtKernelImpl; + +/** + * \brief Contains functions that an OrtEp implements to specify the computation for an operator kernel. + * \since Version 1.24. + */ +struct OrtKernelImpl { + uint32_t ort_version_supported; ///< Must be initialized to ORT_API_VERSION + + /** \brief Computation function called to execute the kernel on an EP. + * + * \param[in] this_ptr The OrtKernelImpl instance. + * \param[in] context The OrtKernelContext instance that provides access to the inputs and outputs. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ + ORT_API2_STATUS(Compute, _In_ OrtKernelImpl* this_ptr, _In_ OrtKernelContext* context); + + /** \brief Called by ORT to release the OrtKernelImpl instance and its resources. + * + * \param[in] this_ptr The OrtKernelImpl instance. + * + * \since Version 1.24. + */ + ORT_API_T(void, Release, _In_ OrtKernelImpl* this_ptr); +}; + +/** \brief Type definition for a function that creates an OrtKernelImpl instance for an operator kernel. + * + * \param[in] ctx Unused/reserved for future use. + * \param[in] kernel_create_func_state Opaque state initially provided by the EP that registered the kernel. + * Refer to OrtEpApi::KernelRegistry_AddKernel(). May be null. + * \param[in] info The OrtKernelInfo instance that provides access to the kernel's input and output characteristics. + * \param[out] kernel_out Output parameter set to the new OrtKernelImpl instance. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ +typedef OrtStatus*(ORT_API_CALL* OrtKernelCreateFunc)(_In_ OrtKernelCreateContext* ctx, // unused/reserved as of 1.24 + _In_ void* kernel_create_func_state, + _In_ const OrtKernelInfo* info, + _Outptr_result_maybenull_ OrtKernelImpl** kernel_out); + struct OrtEpApi { /** \brief Create an OrtEpDevice for the EP and an OrtHardwareDevice. * \param[in] ep_factory Execution provider factory that is creating the instance. @@ -465,6 +519,279 @@ struct OrtEpApi { */ ORT_API_T(uint64_t, GetSyncIdForLastWaitOnSyncStream, _In_ const OrtSyncStream* producer_stream, _In_ const OrtSyncStream* consumer_stream); + + /** \brief Creates an empty kernel registry. A kernel registry contains kernel creation information for + * every operator kernel supported by an EP. + * + * \remarks Refer to OrtEp::GetKernelRegistry, which returns an EP's kernel registry to ORT. + * + * \param[out] kernel_registry Output parameter set to the new OrtKernelRegistry instance. + * Must be released with OrtEpApi::ReleaseKernelRegistry. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ + ORT_API2_STATUS(CreateKernelRegistry, _Outptr_ OrtKernelRegistry** kernel_registry); + + ORT_CLASS_RELEASE(KernelRegistry); + + /** \brief Adds kernel creation information for a supported operator kernel to the given kernel registry. + * + * \remarks Refer to OrtEp::GetKernelRegistry, which returns an EP's kernel registry to ORT. + * + * \param[in] kernel_registry The OrtKernelRegistry instance. + * \param[in] kernel_def The kernel definition, which includes operator type, version, EP name, type constraints, etc. + * \param[in] kernel_create_func Function that creates an instance of the operator kernel as a OrtKernelImpl instance. + * \param[in] kernel_create_func_state Custom state passed to the kernel creation function. Can be null. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ + ORT_API2_STATUS(KernelRegistry_AddKernel, _In_ OrtKernelRegistry* kernel_registry, + _In_ const OrtKernelDef* kernel_def, _In_ OrtKernelCreateFunc kernel_create_func, + _In_ void* kernel_create_func_state); + + /** \brief Creates a kernel definition builder used to create instances of OrtKernelDef. + * + * \param[out] kernel_def_builder_out Output parameter set to the new OrtKernelDefBuilder instance. + * Must be released with OrtEpApi::ReleaseKernelDefBuilder(). + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ + ORT_API2_STATUS(CreateKernelDefBuilder, _Outptr_ OrtKernelDefBuilder** kernel_def_builder_out); + + ORT_CLASS_RELEASE(KernelDefBuilder); + + /** \brief Sets the kernel's operator type. + * + * \param[in] kernel_def_builder The OrtKernelDefBuilder instance. + * \param[in] op_type A null-terminated string representing the operator type. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ + ORT_API2_STATUS(KernelDefBuilder_SetOperatorType, _In_ OrtKernelDefBuilder* kernel_def_builder, + _In_ const char* op_type); + + /** \brief Sets the kernel's domain. + * + * \param[in] kernel_def_builder The OrtKernelDefBuilder instance. + * \param[in] domain A null-terminated string representing the operator's domain. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ + ORT_API2_STATUS(KernelDefBuilder_SetDomain, _In_ OrtKernelDefBuilder* kernel_def_builder, _In_ const char* domain); + + /** \brief Sets the kernel's opset version range that is supported. + * + * \param[in] kernel_def_builder The OrtKernelDefBuilder instance. + * \param[in] since_version_start The starting opset version that is supported. + * \param[in] since_version_end The ending opset version (inclusive) that is supported. + * Can be set to -1 to indicate the latest opset version. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ + ORT_API2_STATUS(KernelDefBuilder_SetSinceVersion, _In_ OrtKernelDefBuilder* kernel_def_builder, + _In_ int since_version_start, _In_ int since_version_end); + + /** \brief Sets the name of the kernel's intended execution provider. + * + * \param[in] kernel_def_builder The OrtKernelDefBuilder instance. + * \param[in] ep_name A null-terminated string representing the execution provider's name. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ + ORT_API2_STATUS(KernelDefBuilder_SetExecutionProvider, _In_ OrtKernelDefBuilder* kernel_def_builder, + _In_ const char* ep_name); + + /** \brief Sets the memory type for a kernel input. + * + * \param[in] kernel_def_builder The OrtKernelDefBuilder instance. + * \param[in] input_index The index of the input. + * \param[in] mem_type The input's memory type. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ + ORT_API2_STATUS(KernelDefBuilder_SetInputMemType, _In_ OrtKernelDefBuilder* kernel_def_builder, + _In_ size_t input_index, _In_ OrtMemType mem_type); + + /** \brief Sets the memory type for a kernel output. + * + * \param[in] kernel_def_builder The OrtKernelDefBuilder instance. + * \param[in] output_index The index of the output. + * \param[in] mem_type The output's memory type. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ + ORT_API2_STATUS(KernelDefBuilder_SetOutputMemType, _In_ OrtKernelDefBuilder* kernel_def_builder, + _In_ size_t output_index, _In_ OrtMemType mem_type); + + /** \brief Sets type constraints for a kernel argument represented as a string (e.g., "T"). + * + * \param[in] kernel_def_builder The OrtKernelDefBuilder instance. + * \param[in] arg_name A null-terminated string representing the argument to constrain (e.g., "T"). + * \param[in] types Array of OrtMLDataType instances representing allowed types for the argument. + * Must contain `num_types` elements. + * \param[in] num_types The number of OrtMLDataType elements in the `types` array. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ + ORT_API2_STATUS(KernelDefBuilder_AddTypeConstraint, _In_ OrtKernelDefBuilder* kernel_def_builder, + _In_ const char* arg_name, _In_reads_(num_types) const OrtMLDataType* const* types, + _In_ size_t num_types); + + /** \brief Creates a OrtKernelDef instance from the given kernel definition builder. + * + * \param[in] kernel_def_builder The OrtKernelDefBuilder instance. + * \param[out] kernel_def_out The new OrtKernelDef instance. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ + ORT_API2_STATUS(KernelDefBuilder_Build, _In_ OrtKernelDefBuilder* kernel_def_builder, + _Outptr_ OrtKernelDef** kernel_def_out); + + ORT_CLASS_RELEASE(KernelDef); + + /** \brief Returns the operator type from the kernel definition. + * + * \param[in] kernel_def The OrtKernelDef instance. + * \return A null-terminated string representing the operator type. + * + * \since Version 1.24. + */ + ORT_API_T(const char*, KernelDef_GetOperatorType, _In_ const OrtKernelDef* kernel_def); + + /** \brief Returns the operator's domain from the kernel definition. + * + * \param[in] kernel_def The OrtKernelDef instance. + * \return A null-terminated string representing the operator's domain. + * + * \since Version 1.24. + */ + ORT_API_T(const char*, KernelDef_GetDomain, _In_ const OrtKernelDef* kernel_def); + + /** \brief Gets the kernel's opset version range that is supported. + * + * \param[in] kernel_def The OrtKernelDef instance. + * \param[out] version_start Output parameter set to the starting opset version that is supported. + * \param[out] version_end Output parameter set to the ending opset version (inclusive) that is supported. + * Set to -1 to indicate the latest opset version. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ + ORT_API2_STATUS(KernelDef_GetSinceVersion, _In_ const OrtKernelDef* kernel_def, + _Out_ int* start_version, _Out_ int* end_version); + + /** \brief Returns the name of the kernel's intended execution provider. + * + * \param[in] kernel_def The OrtKernelDef instance. + * \return A null-terminated string representing the name of the execution provider. + * + * \since Version 1.24. + */ + ORT_API_T(const char*, KernelDef_GetExecutionProvider, _In_ const OrtKernelDef* kernel_def); + + /** \brief Gets the memory type for a kernel input. + * + * \param[in] kernel_def The OrtKernelDef instance. + * \param[in] input_index The index of the input. + * \param[out] mem_type Output parameter set to the input's memory type. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ + ORT_API2_STATUS(KernelDef_GetInputMemType, _In_ const OrtKernelDef* kernel_def, + _In_ size_t input_index, _Out_ OrtMemType* mem_type); + + /** \brief Gets the memory type for a kernel output. + * + * \param[in] kernel_def The OrtKernelDef instance. + * \param[in] output_index The index of the output. + * \param[out] mem_type Output parameter set to the output's memory type. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ + ORT_API2_STATUS(KernelDef_GetOutputMemType, _In_ const OrtKernelDef* kernel_def, + _In_ size_t output_index, _Out_ OrtMemType* mem_type); + + /** \brief Gets the OrtMLDataType that represents the data type for a tensor of the given element type. + * + * \param[in] elem_type The tensor's element type. + * \param[out] out Output parameter set to the OrtMLDataType. Owned by ORT and must not be released. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ + ORT_API2_STATUS(GetTensorMLDataType, _In_ ONNXTensorElementDataType elem_type, + _Outptr_ const OrtMLDataType** out); + + /** \brief Copy OrtValue instances containing Tensors between devices. + * + * The overall copy must be between a single source device and a single destination device. i.e. + * - all src_tensors must have matching OrtMemoryInfo, + * - all dst_tensors must have matching OrtMemoryInfo. + * + * OrtValue instances should be obtained from the OrtKernelContext instanced provided to a kernel's compute function. + * Refer to OrtKernelImpl::Compute(). + * + * \param[in] info The OrtKernelInfo instance, which contains references to available data transfer implementations. + * \param[in] src_tensors Array of OrtValue instances containing the source tensors to copy. + * \param[in] dst_tensors Array of OrtValue instances to copy the source tensors to. + * \param[in] stream Optional OrtSyncStream that can be used to perform the copy asynchronously. May be nullptr. + * \param[in] num_tensors The number of tensors to copy. The size of `src_tensors` and `dst_tensors` must match. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24 + */ + ORT_API2_STATUS(KernelInfo_CopyTensors, _In_ const OrtKernelInfo* info, + _In_reads_(num_tensors) const OrtValue* const* src_tensors, + _In_reads_(num_tensors) OrtValue* const* dst_tensors, + _In_opt_ OrtSyncStream* stream, + _In_ size_t num_tensors); + + /** \brief Gets the kernel definition for a given node, if any exists for the calling execution provider. + * + * Used within OrtEp::GetCapability() to get the registered kernel definition for the given node. + * The kernel definition is set to NULL if there is no registered kernel definition for the node + * and execution provider. + * + * \param[in] graph_support_info The OrtEpGraphSupportInfo instance to query. + * \param[in] node The node for which to look up a kernel definition. + * \param[out] out_kernel_def Output parameter set to the OrtKernelDef or NULL. + * Owned by ORT and must not be released. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ + ORT_API2_STATUS(EpGraphSupportInfo_LookUpKernel, _In_ OrtEpGraphSupportInfo* graph_support_info, + _In_ const OrtNode* node, _Outptr_result_maybenull_ const OrtKernelDef** out_kernel_def); }; /** @@ -726,6 +1053,22 @@ struct OrtEp { */ ORT_API_T(const char*, GetCompiledModelCompatibilityInfo, _In_ OrtEp* this_ptr, _In_ const OrtGraph* graph); + + /** \brief Gets the execution provider's kernel registry, if any. + * + * A kernel registry contains kernel creation information for operator kernels supported by an EP. + * + * \param[in] this_ptr The OrtEp instance. + * \param[out] kernel_registry Output parameter set to the EP's kernel registry, which must remain valid throughout + * the lifetime of the EP. + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \note Implementation of this function is optional. If set to NULL, ORT assumes the EP compiles nodes. + * + * \since Version 1.24. + */ + ORT_API2_STATUS(GetKernelRegistry, _In_ OrtEp* this_ptr, + _Outptr_result_maybenull_ const OrtKernelRegistry** kernel_registry); }; /** \brief The function signature that ORT will call to create OrtEpFactory instances. diff --git a/onnxruntime/core/session/abi_ep_types.cc b/onnxruntime/core/session/abi_ep_types.cc index 14764251898aa..5f45ea0a2b808 100644 --- a/onnxruntime/core/session/abi_ep_types.cc +++ b/onnxruntime/core/session/abi_ep_types.cc @@ -10,6 +10,10 @@ #include "core/graph/ep_api_types.h" #include "core/session/abi_devices.h" +OrtEpGraphSupportInfo::OrtEpGraphSupportInfo(const onnxruntime::EpGraph& graph, + const onnxruntime::IExecutionProvider::IKernelLookup& kernel_lookup) + : ort_graph(graph), kernel_lookup{kernel_lookup} {} + onnxruntime::Status OrtEpGraphSupportInfo::AddNodesToFuse(gsl::span nodes, const OrtNodeFusionOptions* optional_fusion_options) { std::vector ep_nodes; diff --git a/onnxruntime/core/session/abi_ep_types.h b/onnxruntime/core/session/abi_ep_types.h index eb68d79a24279..deaadf7c67e6e 100644 --- a/onnxruntime/core/session/abi_ep_types.h +++ b/onnxruntime/core/session/abi_ep_types.h @@ -10,6 +10,7 @@ #include "core/common/inlined_containers_fwd.h" #include "core/common/status.h" +#include "core/framework/execution_provider.h" #include "core/session/onnxruntime_c_api.h" namespace onnxruntime { @@ -39,7 +40,8 @@ struct OrtEpGraphSupportInfo { OrtNodeFusionOptions fusion_options = {}; }; - explicit OrtEpGraphSupportInfo(const onnxruntime::EpGraph& graph) : ort_graph(graph) {} + OrtEpGraphSupportInfo(const onnxruntime::EpGraph& graph, + const onnxruntime::IExecutionProvider::IKernelLookup& kernel_lookup); onnxruntime::Status AddNodesToFuse(gsl::span nodes, const OrtNodeFusionOptions* node_fusion_options = nullptr); @@ -47,4 +49,5 @@ struct OrtEpGraphSupportInfo { const onnxruntime::EpGraph& ort_graph; std::vector node_groupings; + const onnxruntime::IExecutionProvider::IKernelLookup& kernel_lookup; }; diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 9b258d0983570..fdb76330364b4 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -3430,53 +3430,14 @@ ORT_API_STATUS_IMPL(OrtApis::CopyTensors, _In_ const OrtEnv* env, return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Invalid arguments provided to CopyTensors."); } - const OrtMemoryInfo* src_memory_info = nullptr; - const OrtMemoryInfo* dst_memory_info = nullptr; - - const auto validate_and_get_mem_info = - [](const OrtValue* const* values, size_t num_values, const OrtMemoryInfo*& mem_info) -> OrtStatus* { - for (size_t i = 0; i < num_values; ++i) { - const OrtValue* value = values[i]; - if (value == nullptr || !value->IsTensor() || !value->IsAllocated()) { - return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "OrtValue must contain Tensor with data."); - } - - if (i == 0) { - mem_info = &value->Get().Location(); - } else if (*mem_info != value->Get().Location()) { - return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "All OrtValue instances must have the same OrtMemoryInfo"); - } - } - - return nullptr; - }; - - ORT_API_RETURN_IF_ERROR(validate_and_get_mem_info(src_tensors, num_tensors, src_memory_info)); - ORT_API_RETURN_IF_ERROR(validate_and_get_mem_info(const_cast(dst_tensors), num_tensors, - dst_memory_info)); - auto& data_transfer_mgr = env->GetEnvironment().GetDataTransferManager(); - const auto* data_transfer = data_transfer_mgr.GetDataTransfer(src_memory_info->device, dst_memory_info->device); - if (data_transfer == nullptr) { - return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, - "Data transfer implementation between source and destination device was not found."); - } - - std::vector pairs; - pairs.reserve(num_tensors); - for (size_t i = 0; i < num_tensors; ++i) { - pairs.push_back({ - src_tensors[i]->Get(), - *dst_tensors[i]->GetMutable(), - stream, - }); - } - - ORT_API_RETURN_IF_STATUS_NOT_OK(data_transfer->CopyTensors(pairs)); + ORT_API_RETURN_IF_STATUS_NOT_OK(CopyTensors(data_transfer_mgr, + gsl::span(src_tensors, num_tensors), + gsl::span(dst_tensors, num_tensors), + stream)); return nullptr; - API_IMPL_END } diff --git a/onnxruntime/core/session/plugin_ep/ep_api.cc b/onnxruntime/core/session/plugin_ep/ep_api.cc index cae0b086af66c..e0541ad073d23 100644 --- a/onnxruntime/core/session/plugin_ep/ep_api.cc +++ b/onnxruntime/core/session/plugin_ep/ep_api.cc @@ -4,11 +4,14 @@ #include "core/session/plugin_ep/ep_api.h" #include +#include +#include #include #include "core/common/semver.h" #include "core/framework/error_code_helper.h" #include "core/framework/func_api.h" +#include "core/framework/op_kernel_info.h" #include "core/framework/ort_value.h" #include "core/framework/ortdevice.h" #include "core/framework/ortmemoryinfo.h" @@ -19,6 +22,8 @@ #include "core/session/abi_ep_types.h" #include "core/session/onnxruntime_ep_device_ep_metadata_keys.h" #include "core/session/ort_apis.h" +#include "core/session/plugin_ep/ep_kernel_registration.h" +#include "core/session/utils.h" using namespace onnxruntime; namespace OrtExecutionProviderApi { @@ -71,7 +76,7 @@ ORT_API_STATUS_IMPL(EpGraphSupportInfo_AddNodesToFuse, _In_ OrtEpGraphSupportInf _In_opt_ const OrtNodeFusionOptions* node_fusion_options) { API_IMPL_BEGIN if (ort_graph_support_info == nullptr) { - return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Must specify a valid OrtGraph instance"); + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Must specify a valid OrtEpGraphSupportInfo instance"); } if (num_nodes == 0 || nodes == nullptr) { @@ -205,6 +210,277 @@ ORT_API(uint64_t, GetSyncIdForLastWaitOnSyncStream, _In_ const OrtSyncStream* pr return id; } +ORT_API_STATUS_IMPL(CreateKernelRegistry, _Outptr_ OrtKernelRegistry** kernel_registry) { + API_IMPL_BEGIN + auto unique_kernel_registry = std::make_unique(); + unique_kernel_registry->registry = std::make_shared(); + + *kernel_registry = unique_kernel_registry.release(); + return nullptr; + API_IMPL_END +} + +ORT_API(void, ReleaseKernelRegistry, _Frees_ptr_opt_ OrtKernelRegistry* kernel_registry) { + delete kernel_registry; +} + +ORT_API_STATUS_IMPL(KernelRegistry_AddKernel, _In_ OrtKernelRegistry* kernel_registry, + _In_ const OrtKernelDef* kernel_def, _In_ OrtKernelCreateFunc kernel_create_func, + _In_ void* kernel_create_func_state) { + API_IMPL_BEGIN + if (kernel_registry == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Must specify a non-null OrtKernelRegistry"); + } + + if (kernel_def == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Must specify a non-null OrtKernelDef"); + } + + if (kernel_create_func == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Must specify a non-null OrtKernelCreateFunc"); + } + + KernelCreateInfo kernel_create_info = MakePluginEpKernelCreateInfo(static_cast(kernel_def), + kernel_create_func, kernel_create_func_state); + + ORT_API_RETURN_IF_STATUS_NOT_OK(kernel_registry->registry->Register(std::move(kernel_create_info))); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(CreateKernelDefBuilder, _Outptr_ OrtKernelDefBuilder** kernel_def_builder_out) { + API_IMPL_BEGIN + auto builder = onnxruntime::KernelDefBuilder::Create(); + *kernel_def_builder_out = static_cast(builder.release()); + return nullptr; + API_IMPL_END +} + +ORT_API(void, ReleaseKernelDefBuilder, _Frees_ptr_opt_ OrtKernelDefBuilder* kernel_def_builder) { + delete kernel_def_builder; +} + +ORT_API_STATUS_IMPL(KernelDefBuilder_SetOperatorType, _In_ OrtKernelDefBuilder* kernel_def_builder, + _In_ const char* op_type) { + API_IMPL_BEGIN + kernel_def_builder->SetName(op_type); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(KernelDefBuilder_SetDomain, _In_ OrtKernelDefBuilder* kernel_def_builder, + _In_ const char* domain) { + API_IMPL_BEGIN + kernel_def_builder->SetDomain(domain); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(KernelDefBuilder_SetSinceVersion, _In_ OrtKernelDefBuilder* kernel_def_builder, + _In_ int since_version_start, _In_ int since_version_end) { + API_IMPL_BEGIN + kernel_def_builder->SinceVersion(since_version_start, since_version_end); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(KernelDefBuilder_SetExecutionProvider, _In_ OrtKernelDefBuilder* kernel_def_builder, + _In_ const char* ep_name) { + API_IMPL_BEGIN + kernel_def_builder->Provider(ep_name); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(KernelDefBuilder_SetInputMemType, _In_ OrtKernelDefBuilder* kernel_def_builder, + _In_ size_t input_index, _In_ OrtMemType mem_type) { + API_IMPL_BEGIN + kernel_def_builder->InputMemoryType(mem_type, static_cast(input_index)); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(KernelDefBuilder_SetOutputMemType, _In_ OrtKernelDefBuilder* kernel_def_builder, + _In_ size_t output_index, _In_ OrtMemType mem_type) { + API_IMPL_BEGIN + kernel_def_builder->OutputMemoryType(mem_type, static_cast(output_index)); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(KernelDefBuilder_AddTypeConstraint, _In_ OrtKernelDefBuilder* kernel_def_builder, + _In_ const char* arg_name, _In_reads_(num_types) const OrtMLDataType* const* types, + _In_ size_t num_types) { + API_IMPL_BEGIN + if (num_types == 0) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Must specify at least one OrtMLDataType instance"); + } + + if (types == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Must specify a valid array of OrtMLDataType instances"); + } + + if (arg_name == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, + "Must specify a valid name for a kernel definition's type constraint"); + } + + std::vector ml_types; + ml_types.reserve(num_types); + + for (size_t i = 0; i < num_types; i++) { + ml_types.push_back(static_cast(types[i])); + } + + kernel_def_builder->TypeConstraint(arg_name, std::move(ml_types)); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(KernelDefBuilder_Build, _In_ OrtKernelDefBuilder* kernel_def_builder, + _Outptr_ OrtKernelDef** kernel_def_out) { + API_IMPL_BEGIN + *kernel_def_out = static_cast(kernel_def_builder->Build().release()); + return nullptr; + API_IMPL_END +} + +ORT_API(void, ReleaseKernelDef, _Frees_ptr_opt_ OrtKernelDef* kernel_def) { + delete kernel_def; +} + +ORT_API(const char*, KernelDef_GetOperatorType, _In_ const OrtKernelDef* kernel_def) { + return static_cast(kernel_def)->OpName().c_str(); +} + +ORT_API(const char*, KernelDef_GetDomain, _In_ const OrtKernelDef* kernel_def) { + return static_cast(kernel_def)->Domain().c_str(); +} + +ORT_API_STATUS_IMPL(KernelDef_GetSinceVersion, _In_ const OrtKernelDef* kernel_def, + _Out_ int* start_version, _Out_ int* end_version) { + API_IMPL_BEGIN + if (kernel_def == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Must specify a valid non-null OrtKernelDef"); + } + + if (start_version == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Must specify a non-null `start_version` output parameter"); + } + + if (end_version == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Must specify a non-null `end_version` output parameter"); + } + + auto internal_kernel_def = static_cast(kernel_def); + internal_kernel_def->SinceVersion(start_version, end_version); + + return nullptr; + API_IMPL_END +} + +ORT_API(const char*, KernelDef_GetExecutionProvider, _In_ const OrtKernelDef* kernel_def) { + return static_cast(kernel_def)->Provider().c_str(); +} + +ORT_API_STATUS_IMPL(KernelDef_GetInputMemType, _In_ const OrtKernelDef* kernel_def, + _In_ size_t input_index, _Out_ OrtMemType* mem_type) { + API_IMPL_BEGIN + if (kernel_def == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Must specify a valid non-null OrtKernelDef"); + } + + if (mem_type == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Must specify a non-null `mem_type` output parameter"); + } + + auto internal_kernel_def = static_cast(kernel_def); + *mem_type = internal_kernel_def->InputMemoryType(input_index); + + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(KernelDef_GetOutputMemType, _In_ const OrtKernelDef* kernel_def, + _In_ size_t output_index, _Out_ OrtMemType* mem_type) { + API_IMPL_BEGIN + if (kernel_def == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Must specify a valid non-null OrtKernelDef"); + } + + if (mem_type == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Must specify a non-null `mem_type` output parameter"); + } + + auto internal_kernel_def = static_cast(kernel_def); + *mem_type = internal_kernel_def->OutputMemoryType(output_index); + + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(GetTensorMLDataType, _In_ ONNXTensorElementDataType elem_type, + _Outptr_ const OrtMLDataType** out) { + API_IMPL_BEGIN + const DataTypeImpl* ml_type = DataTypeImpl::TensorTypeFromONNXEnum(elem_type); + *out = static_cast(ml_type); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(KernelInfo_CopyTensors, _In_ const OrtKernelInfo* info, + _In_reads_(num_tensors) const OrtValue* const* src_tensors, + _In_reads_(num_tensors) OrtValue* const* dst_tensors, + _In_opt_ OrtSyncStream* stream, + _In_ size_t num_tensors) { + API_IMPL_BEGIN + if (info == nullptr || src_tensors == nullptr || dst_tensors == nullptr || num_tensors == 0) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Invalid arguments provided to KernelInfo_CopyTensors."); + } + + auto op_kernel_info = reinterpret_cast(info); + auto& data_transfer_mgr = op_kernel_info->GetDataTransferManager(); + + ORT_API_RETURN_IF_STATUS_NOT_OK(CopyTensors(data_transfer_mgr, + gsl::span(src_tensors, num_tensors), + gsl::span(dst_tensors, num_tensors), + stream)); + + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(EpGraphSupportInfo_LookUpKernel, _In_ OrtEpGraphSupportInfo* graph_support_info, + _In_ const OrtNode* node, _Outptr_result_maybenull_ const OrtKernelDef** out_kernel_def) { + API_IMPL_BEGIN + if (out_kernel_def == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Must specify a non-null OrtKernelDef output parameter"); + } + + *out_kernel_def = nullptr; + + if (graph_support_info == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Must specify a valid non-null OrtEpGraphSupportInfo instance"); + } + + if (node == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Must specify a valid non-null OrtNode instance"); + } + + const onnxruntime::EpNode* ep_node = onnxruntime::EpNode::ToInternal(node); + if (ep_node == nullptr) { + return OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, + "OrtNode created via the ModelEditor API is not supported"); + } + + const onnxruntime::KernelCreateInfo* create_info = + graph_support_info->kernel_lookup.LookUpKernel(ep_node->GetInternalNode()); + + *out_kernel_def = static_cast(create_info->kernel_def.get()); + return nullptr; + API_IMPL_END +} + static constexpr OrtEpApi ort_ep_api = { // NOTE: ABI compatibility depends on the order within this struct so all additions must be at the end, // and no functions can be removed (the implementation needs to change to return an error). @@ -230,6 +506,29 @@ static constexpr OrtEpApi ort_ep_api = { &OrtExecutionProviderApi::SyncStream_GetImpl, &OrtExecutionProviderApi::SyncStream_GetSyncId, &OrtExecutionProviderApi::GetSyncIdForLastWaitOnSyncStream, + &OrtExecutionProviderApi::CreateKernelRegistry, + &OrtExecutionProviderApi::ReleaseKernelRegistry, + &OrtExecutionProviderApi::KernelRegistry_AddKernel, + &OrtExecutionProviderApi::CreateKernelDefBuilder, + &OrtExecutionProviderApi::ReleaseKernelDefBuilder, + &OrtExecutionProviderApi::KernelDefBuilder_SetOperatorType, + &OrtExecutionProviderApi::KernelDefBuilder_SetDomain, + &OrtExecutionProviderApi::KernelDefBuilder_SetSinceVersion, + &OrtExecutionProviderApi::KernelDefBuilder_SetExecutionProvider, + &OrtExecutionProviderApi::KernelDefBuilder_SetInputMemType, + &OrtExecutionProviderApi::KernelDefBuilder_SetOutputMemType, + &OrtExecutionProviderApi::KernelDefBuilder_AddTypeConstraint, + &OrtExecutionProviderApi::KernelDefBuilder_Build, + &OrtExecutionProviderApi::ReleaseKernelDef, + &OrtExecutionProviderApi::KernelDef_GetOperatorType, + &OrtExecutionProviderApi::KernelDef_GetDomain, + &OrtExecutionProviderApi::KernelDef_GetSinceVersion, + &OrtExecutionProviderApi::KernelDef_GetExecutionProvider, + &OrtExecutionProviderApi::KernelDef_GetInputMemType, + &OrtExecutionProviderApi::KernelDef_GetOutputMemType, + &OrtExecutionProviderApi::GetTensorMLDataType, + &OrtExecutionProviderApi::KernelInfo_CopyTensors, + &OrtExecutionProviderApi::EpGraphSupportInfo_LookUpKernel, }; // checks that we don't violate the rule that the functions must remain in the slots they were originally assigned diff --git a/onnxruntime/core/session/plugin_ep/ep_api.h b/onnxruntime/core/session/plugin_ep/ep_api.h index c0dc79f3fb333..f637dc00d6c6b 100644 --- a/onnxruntime/core/session/plugin_ep/ep_api.h +++ b/onnxruntime/core/session/plugin_ep/ep_api.h @@ -40,4 +40,54 @@ ORT_API(const OrtSyncStreamImpl*, SyncStream_GetImpl, _In_ const OrtSyncStream* ORT_API(uint64_t, SyncStream_GetSyncId, _In_ const OrtSyncStream* stream); ORT_API(uint64_t, GetSyncIdForLastWaitOnSyncStream, _In_ const OrtSyncStream* producer_stream, _In_ const OrtSyncStream* consumer_stream); + +// OrtKernelRegistry +ORT_API_STATUS_IMPL(CreateKernelRegistry, _Outptr_ OrtKernelRegistry** kernel_registry); +ORT_API(void, ReleaseKernelRegistry, _Frees_ptr_opt_ OrtKernelRegistry* kernel_registry); +ORT_API_STATUS_IMPL(KernelRegistry_AddKernel, _In_ OrtKernelRegistry* kernel_registry, + _In_ const OrtKernelDef* kernel_def, _In_ OrtKernelCreateFunc kernel_create_func, + _In_ void* kernel_create_func_state); + +// OrtKernelDefBuilder +ORT_API_STATUS_IMPL(CreateKernelDefBuilder, _Outptr_ OrtKernelDefBuilder** kernel_def_builder_out); +ORT_API(void, ReleaseKernelDefBuilder, _Frees_ptr_opt_ OrtKernelDefBuilder* kernel_def_builder); +ORT_API_STATUS_IMPL(KernelDefBuilder_SetOperatorType, _In_ OrtKernelDefBuilder* kernel_def_builder, + _In_ const char* op_type); +ORT_API_STATUS_IMPL(KernelDefBuilder_SetDomain, _In_ OrtKernelDefBuilder* kernel_def_builder, + _In_ const char* domain); +ORT_API_STATUS_IMPL(KernelDefBuilder_SetSinceVersion, _In_ OrtKernelDefBuilder* kernel_def_builder, + _In_ int since_version_start, _In_ int since_version_end); +ORT_API_STATUS_IMPL(KernelDefBuilder_SetExecutionProvider, _In_ OrtKernelDefBuilder* kernel_def_builder, + _In_ const char* ep_name); +ORT_API_STATUS_IMPL(KernelDefBuilder_SetInputMemType, _In_ OrtKernelDefBuilder* kernel_def_builder, + _In_ size_t input_index, _In_ OrtMemType mem_type); +ORT_API_STATUS_IMPL(KernelDefBuilder_SetOutputMemType, _In_ OrtKernelDefBuilder* kernel_def_builder, + _In_ size_t output_index, _In_ OrtMemType mem_type); +ORT_API_STATUS_IMPL(KernelDefBuilder_AddTypeConstraint, _In_ OrtKernelDefBuilder* kernel_def_builder, + _In_ const char* arg_name, _In_reads_(num_types) const OrtMLDataType* const* types, + _In_ size_t num_types); +ORT_API_STATUS_IMPL(KernelDefBuilder_Build, _In_ OrtKernelDefBuilder* kernel_def_builder, + _Outptr_ OrtKernelDef** kernel_def_out); + +// OrtKernelDef +ORT_API(void, ReleaseKernelDef, _Frees_ptr_opt_ OrtKernelDef* kernel_def); +ORT_API(const char*, KernelDef_GetOperatorType, _In_ const OrtKernelDef* kernel_def); +ORT_API(const char*, KernelDef_GetDomain, _In_ const OrtKernelDef* kernel_def); +ORT_API_STATUS_IMPL(KernelDef_GetSinceVersion, _In_ const OrtKernelDef* kernel_def, + _Out_ int* start_version, _Out_ int* end_version); +ORT_API(const char*, KernelDef_GetExecutionProvider, _In_ const OrtKernelDef* kernel_def); +ORT_API_STATUS_IMPL(KernelDef_GetInputMemType, _In_ const OrtKernelDef* kernel_def, + _In_ size_t input_index, _Out_ OrtMemType* mem_type); +ORT_API_STATUS_IMPL(KernelDef_GetOutputMemType, _In_ const OrtKernelDef* kernel_def, + _In_ size_t output_index, _Out_ OrtMemType* mem_type); + +ORT_API_STATUS_IMPL(GetTensorMLDataType, _In_ ONNXTensorElementDataType elem_type, + _Outptr_ const OrtMLDataType** out); +ORT_API_STATUS_IMPL(KernelInfo_CopyTensors, _In_ const OrtKernelInfo* info, + _In_reads_(num_tensors) const OrtValue* const* src_tensors, + _In_reads_(num_tensors) OrtValue* const* dst_tensors, + _In_opt_ OrtSyncStream* stream, + _In_ size_t num_tensors); +ORT_API_STATUS_IMPL(EpGraphSupportInfo_LookUpKernel, _In_ OrtEpGraphSupportInfo* graph_support_info, + _In_ const OrtNode* node, _Outptr_result_maybenull_ const OrtKernelDef** out_kernel_def); } // namespace OrtExecutionProviderApi diff --git a/onnxruntime/core/session/plugin_ep/ep_kernel_registration.cc b/onnxruntime/core/session/plugin_ep/ep_kernel_registration.cc new file mode 100644 index 0000000000000..d60e14ac287d6 --- /dev/null +++ b/onnxruntime/core/session/plugin_ep/ep_kernel_registration.cc @@ -0,0 +1,115 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/session/plugin_ep/ep_kernel_registration.h" + +#include +#include + +#include "core/framework/error_code_helper.h" +#include "core/framework/kernel_registry.h" +#include "core/session/plugin_ep/ep_api.h" + +namespace onnxruntime { + +/// +/// OpKernel that wraps a OrtKernelImpl provided by a plugin EP. +/// +class PluginEpOpKernel final : public OpKernel { + private: + struct PrivateTag {}; + + public: + PluginEpOpKernel(const OpKernelInfo& info, PrivateTag) + : OpKernel{info} {} + + static Status Create(FuncManager& fn_manager, const OpKernelInfo& info, + OrtKernelCreateFunc kernel_create_func, void* kernel_create_func_state, + /*out*/ std::unique_ptr& op_kernel); + + ~PluginEpOpKernel() { + kernel_impl_->Release(kernel_impl_); + } + + Status Compute(OpKernelContext* ctx) const override { + return ToStatusAndRelease(kernel_impl_->Compute(kernel_impl_, reinterpret_cast(ctx))); + } + + private: + OrtKernelImpl* kernel_impl_ = nullptr; +}; + +/*static*/ +Status PluginEpOpKernel::Create(FuncManager& fn_manager, const OpKernelInfo& info, + OrtKernelCreateFunc kernel_create_func, void* kernel_create_func_state, + /*out*/ std::unique_ptr& op_kernel) { + // OpKernel's constructor *copies* the OpKernelInfo. + // Therefore, must create the OpKernel instance immediately so that we can pass the actual OpKernelInfo + // to the plugin EP's kernel creation function. + op_kernel = std::make_unique(info, PrivateTag{}); + + OrtKernelCreateContext* create_ctx = reinterpret_cast(&fn_manager); + const OrtKernelInfo* kernel_info = reinterpret_cast(&op_kernel->Info()); + + ORT_RETURN_IF_ERROR(ToStatusAndRelease( + kernel_create_func(create_ctx, kernel_create_func_state, kernel_info, &op_kernel->kernel_impl_))); + + return Status::OK(); +} + +/// +/// A functor that creates a PluginEpOpKernel instance using the creation function (+ state) provided by a plugin EP. +/// +class PluginEpKernelCreateFunctor { + public: + PluginEpKernelCreateFunctor() : kernel_create_func_(nullptr), kernel_create_func_state_(nullptr) {} + PluginEpKernelCreateFunctor(OrtKernelCreateFunc create_func, void* state) + : kernel_create_func_{create_func}, kernel_create_func_state_{state} {} + + Status operator()(FuncManager& fn_manager, const OpKernelInfo& info, std::unique_ptr& out) { + if (kernel_create_func_ == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "PluginEpKernelCreateFunctor does not wrap a valid OrtKernelCreateFunc"); + } + + std::unique_ptr plugin_ep_op_kernel; + ORT_RETURN_IF_ERROR(PluginEpOpKernel::Create(fn_manager, info, kernel_create_func_, kernel_create_func_state_, + plugin_ep_op_kernel)); + + out = std::move(plugin_ep_op_kernel); + return Status::OK(); + } + + private: + OrtKernelCreateFunc kernel_create_func_; + void* kernel_create_func_state_; +}; + +// Make a KernelCreateInfo for a plugin EP's kernel +KernelCreateInfo MakePluginEpKernelCreateInfo(const KernelDef* kernel_def, + OrtKernelCreateFunc kernel_create_func, + void* kernel_create_func_state) { + auto kernel_def_copy = std::make_unique(*kernel_def); + PluginEpKernelCreateFunctor kernel_create_functor(kernel_create_func, kernel_create_func_state); + return KernelCreateInfo(std::move(kernel_def_copy), kernel_create_functor); +} + +// Gets an OrtEp instance's kernel registry. +Status GetPluginEpKernelRegistry(OrtEp& ort_ep, /*out*/ std::shared_ptr& kernel_registry) { + kernel_registry = nullptr; + + if (ort_ep.GetKernelRegistry == nullptr) { + return Status::OK(); + } + + const OrtKernelRegistry* ep_kernel_registry = nullptr; + ORT_RETURN_IF_ERROR(ToStatusAndRelease(ort_ep.GetKernelRegistry(&ort_ep, &ep_kernel_registry))); + + if (ep_kernel_registry != nullptr) { + kernel_registry = ep_kernel_registry->registry; + } + + return Status::OK(); +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/session/plugin_ep/ep_kernel_registration.h b/onnxruntime/core/session/plugin_ep/ep_kernel_registration.h new file mode 100644 index 0000000000000..9d046c8c3420e --- /dev/null +++ b/onnxruntime/core/session/plugin_ep/ep_kernel_registration.h @@ -0,0 +1,46 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include "core/session/onnxruntime_c_api.h" +#include "core/framework/data_types.h" +#include "core/framework/error_code_helper.h" +#include "core/framework/kernel_def_builder.h" +#include "core/framework/kernel_registry.h" +#include "core/framework/op_kernel.h" + +struct OrtMLDataType : onnxruntime::DataTypeImpl {}; + +struct OrtKernelDefBuilder : onnxruntime::KernelDefBuilder {}; + +struct OrtKernelDef : onnxruntime::KernelDef {}; + +struct OrtKernelRegistry { + std::shared_ptr registry; +}; + +namespace onnxruntime { + +/// +/// Make a KernelCreateInfo for a plugin EP's kernel. A KernelCreateInfo contains the function and state +/// necessary to create a kernel. +/// +/// +/// +/// +/// +KernelCreateInfo MakePluginEpKernelCreateInfo(const KernelDef* kernel_def, + OrtKernelCreateFunc kernel_create_func, + void* kernel_create_func_state); + +/// +/// Gets the kernel registry for a plugin EP. +/// +/// The OrtEp instance. +/// Output parameter set to the EP's registry. +/// A status indicating success or an error +Status GetPluginEpKernelRegistry(OrtEp& ort_ep, /*out*/ std::shared_ptr& kernel_registry); + +} // namespace onnxruntime diff --git a/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc b/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc index 55245420db37a..e258cb8f31b0a 100644 --- a/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc +++ b/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc @@ -21,6 +21,7 @@ #include "core/session/abi_logger.h" #include "core/session/abi_session_options_impl.h" #include "core/session/allocator_adapters.h" +#include "core/session/plugin_ep/ep_kernel_registration.h" #include "core/session/ort_apis.h" #include "core/providers/partitioning_utils.h" @@ -46,18 +47,15 @@ PluginExecutionProviderFactory::PluginExecutionProviderFactory(OrtEpFactory& ep_ std::unique_ptr PluginExecutionProviderFactory::CreateProvider(const OrtSessionOptions& session_options, const OrtLogger& session_logger) { - OrtEp* ort_ep = nullptr; - Status status = ToStatusAndRelease(ep_factory_.CreateEp(&ep_factory_, hardware_devices_.data(), ep_metadata_.data(), - hardware_devices_.size(), &session_options, &session_logger, - &ort_ep)); - + std::unique_ptr plugin_ep; + Status status = PluginExecutionProvider::Create(ep_factory_, devices_, hardware_devices_, + ep_metadata_, session_options, session_logger, plugin_ep); if (!status.IsOK()) { - ORT_THROW("Error creating execution provider: ", status.ToString()); + LOGS(*session_logger.ToInternal(), ERROR) << "Error creating execution provider: " << status.ToString(); + return nullptr; } - return std::make_unique(UniqueOrtEp(ort_ep, OrtEpDeleter(ep_factory_)), - session_options, ep_factory_, devices_, - *session_logger.ToInternal()); + return plugin_ep; } /// @@ -129,14 +127,42 @@ static const Node* FindFirstNodeAssignedToOtherEP(const std::string& ep_type, return node_iter != ep_nodes.end() ? &(*node_iter)->GetInternalNode() : nullptr; } +/*static*/ +Status PluginExecutionProvider::Create(OrtEpFactory& ep_factory, + gsl::span ep_devices, + gsl::span hw_devices, + gsl::span ep_metadata, + const OrtSessionOptions& session_options, + const OrtLogger& logger, + /*out*/ std::unique_ptr& plugin_ep) { + plugin_ep = nullptr; + OrtEp* ort_ep = nullptr; + ORT_RETURN_IF_ERROR(ToStatusAndRelease(ep_factory.CreateEp(&ep_factory, hw_devices.data(), ep_metadata.data(), + hw_devices.size(), &session_options, &logger, + &ort_ep))); + ORT_RETURN_IF(ort_ep == nullptr, "OrtEpFactory::CreateEp() for '", ep_factory.GetName(&ep_factory), + "' returned a NULL OrtEp instance"); + + std::shared_ptr kernel_registry; + ORT_RETURN_IF_ERROR(GetPluginEpKernelRegistry(*ort_ep, kernel_registry)); + + plugin_ep = std::make_unique(UniqueOrtEp(ort_ep, OrtEpDeleter(ep_factory)), + session_options, ep_factory, ep_devices, + kernel_registry, + *logger.ToInternal()); + return Status::OK(); +} + PluginExecutionProvider::PluginExecutionProvider(UniqueOrtEp ep, const OrtSessionOptions& session_options, OrtEpFactory& ep_factory, gsl::span ep_devices, + std::shared_ptr kernel_registry, const logging::Logger& logger) : IExecutionProvider(ep->GetName(ep.get()), GetOrtDeviceForPluginEp(ep_devices), logger), ort_ep_(std::move(ep)), ep_factory_(ep_factory), - ep_devices_(ep_devices.begin(), ep_devices.end()) { + ep_devices_(ep_devices.begin(), ep_devices.end()), + kernel_registry_(std::move(kernel_registry)) { generate_ep_ctx_model_ = session_options.value.GetEpContextGenerationOptions().enable; for (const auto* ep_device : ep_devices_) { @@ -161,6 +187,10 @@ PluginExecutionProvider::~PluginExecutionProvider() { } } +std::shared_ptr PluginExecutionProvider::GetKernelRegistry() const { + return kernel_registry_; +} + std::vector> PluginExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer, const IKernelLookup& kernel_lookup, @@ -168,7 +198,6 @@ PluginExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie IResourceAccountant* resource_accountant) const { ORT_UNUSED_PARAMETER(graph_optimizer_registry); // TODO: Add support ORT_UNUSED_PARAMETER(resource_accountant); // TODO: Add support? Not used by prioritized EPs - ORT_UNUSED_PARAMETER(kernel_lookup); // TODO: Add support? Not used by prioritized EPs, so probably not needed? const logging::Logger& logger = GetLogger() != nullptr ? *GetLogger() : logging::LoggingManager::DefaultLogger(); @@ -178,7 +207,7 @@ PluginExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie return {}; } - OrtEpGraphSupportInfo api_graph_support_info(*ep_graph); + OrtEpGraphSupportInfo api_graph_support_info(*ep_graph, kernel_lookup); Status status = ToStatusAndRelease(ort_ep_->GetCapability(ort_ep_.get(), ep_graph->ToExternal(), &api_graph_support_info)); if (!status.IsOK()) { diff --git a/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.h b/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.h index 622bbb3f97b24..88b3965b1acad 100644 --- a/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.h +++ b/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.h @@ -64,10 +64,22 @@ class PluginExecutionProvider : public IExecutionProvider { using Base = IExecutionProvider; public: + static Status Create(OrtEpFactory& ep_factory, + gsl::span ep_devices, + gsl::span hw_devices, + gsl::span ep_metadata, + const OrtSessionOptions& session_options, + const OrtLogger& logger, + /*out*/ std::unique_ptr& plugin_ep); + explicit PluginExecutionProvider(UniqueOrtEp ep, const OrtSessionOptions& session_options, OrtEpFactory& ep_factory, - gsl::span ep_devices, const logging::Logger& logger); + gsl::span ep_devices, + std::shared_ptr kernel_registry, + const logging::Logger& logger); ~PluginExecutionProvider(); + std::shared_ptr GetKernelRegistry() const override; + std::vector> GetCapability(const onnxruntime::GraphViewer& graph_viewer, const IKernelLookup& kernel_lookup, @@ -136,5 +148,7 @@ class PluginExecutionProvider : public IExecutionProvider { // calls IExecutionProvider::GetEpContextNodes(). std::vector> ep_context_nodes_; std::vector> ep_context_node_args_; + + std::shared_ptr kernel_registry_; }; } // namespace onnxruntime diff --git a/onnxruntime/core/session/provider_policy_context.cc b/onnxruntime/core/session/provider_policy_context.cc index 6bcbda0f13b92..50039766e8932 100644 --- a/onnxruntime/core/session/provider_policy_context.cc +++ b/onnxruntime/core/session/provider_policy_context.cc @@ -7,6 +7,7 @@ #include #include +#include #include "core/framework/error_code_helper.h" #include "core/session/abi_devices.h" @@ -357,12 +358,10 @@ Status ProviderPolicyContext::CreateExecutionProvider(const Environment& env, Or info.hardware_devices.size(), &options, &logger, &ep))); } else { - OrtEp* api_ep = nullptr; - ORT_RETURN_IF_ERROR(ToStatusAndRelease( - info.ep_factory->CreateEp(info.ep_factory, info.hardware_devices.data(), info.ep_metadata.data(), - info.hardware_devices.size(), &options, &logger, &api_ep))); - ep = std::make_unique(UniqueOrtEp(api_ep, OrtEpDeleter(*info.ep_factory)), options, - *info.ep_factory, info.devices, *logger.ToInternal()); + std::unique_ptr plugin_ep; + ORT_RETURN_IF_ERROR(PluginExecutionProvider::Create(*info.ep_factory, info.devices, info.hardware_devices, + info.ep_metadata, options, logger, plugin_ep)); + ep = std::move(plugin_ep); } return Status::OK(); diff --git a/onnxruntime/core/session/utils.cc b/onnxruntime/core/session/utils.cc index 4a50bab5e8cbc..a13f61dc33c2f 100644 --- a/onnxruntime/core/session/utils.cc +++ b/onnxruntime/core/session/utils.cc @@ -5,6 +5,7 @@ #include #include +#include #include "core/framework/error_code_helper.h" #include "core/framework/execution_provider.h" @@ -20,6 +21,8 @@ #include "core/session/onnxruntime_ep_device_ep_metadata_keys.h" #if !defined(ORT_MINIMAL_BUILD) +#include "core/framework/data_transfer.h" +#include "core/framework/plugin_ep_stream.h" #include "core/session/plugin_ep/ep_factory_internal.h" #include "core/session/plugin_ep/ep_plugin_provider_interfaces.h" #include "core/session/plugin_ep/ep_library_plugin.h" @@ -497,5 +500,62 @@ Status AddEpOptionsToSessionOptions(gsl::span ep_devic return Status::OK(); } + +Status CopyTensors(const DataTransferManager& data_transfer_manager, + gsl::span src_tensors, + gsl::span dst_tensors, + OrtSyncStream* stream) { + if (src_tensors.size() != dst_tensors.size()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Expected the same number of source and destination tensors. ", + "Have ", src_tensors.size(), " source tensors and ", dst_tensors.size(), + " destination tensors."); + } + + const OrtMemoryInfo* src_memory_info = nullptr; + const OrtMemoryInfo* dst_memory_info = nullptr; + + const auto validate_and_get_mem_info = + [](gsl::span values, const OrtMemoryInfo*& mem_info) -> Status { + for (size_t i = 0; i < values.size(); ++i) { + const OrtValue* value = values[i]; + if (value == nullptr || !value->IsTensor() || !value->IsAllocated()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "OrtValue must contain Tensor with data."); + } + + if (i == 0) { + mem_info = &value->Get().Location(); + } else if (*mem_info != value->Get().Location()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "All OrtValue instances must have the same OrtMemoryInfo"); + } + } + + return Status::OK(); + }; + + ORT_RETURN_IF_ERROR(validate_and_get_mem_info(src_tensors, src_memory_info)); + ORT_RETURN_IF_ERROR(validate_and_get_mem_info(dst_tensors, dst_memory_info)); + + const auto* data_transfer = data_transfer_manager.GetDataTransfer(src_memory_info->device, dst_memory_info->device); + + if (data_transfer == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, + "Data transfer implementation between source and destination device was not found."); + } + + std::vector pairs; + pairs.reserve(src_tensors.size()); + for (size_t i = 0; i < src_tensors.size(); ++i) { + pairs.push_back({ + src_tensors[i]->Get(), + *dst_tensors[i]->GetMutable(), + stream, + }); + } + + ORT_RETURN_IF_ERROR(data_transfer->CopyTensors(pairs)); + return Status::OK(); +} #endif // !defined(ORT_MINIMAL_BUILD) } // namespace onnxruntime diff --git a/onnxruntime/core/session/utils.h b/onnxruntime/core/session/utils.h index 2ccd4d464a261..f3cd3abeef6c2 100644 --- a/onnxruntime/core/session/utils.h +++ b/onnxruntime/core/session/utils.h @@ -22,6 +22,7 @@ class ModelCompilationOptions; #if !defined(ORT_MINIMAL_BUILD) namespace onnxruntime { +class DataTransferManager; class Environment; class EpLibrary; class EpFactoryInternal; @@ -71,5 +72,9 @@ Status AddEpOptionsToSessionOptions(gsl::span ep_devic gsl::span ep_options_vals, SessionOptions& session_options); +Status CopyTensors(const DataTransferManager& data_transfer_manager, + gsl::span src_tensors, + gsl::span dst_tensors, + OrtSyncStream* stream = nullptr); } // namespace onnxruntime #endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/test/autoep/library/ep.cc b/onnxruntime/test/autoep/library/ep.cc index 5d8245618dcd6..a798c7a5158f6 100644 --- a/onnxruntime/test/autoep/library/ep.cc +++ b/onnxruntime/test/autoep/library/ep.cc @@ -16,12 +16,12 @@ #include "ep_stream_support.h" /// -/// Example implementation of ONNX Mul. Does not handle many things like broadcasting. +/// Example implementation of ONNX Mul for compiling EP. Does not handle many things like broadcasting. /// -struct MulKernel { - MulKernel(const OrtApi& ort_api, const OrtLogger& logger, - const std::unordered_map& float_initializers, - std::string input0_name, std::string input1_name) +struct CompiledMul { + CompiledMul(const OrtApi& ort_api, const OrtLogger& logger, + const std::unordered_map& float_initializers, + std::string input0_name, std::string input1_name) : ort_api(ort_api), logger(logger), float_initializers(float_initializers), @@ -52,7 +52,7 @@ struct MulKernel { OrtStatus* Compute(OrtKernelContext* kernel_ctx) { RETURN_IF_ERROR(ort_api.Logger_LogMessage(&logger, OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO, - "MulKernel::Compute", ORT_FILE, __LINE__, __FUNCTION__)); + "CompiledMul::Compute", ORT_FILE, __LINE__, __FUNCTION__)); Ort::KernelContext kernel_context(kernel_ctx); try { gsl::span input0; @@ -100,7 +100,7 @@ struct MulKernel { size_t num_outputs = kernel_context.GetOutputCount(); if (num_outputs != 1) { - throw Ort::Exception("Expected 1 output for MulKernel", ORT_INVALID_ARGUMENT); + throw Ort::Exception("Expected 1 output for CompiledMul", ORT_INVALID_ARGUMENT); } auto output = kernel_context.GetOutput(0, shape0); @@ -159,6 +159,7 @@ ExampleEp::ExampleEp(ExampleEpFactory& factory, const std::string& name, const C ReleaseNodeComputeInfos = ReleaseNodeComputeInfosImpl; CreateAllocator = CreateAllocatorImpl; // optional. can be nullptr CreateSyncStreamForDevice = CreateSyncStreamForDeviceImpl; // optional. can be nullptr + GetKernelRegistry = GetKernelRegistryImpl; // optional. can be nullptr auto status = ort_api.Logger_LogMessage(&logger_, OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO, @@ -347,12 +348,12 @@ OrtStatus* ORT_API_CALL ExampleEp::CompileImpl(_In_ OrtEp* this_ptr, _In_ const return status.release(); } - // Associate the name of the fused node with our MulKernel. + // Associate the name of the fused node with our CompiledMul. auto fused_node_name = fused_node.GetName(); - ep->kernels_.emplace(std::move(fused_node_name), std::make_unique(ep->ort_api, ep->logger_, - ep->float_initializers_, - node_input_names[0], - node_input_names[1])); + ep->compiled_subgraphs_.emplace(std::move(fused_node_name), std::make_unique(ep->ort_api, ep->logger_, + ep->float_initializers_, + node_input_names[0], + node_input_names[1])); // Update the OrtNodeComputeInfo associated with the graph. auto node_compute_info = std::make_unique(*ep); @@ -385,6 +386,19 @@ void ORT_API_CALL ExampleEp::ReleaseNodeComputeInfosImpl(OrtEp* this_ptr, } } +/*static*/ +OrtStatus* ORT_API_CALL ExampleEp::GetKernelRegistryImpl( + _In_ OrtEp* this_ptr, + _Outptr_result_maybenull_ const OrtKernelRegistry** kernel_registry) noexcept { + ExampleEp* ep = static_cast(this_ptr); + + *kernel_registry = nullptr; + + // Get the cached kernel registry from parent factory to avoid recreating the kernel registry for every EP instance. + RETURN_IF_ERROR(ep->factory_.GetKernelRegistryForEp(*ep, kernel_registry)); + return nullptr; +} + // Creates EPContext nodes from the given fused nodes. // This is an example implementation that can be used to generate an EPContext model. However, this example EP // cannot currently run the EPContext model. @@ -516,27 +530,27 @@ OrtStatus* ExampleNodeComputeInfo::CreateStateImpl(OrtNodeComputeInfo* this_ptr, ExampleEp& ep = node_compute_info->ep; std::string fused_node_name = ep.ep_api.NodeComputeContext_NodeName(compute_context); - auto kernel_it = ep.Kernels().find(fused_node_name); - if (kernel_it == ep.Kernels().end()) { - std::string message = "Unable to get kernel for fused node with name " + fused_node_name; + auto subgraph_it = ep.CompiledSubgraphs().find(fused_node_name); + if (subgraph_it == ep.CompiledSubgraphs().end()) { + std::string message = "Unable to get compiled subgraph for fused node with name " + fused_node_name; return ep.ort_api.CreateStatus(ORT_EP_FAIL, message.c_str()); } - MulKernel& kernel = *kernel_it->second; - *compute_state = &kernel; + CompiledMul& subgraph = *subgraph_it->second; + *compute_state = &subgraph; return nullptr; } OrtStatus* ExampleNodeComputeInfo::ComputeImpl(OrtNodeComputeInfo* this_ptr, void* compute_state, OrtKernelContext* kernel_context) { (void)this_ptr; - MulKernel& kernel = *reinterpret_cast(compute_state); - return kernel.Compute(kernel_context); + CompiledMul& subgraph = *reinterpret_cast(compute_state); + return subgraph.Compute(kernel_context); } void ExampleNodeComputeInfo::ReleaseStateImpl(OrtNodeComputeInfo* this_ptr, void* compute_state) { (void)this_ptr; - MulKernel& kernel = *reinterpret_cast(compute_state); - (void)kernel; + CompiledMul& subgraph = *reinterpret_cast(compute_state); + (void)subgraph; // Do nothing for this example. } diff --git a/onnxruntime/test/autoep/library/ep.h b/onnxruntime/test/autoep/library/ep.h index 279925a7ec3e1..5a2bafeed2ddb 100644 --- a/onnxruntime/test/autoep/library/ep.h +++ b/onnxruntime/test/autoep/library/ep.h @@ -8,7 +8,7 @@ #include "example_plugin_ep_utils.h" class ExampleEpFactory; -struct MulKernel; +struct CompiledMul; /// /// Example EP that can compile a single Mul operator. @@ -24,8 +24,8 @@ class ExampleEp : public OrtEp, public ApiPtrs { ~ExampleEp(); - std::unordered_map>& Kernels() { - return kernels_; + std::unordered_map>& CompiledSubgraphs() { + return compiled_subgraphs_; } private: @@ -51,6 +51,10 @@ class ExampleEp : public OrtEp, public ApiPtrs { OrtNodeComputeInfo** node_compute_infos, size_t num_node_compute_infos) noexcept; + static OrtStatus* ORT_API_CALL GetKernelRegistryImpl( + _In_ OrtEp* this_ptr, + _Outptr_result_maybenull_ const OrtKernelRegistry** kernel_registry) noexcept; + OrtStatus* CreateEpContextNodes(gsl::span fused_nodes, /*out*/ gsl::span ep_context_nodes); @@ -60,6 +64,6 @@ class ExampleEp : public OrtEp, public ApiPtrs { std::string name_; Config config_{}; const OrtLogger& logger_; - std::unordered_map> kernels_; + std::unordered_map> compiled_subgraphs_; std::unordered_map float_initializers_; }; diff --git a/onnxruntime/test/autoep/library/ep_factory.cc b/onnxruntime/test/autoep/library/ep_factory.cc index 4da7d722a5e0b..a6a1ccdb8bae3 100644 --- a/onnxruntime/test/autoep/library/ep_factory.cc +++ b/onnxruntime/test/autoep/library/ep_factory.cc @@ -9,6 +9,7 @@ #include "ep_allocator.h" #include "ep_arena.h" #include "ep_data_transfer.h" +#include "ep_kernel_registration.h" #include "ep_stream_support.h" ExampleEpFactory::ExampleEpFactory(const char* ep_name, ApiPtrs apis, const OrtLogger& default_logger) @@ -73,6 +74,12 @@ ExampleEpFactory::ExampleEpFactory(const char* ep_name, ApiPtrs apis, const OrtL ort_api.ReleaseMemoryInfo(mem_info); } +ExampleEpFactory::~ExampleEpFactory() { + if (kernel_registry_ != nullptr) { + Ort::GetEpApi().ReleaseKernelRegistry(kernel_registry_); + } +} + /*static*/ const char* ORT_API_CALL ExampleEpFactory::GetNameImpl(const OrtEpFactory* this_ptr) noexcept { const auto* factory = static_cast(this_ptr); @@ -312,3 +319,27 @@ OrtStatus* ORT_API_CALL ExampleEpFactory::CreateSyncStreamForDeviceImpl(OrtEpFac return nullptr; } + +OrtStatus* ExampleEpFactory::GetKernelRegistryForEp(ExampleEp& ep, const OrtKernelRegistry** out_kernel_registry) { + *out_kernel_registry = nullptr; + + if (GetNumKernels() == 0) { + return nullptr; + } + + if (kernel_registry_ == nullptr) { + void* op_kernel_state = &ep; // Optional state that is provided to kernels on creation (can be null). + // This example just passes the entire OrtEp to the kernel. + + const char* ep_name = ep.GetName(static_cast(&ep)); + + // This statement creates the kernel registry and caches it in the OrtEpFactory instance. + // We assume that all EPs created by this factory can use the same kernel registry. This may not be the + // case in a more complex OrtEpFactory that can create EP instances that are each configured for different + // hardware devices. In such a scenario, a different kernel registry may be created for each EP configuration. + RETURN_IF_ERROR(CreateKernelRegistry(ep_name, op_kernel_state, &kernel_registry_)); + } + + *out_kernel_registry = kernel_registry_; + return nullptr; +} diff --git a/onnxruntime/test/autoep/library/ep_factory.h b/onnxruntime/test/autoep/library/ep_factory.h index 088deda1fe9d2..46f6d9aef9d46 100644 --- a/onnxruntime/test/autoep/library/ep_factory.h +++ b/onnxruntime/test/autoep/library/ep_factory.h @@ -9,12 +9,15 @@ #include "ep_data_transfer.h" #include "example_plugin_ep_utils.h" +class ExampleEp; + /// /// Example EP factory that can create an OrtEp and return information about the supported hardware devices. /// class ExampleEpFactory : public OrtEpFactory, public ApiPtrs { public: ExampleEpFactory(const char* ep_name, ApiPtrs apis, const OrtLogger& default_logger); + ~ExampleEpFactory(); OrtDataTransferImpl* GetDataTransfer() const { return data_transfer_impl_.get(); @@ -25,6 +28,9 @@ class ExampleEpFactory : public OrtEpFactory, public ApiPtrs { return arena_allocator_.get(); } + // Called by child OrtEp instances to retrieve the cached kernel registry for that EP. + OrtStatus* GetKernelRegistryForEp(ExampleEp& ep, /*out*/ const OrtKernelRegistry** kernel_registry); + private: static const char* ORT_API_CALL GetNameImpl(const OrtEpFactory* this_ptr) noexcept; @@ -84,4 +90,10 @@ class ExampleEpFactory : public OrtEpFactory, public ApiPtrs { std::mutex mutex_; // mutex to protect arena_allocator_ and num_arena_users_ std::unique_ptr data_transfer_impl_; // data transfer implementation for this factory + + // Cached kernel registry used by all OrtEp instances created by this factory. Refer to OrtEp::GetKernelRegistry. + // + // Note: If this factory instead created EP instances that each supported different hardware configurations, then + // the factory could cache a different kernel registry per EP configuration. + OrtKernelRegistry* kernel_registry_ = nullptr; }; diff --git a/onnxruntime/test/autoep/library/ep_kernel_registration.cc b/onnxruntime/test/autoep/library/ep_kernel_registration.cc new file mode 100644 index 0000000000000..1849c793a6760 --- /dev/null +++ b/onnxruntime/test/autoep/library/ep_kernel_registration.cc @@ -0,0 +1,69 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include + +#include "ep_kernel_registration.h" +#include "kernels/utils.h" + +// Include kernels: +#include "kernels/memcpy.h" + +// Forward declarations of kernel classes used as template args for BuildKernelCreateInfo +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kOnnxDomain, 1, MemcpyFromHost); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kOnnxDomain, 1, MemcpyToHost); + +// Table of BuildKernelCreateInfo functions for each operator +static const BuildKernelCreateInfoFn build_kernel_create_info_funcs[] = { + BuildKernelCreateInfo, // Dummy to avoid table becoming empty. + BuildKernelCreateInfo, + BuildKernelCreateInfo, +}; + +constexpr size_t num_kernel_create_info_funcs = sizeof(build_kernel_create_info_funcs) / + sizeof(build_kernel_create_info_funcs[0]); + +size_t GetNumKernels() { + static_assert(num_kernel_create_info_funcs >= 1); + return num_kernel_create_info_funcs - 1; +} + +OrtStatus* CreateKernelRegistry(const char* ep_name, void* create_kernel_state, OrtKernelRegistry** kernel_registry) { + *kernel_registry = nullptr; + + if (GetNumKernels() == 0) { + return nullptr; + } + + const OrtEpApi& ep_api = Ort::GetEpApi(); + RETURN_IF_ERROR(ep_api.CreateKernelRegistry(kernel_registry)); + + OrtStatus* status = nullptr; + + // Add kernel creation info to registry + for (auto& build_func : build_kernel_create_info_funcs) { + KernelCreateInfo kernel_create_info = {}; + status = build_func(ep_name, create_kernel_state, &kernel_create_info); + + if (status != nullptr) { + break; + } + + if (kernel_create_info.kernel_def != nullptr) { + status = ep_api.KernelRegistry_AddKernel(*kernel_registry, + kernel_create_info.kernel_def, // copied + kernel_create_info.kernel_create_func, + kernel_create_info.kernel_create_func_state); + if (status != nullptr) { + break; + } + } + } + + if (status != nullptr) { + ep_api.ReleaseKernelRegistry(*kernel_registry); + *kernel_registry = nullptr; + } + + return status; +} diff --git a/onnxruntime/test/autoep/library/ep_kernel_registration.h b/onnxruntime/test/autoep/library/ep_kernel_registration.h new file mode 100644 index 0000000000000..2a392828029be --- /dev/null +++ b/onnxruntime/test/autoep/library/ep_kernel_registration.h @@ -0,0 +1,12 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include + +#include "example_plugin_ep_utils.h" + +size_t GetNumKernels(); + +OrtStatus* CreateKernelRegistry(const char* ep_name, void* create_kernel_state, OrtKernelRegistry** kernel_registry); diff --git a/onnxruntime/test/autoep/library/kernels/data_types.cc b/onnxruntime/test/autoep/library/kernels/data_types.cc new file mode 100644 index 0000000000000..8662928517af6 --- /dev/null +++ b/onnxruntime/test/autoep/library/kernels/data_types.cc @@ -0,0 +1,85 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "data_types.h" + +MLDataTypes::MLDataTypes() {} + +/*static*/ +MLDataTypes& MLDataTypes::GetInstance() { + static MLDataTypes instance; + return instance; +} + +/*static*/ +OrtStatus* MLDataTypes::GetTensorType(ONNXTensorElementDataType elem_type, /*out*/ const OrtMLDataType*& tensor_type) { + MLDataTypes& instance = GetInstance(); + const OrtEpApi& ep_api = Ort::GetEpApi(); + + auto iter = instance.tensor_types_map_.find(elem_type); + if (iter == instance.tensor_types_map_.end()) { + const OrtMLDataType* type = nullptr; + + RETURN_IF_ERROR(ep_api.GetTensorMLDataType(elem_type, &type)); + instance.tensor_types_map_.emplace(elem_type, type); + + tensor_type = type; + return nullptr; + } + + tensor_type = iter->second; + return nullptr; +} + +/*static*/ +const OrtMLDataType* MLDataTypes::GetTensorType(ONNXTensorElementDataType elem_type) { + const OrtMLDataType* result = nullptr; + Ort::ThrowOnError(MLDataTypes::GetTensorType(elem_type, result)); + return result; +} + +/*static*/ +OrtStatus* MLDataTypes::GetAllFixedSizeTensorTypesIRv9(/*out*/ std::vector& result) { + MLDataTypes& instance = GetInstance(); + const OrtEpApi& ep_api = Ort::GetEpApi(); + + if (instance.fixed_tensor_v9_.empty()) { + auto add_tensor_type = [&instance, &ep_api](ONNXTensorElementDataType elem_type) -> OrtStatus* { + const OrtMLDataType* type = nullptr; + + RETURN_IF_ERROR(instance.GetTensorType(elem_type, type)); + instance.fixed_tensor_v9_.push_back(type); + return nullptr; + }; + + RETURN_IF_ERROR(add_tensor_type(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT)); + RETURN_IF_ERROR(add_tensor_type(ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8)); + RETURN_IF_ERROR(add_tensor_type(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8)); + RETURN_IF_ERROR(add_tensor_type(ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16)); + RETURN_IF_ERROR(add_tensor_type(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16)); + RETURN_IF_ERROR(add_tensor_type(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32)); + RETURN_IF_ERROR(add_tensor_type(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64)); + RETURN_IF_ERROR(add_tensor_type(ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL)); + RETURN_IF_ERROR(add_tensor_type(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16)); + RETURN_IF_ERROR(add_tensor_type(ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE)); + RETURN_IF_ERROR(add_tensor_type(ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32)); + RETURN_IF_ERROR(add_tensor_type(ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64)); + RETURN_IF_ERROR(add_tensor_type(ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16)); + RETURN_IF_ERROR(add_tensor_type(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN)); + RETURN_IF_ERROR(add_tensor_type(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FNUZ)); + RETURN_IF_ERROR(add_tensor_type(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2)); + RETURN_IF_ERROR(add_tensor_type(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ)); + RETURN_IF_ERROR(add_tensor_type(ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT4)); + RETURN_IF_ERROR(add_tensor_type(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4)); + } + + result = instance.fixed_tensor_v9_; + return nullptr; +} + +/*static*/ +std::vector MLDataTypes::GetAllFixedSizeTensorTypesIRv9() { + std::vector result; + Ort::ThrowOnError(GetInstance().GetAllFixedSizeTensorTypesIRv9(result)); + return result; +} diff --git a/onnxruntime/test/autoep/library/kernels/data_types.h b/onnxruntime/test/autoep/library/kernels/data_types.h new file mode 100644 index 0000000000000..ffb05f09dd175 --- /dev/null +++ b/onnxruntime/test/autoep/library/kernels/data_types.h @@ -0,0 +1,29 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include "../example_plugin_ep_utils.h" + +/// +/// Singleton that returns sets of OrtMLDataType instances using the public C API. +/// Analogous to the internal utilities in include/onnxruntime/core/framework/data_types.h +/// +class MLDataTypes { + public: + static MLDataTypes& GetInstance(); + + static OrtStatus* GetTensorType(ONNXTensorElementDataType elem_type, /*out*/ const OrtMLDataType*& tensor_type); + static const OrtMLDataType* GetTensorType(ONNXTensorElementDataType elem_type); + + static OrtStatus* GetAllFixedSizeTensorTypesIRv9(/*out*/ std::vector& tensor_types); + static std::vector GetAllFixedSizeTensorTypesIRv9(); + + private: + MLDataTypes(); + + std::unordered_map tensor_types_map_; + std::vector fixed_tensor_v9_; +}; diff --git a/onnxruntime/test/autoep/library/kernels/memcpy.cc b/onnxruntime/test/autoep/library/kernels/memcpy.cc new file mode 100644 index 0000000000000..96bb32b94ca09 --- /dev/null +++ b/onnxruntime/test/autoep/library/kernels/memcpy.cc @@ -0,0 +1,96 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "memcpy.h" +#include "utils.h" + +ONNX_OPERATOR_KERNEL_EX( + MemcpyFromHost, + kOnnxDomain, + 1, + (Ort::KernelDefBuilder() + .SetInputMemType(0, OrtMemType::OrtMemTypeCPUInput) + // .AddTypeConstraint("T", MLDataTypes::GetAllFixedSizeTensorTypesIRv9()), + .AddTypeConstraint("T", MLDataTypes::GetTensorType(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT))), + Memcpy) + +ONNX_OPERATOR_KERNEL_EX( + MemcpyToHost, + kOnnxDomain, + 1, + (Ort::KernelDefBuilder() + .SetOutputMemType(0, OrtMemType::OrtMemTypeCPUOutput) + // .AddTypeConstraint("T", MLDataTypes::GetAllFixedSizeTensorTypesIRv9()), + .AddTypeConstraint("T", MLDataTypes::GetTensorType(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT))), + Memcpy) + +Memcpy::Memcpy(const OrtKernelInfo* info, void* state) : info_{info}, state_{state} { + ort_version_supported = ORT_API_VERSION; + Compute = ComputeImpl; + Release = ReleaseImpl; +} + +/*static*/ +OrtStatus* Memcpy::Create(const OrtKernelInfo* info, void* state, + /*out*/ std::unique_ptr& result) { + const OrtApi& ort_api = Ort::GetApi(); + + try { + Ort::ConstKernelInfo kernel_info(info); + + // Basic validation before creating kernel. + size_t num_inputs = kernel_info.GetInputCount(); + size_t num_outputs = kernel_info.GetOutputCount(); + RETURN_IF(num_inputs != 1, ort_api, "Expected only 1 input for Memcpy kernel"); + RETURN_IF(num_outputs != 1, ort_api, "Expected only 1 output for Memcpy kernel"); + + result = std::make_unique(info, state); + } catch (const Ort::Exception& ex) { + Ort::Status status(ex); + return status.release(); + } catch (const std::exception& ex) { + Ort::Status status(ex.what(), ORT_EP_FAIL); + return status.release(); + } + + return nullptr; +} + +/*static*/ +OrtStatus* ORT_API_CALL Memcpy::ComputeImpl(OrtKernelImpl* this_ptr, OrtKernelContext* kernel_ctx) noexcept { + Memcpy* memcpy = static_cast(this_ptr); + return memcpy->DoCompute(kernel_ctx); +} + +/*static*/ +void ORT_API_CALL Memcpy::ReleaseImpl(OrtKernelImpl* this_ptr) noexcept { + delete static_cast(this_ptr); +} + +OrtStatus* Memcpy::DoCompute(OrtKernelContext* kernel_ctx) noexcept { + const OrtEpApi& ep_api = Ort::GetEpApi(); + Ort::KernelContext kernel_context(kernel_ctx); + + try { + Ort::ConstValue input = kernel_context.GetInput(0); + std::vector shape = input.GetTensorTypeAndShapeInfo().GetShape(); + Ort::UnownedValue output = kernel_context.GetOutput(0, shape); + + std::array src_tensors = {input}; + std::array dst_tensors = {output}; + + RETURN_IF_ERROR(ep_api.KernelInfo_CopyTensors(info_, + src_tensors.data(), + dst_tensors.data(), + /*stream*/ nullptr, + src_tensors.size())); + } catch (const Ort::Exception& ex) { + Ort::Status status(ex); + return status.release(); + } catch (const std::exception& ex) { + Ort::Status status(ex.what(), ORT_EP_FAIL); + return status.release(); + } + + return nullptr; +} diff --git a/onnxruntime/test/autoep/library/kernels/memcpy.h b/onnxruntime/test/autoep/library/kernels/memcpy.h new file mode 100644 index 0000000000000..e541ef42ef7fa --- /dev/null +++ b/onnxruntime/test/autoep/library/kernels/memcpy.h @@ -0,0 +1,21 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "../example_plugin_ep_utils.h" + +struct Memcpy : public OrtKernelImpl { + static OrtStatus* Create(const OrtKernelInfo* info, void* state, /*out*/ std::unique_ptr& kernel); + + Memcpy(const OrtKernelInfo* info, void* state); + + static OrtStatus* ORT_API_CALL ComputeImpl(OrtKernelImpl* this_ptr, OrtKernelContext* kernel_ctx) noexcept; + static void ORT_API_CALL ReleaseImpl(OrtKernelImpl* this_ptr) noexcept; + + OrtStatus* DoCompute(OrtKernelContext* kernel_ctx) noexcept; + + private: + const OrtKernelInfo* info_; + void* state_; // Custom state passed from OrtEp +}; diff --git a/onnxruntime/test/autoep/library/kernels/utils.h b/onnxruntime/test/autoep/library/kernels/utils.h new file mode 100644 index 0000000000000..55a6d946573b1 --- /dev/null +++ b/onnxruntime/test/autoep/library/kernels/utils.h @@ -0,0 +1,75 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "../example_plugin_ep_utils.h" +#include "data_types.h" + +/// +/// Contains information to create a kernel: kernel definition, creation function + state. +/// +struct KernelCreateInfo { + KernelCreateInfo() = default; + KernelCreateInfo(Ort::KernelDef def, OrtKernelCreateFunc func, void* state) + : kernel_def{std::move(def)}, kernel_create_func{func}, kernel_create_func_state{state} {} + + Ort::KernelDef kernel_def{nullptr}; + OrtKernelCreateFunc kernel_create_func = nullptr; + void* kernel_create_func_state = nullptr; +}; + +using BuildKernelCreateInfoFn = OrtStatus* (*)(const char*, void*, KernelCreateInfo*); + +template +OrtStatus* BuildKernelCreateInfo(const char* ep_name, void* create_func_state, /*out*/ KernelCreateInfo* result); + +template <> +inline OrtStatus* BuildKernelCreateInfo(const char* /*ep_name*/, void* /*create_func_state*/, + /*out*/ KernelCreateInfo* result) { + result->kernel_def = Ort::KernelDef{nullptr}; + result->kernel_create_func = nullptr; + result->kernel_create_func_state = nullptr; + return nullptr; +} + +static constexpr const char* kOnnxDomain = ""; + +// Naming convention for operator kernel classes +#define ONNX_OPERATOR_KERNEL_CLASS_NAME(domain, ver, name) \ + example_ep_##name##_##domain##_ver##ver + +#define ONNX_OPERATOR_KERNEL_EX(name, domain, ver, builder, kernel_class) \ + class ONNX_OPERATOR_KERNEL_CLASS_NAME(domain, ver, name); \ + template <> \ + OrtStatus* \ + BuildKernelCreateInfo(const char* ep_name, \ + void* create_kernel_state, \ + KernelCreateInfo* result) { \ + try { \ + Ort::KernelDef kernel_def = builder.SetOperatorType(#name) \ + .SetDomain(domain) \ + .SetSinceVersion(ver) \ + .SetExecutionProvider(ep_name) \ + .Build(); \ + \ + auto kernel_create_func = [](OrtKernelCreateContext* /*ctx*/, void* state, const OrtKernelInfo* info, \ + OrtKernelImpl** kernel_out) noexcept -> OrtStatus* { \ + *kernel_out = nullptr; \ + \ + std::unique_ptr kernel; \ + RETURN_IF_ERROR(kernel_class::Create(info, state, kernel)); \ + *kernel_out = kernel.release(); \ + return nullptr; \ + }; \ + \ + *result = KernelCreateInfo(std::move(kernel_def), kernel_create_func, create_kernel_state); \ + } catch (const Ort::Exception& ex) { \ + Ort::Status status(ex); \ + return status.release(); \ + } catch (const std::exception& ex) { \ + Ort::Status status(ex.what(), ORT_EP_FAIL); \ + return status.release(); \ + } \ + return nullptr; \ + } diff --git a/onnxruntime/test/autoep/test_execution.cc b/onnxruntime/test/autoep/test_execution.cc index 78be22d082692..11e05952755fc 100644 --- a/onnxruntime/test/autoep/test_execution.cc +++ b/onnxruntime/test/autoep/test_execution.cc @@ -111,6 +111,9 @@ TEST(OrtEpLibrary, PluginEp_PreferCpu_MulInference) { } } +// Runs a model on a plugin EP that only supports one of three model nodes. +// Because the plugin EP pretends to run on GPU, this unit test will also use the plugin EP's +// kernel registry to create Memcpy ops that copy the I/O to/from the EP. TEST(OrtEpLibrary, PluginEp_AppendV2_PartiallySupportedModelInference) { RegisteredEpDeviceUniquePtr example_ep; Utils::RegisterAndGetExampleEp(*ort_env, example_ep); diff --git a/onnxruntime/test/framework/ep_plugin_provider_test.cc b/onnxruntime/test/framework/ep_plugin_provider_test.cc index 30595d5ce97b2..0d19dca104dde 100644 --- a/onnxruntime/test/framework/ep_plugin_provider_test.cc +++ b/onnxruntime/test/framework/ep_plugin_provider_test.cc @@ -8,6 +8,8 @@ #include "gtest/gtest.h" #include "core/common/logging/sinks/file_sink.h" +#include "core/framework/kernel_def_builder.h" +#include "core/framework/op_kernel.h" #include "core/graph/graph_viewer.h" #include "core/graph/model.h" #include "core/optimizer/graph_optimizer_registry.h" @@ -36,6 +38,14 @@ static void CheckStringInFile(const PathString& filename, const std::string& loo EXPECT_NE(content.find(look_for), std::string::npos); } +static void CheckFileIsEmpty(const PathString& filename) { + std::ifstream ifs{filename}; + std::string content(std::istreambuf_iterator{ifs}, + std::istreambuf_iterator{}); + + EXPECT_TRUE(content.empty()); +} + // Normally, a plugin EP would be implemented in a separate library. // The `test_plugin_ep` namespace contains a local implementation intended for unit testing. namespace test_plugin_ep { @@ -121,14 +131,25 @@ MakeTestOrtEpResult MakeTestOrtEp(std::vector ep_devices = { *static_cast(ort_session_options), g_test_ort_ep_factory, ep_devices, + /*kernel_registry*/ nullptr, logging_manager.DefaultLogger()); auto result = MakeTestOrtEpResult{std::move(ep), ort_ep_raw}; return result; } +using LookUpKernelFunc = std::function; + class MockKernelLookup : public IExecutionProvider::IKernelLookup { - const KernelCreateInfo* LookUpKernel(const Node& /*node*/) const override { return nullptr; } + public: + explicit MockKernelLookup(LookUpKernelFunc lookup = nullptr) : lookup_{lookup} {} + + const KernelCreateInfo* LookUpKernel(const Node& node) const override { + return lookup_ != nullptr ? lookup_(node) : nullptr; + } + + private: + LookUpKernelFunc lookup_ = nullptr; }; } // namespace test_plugin_ep @@ -435,10 +456,23 @@ static OrtStatus* ORT_API_CALL GetCapabilityTakeSingleNode(OrtEp* this_ptr, cons return st; } - // Take only the first node using EpGraphSupportInfo_AddSingleNode(). - if (OrtStatus* st = this_ep->ep_api->EpGraphSupportInfo_AddSingleNode(graph_support_info, nodes[0]); - st != nullptr) { - return st; + // Take only the first node that has a registered kernel for this EP. + for (const OrtNode* node : nodes) { + const OrtKernelDef* kernel_def = nullptr; + OrtStatus* status = this_ep->ep_api->EpGraphSupportInfo_LookUpKernel(graph_support_info, node, &kernel_def); + + if (status != nullptr) { + return status; + } + + if (kernel_def != nullptr) { + if (OrtStatus* st = this_ep->ep_api->EpGraphSupportInfo_AddSingleNode(graph_support_info, node); + st != nullptr) { + return st; + } + + break; + } } return nullptr; @@ -454,7 +488,8 @@ TEST(PluginExecutionProviderTest, GetCapability_ClaimNodesAssignedToOtherEP) { auto run_test = [&log_file](IExecutionProvider& ep, const std::unordered_set& nodes_for_other_ep, const std::unordered_set& nodes_for_this_ep, - const char* expected_log_string) { + const char* expected_log_string, + test_plugin_ep::LookUpKernelFunc lookup_kernel_func = nullptr) { std::shared_ptr model; ASSERT_NO_FATAL_FAILURE(LoadModelAndAssignNodesToEp(ORT_TSTR("testdata/add_mul_add.onnx"), "OtherEp", nodes_for_other_ep, model)); @@ -471,7 +506,7 @@ TEST(PluginExecutionProviderTest, GetCapability_ClaimNodesAssignedToOtherEP) { GraphViewer graph_viewer(model->MainGraph()); auto compute_capabilities = ep.GetCapability(graph_viewer, - test_plugin_ep::MockKernelLookup{}, + test_plugin_ep::MockKernelLookup(lookup_kernel_func), GraphOptimizerRegistry(nullptr, nullptr, file_logger.get()), nullptr); @@ -489,7 +524,12 @@ TEST(PluginExecutionProviderTest, GetCapability_ClaimNodesAssignedToOtherEP) { } ASSERT_TRUE(std::filesystem::exists(log_file)); - EXPECT_NO_FATAL_FAILURE(CheckStringInFile(log_file, expected_log_string)); + + if (expected_log_string != nullptr) { + EXPECT_NO_FATAL_FAILURE(CheckStringInFile(log_file, expected_log_string)); + } else { + EXPECT_NO_FATAL_FAILURE(CheckFileIsEmpty(log_file)); + } }; constexpr std::array node_names = {"add_0", "mul_0", "add_1"}; @@ -536,6 +576,19 @@ TEST(PluginExecutionProviderTest, GetCapability_ClaimNodesAssignedToOtherEP) { run_test(*ep, nodes_for_other_ep, nodes_for_this_ep, "Found one or more nodes that were already assigned to a different EP named 'OtherEp'"); + // Build dummy kernel definition for an Add node. Retrieved by OrtEp using EpGraphSupportInfo_LookUpKernel(). + KernelDefBuilder builder; + builder.SetName("Add").SinceVersion(1).Provider("TestOrtEp"); + auto add_kernel_create_info = std::make_unique(builder.Build(), nullptr); + + auto mock_kernel_lookup_fn = [&add_kernel_create_info](const Node& node) -> const KernelCreateInfo* { + // Only return a result for an Add node. + if (add_kernel_create_info->kernel_def->OpName() == node.OpType()) { + return add_kernel_create_info.get(); + } + return nullptr; + }; + // Load a model and assign the first Add node to another EP named 'OtherEp'. // The plugin EP will try to take only the first Add node with a single call to EpGraphSupportInfo_AddSingleNode. // IExecutionProvider::GetCapability() will return an empty result and log a warning. @@ -543,9 +596,60 @@ TEST(PluginExecutionProviderTest, GetCapability_ClaimNodesAssignedToOtherEP) { nodes_for_other_ep = std::unordered_set{"add_0"}; nodes_for_this_ep = std::unordered_set{}; run_test(*ep, nodes_for_other_ep, nodes_for_this_ep, - "Found one or more nodes that were already assigned to a different EP named 'OtherEp'"); + "Found one or more nodes that were already assigned to a different EP named 'OtherEp'", + mock_kernel_lookup_fn); + + // Load a model and assign the last Add node to another EP named 'OtherEp'. + // The plugin EP will try to take only the first Add node with a single call to EpGraphSupportInfo_AddSingleNode. + // IExecutionProvider::GetCapability() will return a single capability and will not log warnings. + ort_ep->GetCapability = GetCapabilityTakeSingleNode; + nodes_for_other_ep = std::unordered_set{"add_1"}; + nodes_for_this_ep = std::unordered_set{"add_0"}; + run_test(*ep, nodes_for_other_ep, nodes_for_this_ep, + /*expected_log_string*/ nullptr, mock_kernel_lookup_fn); std::filesystem::remove(log_file); } +TEST(PluginExecutionProviderTest, KernelDefCxxApis) { + auto check_kernel_def = [&](const KernelDef& expected, Ort::ConstKernelDef actual) -> void { + EXPECT_EQ(expected.OpName(), actual.GetOperatorType()); + EXPECT_EQ(expected.Domain(), actual.GetDomain()); + EXPECT_EQ(expected.SinceVersion(), actual.GetSinceVersion()); + EXPECT_EQ(expected.Provider(), actual.GetExecutionProvider()); + EXPECT_EQ(expected.InputMemoryType(0), actual.GetInputMemType(0)); + EXPECT_EQ(expected.InputMemoryType(1), actual.GetInputMemType(1)); + EXPECT_EQ(expected.OutputMemoryType(1), actual.GetOutputMemType(1)); + }; + + // Check that C++ APIs for Ort::KernelDef return the expected values. + { + KernelDefBuilder builder; + std::unique_ptr expected_def = builder.SetName("Mul") + .SetDomain("TestDomain") + .SinceVersion(3, 13) + .Provider("TestOrtEp") + .InputMemoryType(OrtMemTypeCPUInput, 0) + .InputMemoryType(OrtMemTypeCPUInput, 1) + .OutputMemoryType(OrtMemTypeCPUOutput, 1) + .Build(); + + Ort::ConstKernelDef actual_def(reinterpret_cast(expected_def.get())); + EXPECT_NO_FATAL_FAILURE(check_kernel_def(*expected_def, actual_def)); + } + + // SinceVersion with no explicit end (defaults to -1) + { + KernelDefBuilder builder; + std::unique_ptr expected_def = builder.SetName("Mul") + .SetDomain("TestDomain") + .Provider("TestOrtEp") + .SinceVersion(3) // end should default to -1 + .Build(); + + Ort::ConstKernelDef actual_def(reinterpret_cast(expected_def.get())); + EXPECT_NO_FATAL_FAILURE(check_kernel_def(*expected_def, actual_def)); + } +} + } // namespace onnxruntime::test