Skip to content
Open
Show file tree
Hide file tree
Changes from 30 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
8c03c61
Sketch API funcs
adrianlizarraga Sep 23, 2025
40d7866
Merge branch 'main' into adrianl/ep-abi-kernel-based-eps
adrianlizarraga Sep 26, 2025
ec46ea5
Implement more c apis
adrianlizarraga Sep 27, 2025
e8532a9
OpKernel class for plugin EPs
adrianlizarraga Sep 27, 2025
246681c
Initialize PluginExecutionProvider's kernel registry
adrianlizarraga Sep 28, 2025
6b61b91
Move files to session
adrianlizarraga Sep 28, 2025
4a112fc
Add API to set kernel def I/O memory types
adrianlizarraga Sep 28, 2025
0f870d0
Add C API to add type constraints to a kernel definition
adrianlizarraga Sep 28, 2025
9be6923
Start implementing MemcpyFromHost kernel in example EP
adrianlizarraga Sep 28, 2025
5df3fb5
Get kernel for MemcpyFromHost working for example plugin EP!
adrianlizarraga Sep 28, 2025
5aade60
Moved example plugin EP's kernel stuff to different file
adrianlizarraga Sep 28, 2025
d159b38
Add separate utility to load OrtMLDataTypes
adrianlizarraga Sep 29, 2025
8187233
Add MLDataTypes::GetTensorType()
adrianlizarraga Sep 29, 2025
bf92f04
Add C++ Ort::KernelDefBuilder to allow creation of macro
adrianlizarraga Sep 29, 2025
da14d65
Create macro for defining BuildKernelCreateInfo functions for each op
adrianlizarraga Sep 30, 2025
73ae307
Move kernels to separate directory/files
adrianlizarraga Sep 30, 2025
fb4a6a6
Use data transfer in MemcpyFromHost and MemcpyToHost
adrianlizarraga Sep 30, 2025
b8867d6
Release OrtKernelCreateInfo instances if an error occurs
adrianlizarraga Sep 30, 2025
de8be32
Move typedef and add forward-declaration of OrtKernelImpl for gcc
adrianlizarraga Sep 30, 2025
dc78ec3
Merge branch 'main' into adrianl/ep-abi-kernel-based-eps
adrianlizarraga Sep 30, 2025
8babb63
Apply suggestions from code review
adrianlizarraga Oct 1, 2025
90bf598
Simplify with OrtKernelRegistry
adrianlizarraga Oct 2, 2025
9f95589
Pass custom state to kernel creation in plugin EP
adrianlizarraga Oct 2, 2025
90e4fc1
Clean up
adrianlizarraga Oct 2, 2025
5f52cdc
ExampleEp: cache kernel registry in factory so it can be reused by al…
adrianlizarraga Oct 2, 2025
9bfc281
Add C API to lookup a kernel from within OrtEp::GetCapability
adrianlizarraga Oct 2, 2025
60ea06c
Disambiguate a compiled subgraph (of one node) from a registered kern…
adrianlizarraga Oct 2, 2025
33ffd8d
Add unit test for EpGraphSupportInfo_LookUpKernel()
adrianlizarraga Oct 2, 2025
31cdc82
Add missing include needed for linux ci
adrianlizarraga Oct 2, 2025
93b99e6
Add KernelDef to C++ api and add basic getters
adrianlizarraga Oct 2, 2025
995e25b
Add documentation comments
adrianlizarraga Oct 3, 2025
0f7145f
Remove incorrect comment
adrianlizarraga Oct 3, 2025
f528f6f
Add missing API_IMPL_BEGIN/END macro calls
adrianlizarraga Oct 6, 2025
9fbf230
Merge branch 'main' into adrianl/ep-abi-kernel-based-eps
adrianlizarraga Oct 6, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion cmake/onnxruntime_unittests.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -1976,7 +1976,9 @@ if (WIN32 AND onnxruntime_BUILD_SHARED_LIB AND
NOT onnxruntime_MINIMAL_BUILD)
# example_plugin_ep
file(GLOB onnxruntime_autoep_test_library_src "${TEST_SRC_DIR}/autoep/library/*.h"
"${TEST_SRC_DIR}/autoep/library/*.cc")
"${TEST_SRC_DIR}/autoep/library/*.cc"
"${TEST_SRC_DIR}/autoep/library/kernels/*.h"
"${TEST_SRC_DIR}/autoep/library/kernels/*.cc")
onnxruntime_add_shared_library_module(example_plugin_ep ${onnxruntime_autoep_test_library_src})
target_include_directories(example_plugin_ep PRIVATE ${REPO_ROOT}/include/onnxruntime/core/session)
target_link_libraries(example_plugin_ep PRIVATE onnxruntime)
Expand Down
61 changes: 61 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_cxx_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -644,6 +644,8 @@ ORT_DEFINE_RELEASE(ValueInfo);

ORT_DEFINE_RELEASE_FROM_API_STRUCT(ModelCompilationOptions, GetCompileApi);
ORT_DEFINE_RELEASE_FROM_API_STRUCT(EpDevice, GetEpApi);
ORT_DEFINE_RELEASE_FROM_API_STRUCT(KernelDef, GetEpApi);
ORT_DEFINE_RELEASE_FROM_API_STRUCT(KernelDefBuilder, GetEpApi);

// This is defined explicitly since OrtTensorRTProviderOptionsV2 is not a C API type,
// but the struct has V2 in its name to indicate that it is the second version of the options.
Expand Down Expand Up @@ -3284,5 +3286,64 @@ struct Model : detail::ModelImpl<OrtModel> {
explicit Model(const std::vector<DomainOpsetPair>& opsets);
#endif
};

namespace detail {
template <typename T>
struct ConstKernelDefImpl : Base<T> {
using B = Base<T>;
using B::B;

///< Wraps OrtEpApi::KernelDef_GetOperatorType
const char* GetOperatorType() const;

///< Wraps OrtEpApi::KernelDef_GetDomain
const char* GetDomain() const;

///< Wraps OrtEpApi::KernelDef_GetSinceVersion
std::pair<int, int> GetSinceVersion() const;

///< Wraps OrtEpApi::KernelDef_GetExecutionProvider
const char* GetExecutionProvider() const;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If any of the information for any getters is optional, suggest returning a status

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The underlying C API function OrtEpApi::KernelDef_GetExecutionProvider returns the const char* directly (doesn't return a status).

Sorry, I don't fully understand the comment. Is the request to make this return a status instead? What is meant by "information for any getters is optional"?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was a general comment not to be bound to throwing exceptions by default. In this case, the data is returned without error reporting.


///< Wraps OrtEpApi::KernelDef_GetInputMemType
OrtMemType GetInputMemType(size_t input_index) const;

///< Wraps OrtEpApi::KernelDef_GetOutputMemType
OrtMemType GetOutputMemType(size_t output_index) const;
};
} // namespace detail

using ConstKernelDef = detail::ConstKernelDefImpl<detail::Unowned<const OrtKernelDef>>;

struct KernelDef : detail::ConstKernelDefImpl<OrtKernelDef> {
using Base = detail::ConstKernelDefImpl<OrtKernelDef>;
using Base::Base;

explicit KernelDef(std::nullptr_t) {}
explicit KernelDef(OrtKernelDef* p) : detail::ConstKernelDefImpl<OrtKernelDef>{p} {}

ConstKernelDef GetConst() const { return ConstKernelDef{this->p_}; }
};

/** \brief Builder for OrtKernelDef.
*
* Used by plugin EPs to build a kernel definition.
*/
struct KernelDefBuilder : detail::Base<OrtKernelDefBuilder> {
KernelDefBuilder(); ///< Wraps OrtEpApi::CreateKernelDefBuilder
explicit KernelDefBuilder(std::nullptr_t) {} ///< Create an empty object, must be assigned a valid one to be used
explicit KernelDefBuilder(OrtKernelDefBuilder* ort_kernel_def_builder);

KernelDefBuilder& SetOperatorType(const char* op_type);
KernelDefBuilder& SetDomain(const char* domain);
KernelDefBuilder& SetSinceVersion(int since_version_start, int since_version_end = -1);
KernelDefBuilder& SetExecutionProvider(const char* ep_name);
KernelDefBuilder& SetInputMemType(size_t input_index, OrtMemType mem_type);
KernelDefBuilder& SetOutputMemType(size_t output_index, OrtMemType mem_type);
KernelDefBuilder& AddTypeConstraint(const char* arg_name, const OrtMLDataType* data_types);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: it is one data type with this overload, right?

Suggested change
KernelDefBuilder& AddTypeConstraint(const char* arg_name, const OrtMLDataType* data_types);
KernelDefBuilder& AddTypeConstraint(const char* arg_name, const OrtMLDataType* data_type);

KernelDefBuilder& AddTypeConstraint(const char* arg_name, const std::vector<const OrtMLDataType*>& data_types);

KernelDef Build();
};
} // namespace Ort
#include "onnxruntime_cxx_inline.h"
97 changes: 97 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_cxx_inline.h
Original file line number Diff line number Diff line change
Expand Up @@ -3553,4 +3553,101 @@
}
#endif

namespace detail {
template <typename T>
inline const char* ConstKernelDefImpl<T>::GetOperatorType() const {
return GetEpApi().KernelDef_GetOperatorType(this->p_);
}

template <typename T>
inline const char* ConstKernelDefImpl<T>::GetDomain() const {
return GetEpApi().KernelDef_GetDomain(this->p_);
}

template <typename T>
inline std::pair<int, int> ConstKernelDefImpl<T>::GetSinceVersion() const {
int start = 0;
int end = 0;

ThrowOnError(GetEpApi().KernelDef_GetSinceVersion(this->p_, &start, &end));
return std::pair<int, int>(start, end);

Check warning on line 3573 in include/onnxruntime/core/session/onnxruntime_cxx_inline.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <utility> for pair<> [build/include_what_you_use] [4] Raw Output: include/onnxruntime/core/session/onnxruntime_cxx_inline.h:3573: Add #include <utility> for pair<> [build/include_what_you_use] [4]
}

template <typename T>
inline const char* ConstKernelDefImpl<T>::GetExecutionProvider() const {
return GetEpApi().KernelDef_GetExecutionProvider(this->p_);
}

template <typename T>
inline OrtMemType ConstKernelDefImpl<T>::GetInputMemType(size_t input_index) const {
OrtMemType mem_type{};
ThrowOnError(GetEpApi().KernelDef_GetInputMemType(this->p_, input_index, &mem_type));

return mem_type;
}

template <typename T>
inline OrtMemType ConstKernelDefImpl<T>::GetOutputMemType(size_t output_index) const {
OrtMemType mem_type{};
ThrowOnError(GetEpApi().KernelDef_GetOutputMemType(this->p_, output_index, &mem_type));

return mem_type;
}
} // namespace detail

inline KernelDefBuilder::KernelDefBuilder() {
ThrowOnError(GetEpApi().CreateKernelDefBuilder(&p_));
}

inline KernelDefBuilder::KernelDefBuilder(OrtKernelDefBuilder* p) : detail::Base<OrtKernelDefBuilder>{p} {
}

inline KernelDefBuilder& KernelDefBuilder::SetOperatorType(const char* op_type) {
ThrowOnError(GetEpApi().KernelDefBuilder_SetOperatorType(p_, op_type));
return *this;
}

inline KernelDefBuilder& KernelDefBuilder::SetDomain(const char* domain) {
ThrowOnError(GetEpApi().KernelDefBuilder_SetDomain(p_, domain));
return *this;
}

inline KernelDefBuilder& KernelDefBuilder::SetSinceVersion(int since_version_start, int since_version_end) {
ThrowOnError(GetEpApi().KernelDefBuilder_SetSinceVersion(p_, since_version_start, since_version_end));
return *this;
}

inline KernelDefBuilder& KernelDefBuilder::SetExecutionProvider(const char* ep_name) {
ThrowOnError(GetEpApi().KernelDefBuilder_SetExecutionProvider(p_, ep_name));
return *this;
}

inline KernelDefBuilder& KernelDefBuilder::SetInputMemType(size_t input_index, OrtMemType mem_type) {
ThrowOnError(GetEpApi().KernelDefBuilder_SetInputMemType(p_, input_index, mem_type));
return *this;
}

inline KernelDefBuilder& KernelDefBuilder::SetOutputMemType(size_t output_index, OrtMemType mem_type) {
ThrowOnError(GetEpApi().KernelDefBuilder_SetOutputMemType(p_, output_index, mem_type));
return *this;
}

inline KernelDefBuilder& KernelDefBuilder::AddTypeConstraint(const char* arg_name,
const OrtMLDataType* data_types) {
ThrowOnError(GetEpApi().KernelDefBuilder_AddTypeConstraint(p_, arg_name, &data_types, 1));
return *this;
}

inline KernelDefBuilder& KernelDefBuilder::AddTypeConstraint(const char* arg_name,
const std::vector<const OrtMLDataType*>& data_types) {
ThrowOnError(GetEpApi().KernelDefBuilder_AddTypeConstraint(p_, arg_name, data_types.data(), data_types.size()));
return *this;
}

inline KernelDef KernelDefBuilder::Build() {
OrtKernelDef* kernel_def = nullptr;
ThrowOnError(GetEpApi().KernelDefBuilder_Build(p_, &kernel_def));
return KernelDef(kernel_def);
}

} // namespace Ort
73 changes: 73 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_ep_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,29 @@ ORT_RUNTIME_CLASS(DataTransferImpl);
ORT_RUNTIME_CLASS(SyncNotificationImpl);
ORT_RUNTIME_CLASS(SyncStreamImpl);

// Opaque types for kernel-based EPs
ORT_RUNTIME_CLASS(KernelRegistry);
ORT_RUNTIME_CLASS(KernelCreateContext); // stand-in for FuncManager. may not be needed.
ORT_RUNTIME_CLASS(KernelDefBuilder);
ORT_RUNTIME_CLASS(KernelDef);
ORT_RUNTIME_CLASS(MLDataType); // combination of ONNXType (e.g., Tensor, Map, Sequence) and ONNXTensorElementDataType

struct OrtKernelImpl;
typedef struct OrtKernelImpl OrtKernelImpl;

// struct that an EP implements for OpKernel computation.
struct OrtKernelImpl {
uint32_t ort_version_supported; ///< Must be initialized to ORT_API_VERSION

ORT_API2_STATUS(Compute, _In_ OrtKernelImpl* this_ptr, _In_ OrtKernelContext* context);
ORT_API_T(void, Release, _In_ OrtKernelImpl* this_ptr);
};

typedef OrtStatus*(ORT_API_CALL* OrtKernelCreateFunc)(_In_ OrtKernelCreateContext* ctx,
_In_ void* ep_state,
_In_ const OrtKernelInfo* info,
_Outptr_result_maybenull_ OrtKernelImpl** kernel_out);

// struct that an EP implements for IDataTransfer to copy between devices it uses and CPU
struct OrtDataTransferImpl {
uint32_t ort_version_supported; ///< Must be initialized to ORT_API_VERSION
Expand Down Expand Up @@ -465,6 +488,53 @@ struct OrtEpApi {
*/
ORT_API_T(uint64_t, GetSyncIdForLastWaitOnSyncStream,
_In_ const OrtSyncStream* producer_stream, _In_ const OrtSyncStream* consumer_stream);

ORT_API2_STATUS(CreateKernelRegistry, _Outptr_ OrtKernelRegistry** kernel_registry);
ORT_CLASS_RELEASE(KernelRegistry);
ORT_API2_STATUS(KernelRegistry_AddKernel, _In_ OrtKernelRegistry* kernel_registry,
_In_ const OrtKernelDef* kernel_def, _In_ OrtKernelCreateFunc kernel_create_func,
_In_ void* ep_state);

ORT_API2_STATUS(CreateKernelDefBuilder, _Outptr_ OrtKernelDefBuilder** kernel_def_builder_out);
ORT_CLASS_RELEASE(KernelDefBuilder);
ORT_API2_STATUS(KernelDefBuilder_SetOperatorType, _In_ OrtKernelDefBuilder* kernel_def_builder,
_In_ const char* op_type);
ORT_API2_STATUS(KernelDefBuilder_SetDomain, _In_ OrtKernelDefBuilder* kernel_def_builder, _In_ const char* domain);
ORT_API2_STATUS(KernelDefBuilder_SetSinceVersion, _In_ OrtKernelDefBuilder* kernel_def_builder,
_In_ int since_version_start, _In_ int since_version_end);
ORT_API2_STATUS(KernelDefBuilder_SetExecutionProvider, _In_ OrtKernelDefBuilder* kernel_def_builder,
_In_ const char* ep_name);
ORT_API2_STATUS(KernelDefBuilder_SetInputMemType, _In_ OrtKernelDefBuilder* kernel_def_builder,
_In_ size_t input_index, _In_ OrtMemType mem_type);
ORT_API2_STATUS(KernelDefBuilder_SetOutputMemType, _In_ OrtKernelDefBuilder* kernel_def_builder,
_In_ size_t output_index, _In_ OrtMemType mem_type);
ORT_API2_STATUS(KernelDefBuilder_AddTypeConstraint, _In_ OrtKernelDefBuilder* kernel_def_builder,
_In_ const char* arg_name, _In_reads_(num_types) const OrtMLDataType* const* types,
_In_ size_t num_types);
ORT_API2_STATUS(KernelDefBuilder_Build, _In_ OrtKernelDefBuilder* kernel_def_builder,
_Outptr_ OrtKernelDef** kernel_def_out);

Copy link
Contributor Author

@adrianlizarraga adrianlizarraga Oct 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR does not yet add all KernelDefBuilder functions. It's missing aliasing, "may inplace". However, these things may not be used commonly and could be added later.

ORT_CLASS_RELEASE(KernelDef);
ORT_API_T(const char*, KernelDef_GetOperatorType, _In_ const OrtKernelDef* kernel_def);
ORT_API_T(const char*, KernelDef_GetDomain, _In_ const OrtKernelDef* kernel_def);
ORT_API2_STATUS(KernelDef_GetSinceVersion, _In_ const OrtKernelDef* kernel_def,
_Out_ int* start_version, _Out_ int* end_version);
ORT_API_T(const char*, KernelDef_GetExecutionProvider, _In_ const OrtKernelDef* kernel_def);
ORT_API2_STATUS(KernelDef_GetInputMemType, _In_ const OrtKernelDef* kernel_def,
_In_ size_t input_index, _Out_ OrtMemType* mem_type);
ORT_API2_STATUS(KernelDef_GetOutputMemType, _In_ const OrtKernelDef* kernel_def,
_In_ size_t output_index, _Out_ OrtMemType* mem_type);

Copy link
Contributor Author

@adrianlizarraga adrianlizarraga Oct 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also have not added all getters for KernelDef because they are not really used by EPs. An EP retrieves a kernel def during GetCapability to check if a kernel for a node has been registered. Notably, there is only one EP (ACL EP) that actually gets a property from a KernelDef returned by a lookup, and that property is the operator type, which it could instead get from the node.

ORT_API2_STATUS(GetTensorMLDataType, _In_ ONNXTensorElementDataType elem_type,
_Outptr_ const OrtMLDataType** out);
Comment on lines +750 to +751
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I currently only added an API to get tensor data types. We would need to add similar APIs for sequences, maps, etc.

Also, I'm not too sure if we should keep using the term "ML data type". I kept it to remain consistent with the internal names, but perhaps we can rename?


ORT_API2_STATUS(KernelInfo_CopyTensors, _In_ const OrtKernelInfo* info,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it possible to reuse the CopyTensors API or do we need a new one?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The existing CopyTensors API takes an OrtEnv as input, which is not available to EPs (if I'm not mistaken)

_In_reads_(num_tensors) const OrtValue* const* src_tensors,
_In_reads_(num_tensors) OrtValue* const* dst_tensors,
_In_opt_ OrtSyncStream* stream,
_In_ size_t num_tensors);
ORT_API2_STATUS(EpGraphSupportInfo_LookUpKernel, _In_ OrtEpGraphSupportInfo* graph_support_info,
_In_ const OrtNode* node, _Outptr_result_maybenull_ const OrtKernelDef** out_kernel_def);
};

/**
Expand Down Expand Up @@ -726,6 +796,9 @@ struct OrtEp {
*/
ORT_API_T(const char*, GetCompiledModelCompatibilityInfo, _In_ OrtEp* this_ptr,
_In_ const OrtGraph* graph);

ORT_API2_STATUS(GetKernelRegistry, _In_ OrtEp* this_ptr,
_Outptr_result_maybenull_ const OrtKernelRegistry** kernel_registry);
};

/** \brief The function signature that ORT will call to create OrtEpFactory instances.
Expand Down
4 changes: 4 additions & 0 deletions onnxruntime/core/session/abi_ep_types.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@
#include "core/graph/ep_api_types.h"
#include "core/session/abi_devices.h"

OrtEpGraphSupportInfo::OrtEpGraphSupportInfo(const onnxruntime::EpGraph& graph,
const onnxruntime::IExecutionProvider::IKernelLookup& kernel_lookup)
: ort_graph(graph), kernel_lookup{kernel_lookup} {}

onnxruntime::Status OrtEpGraphSupportInfo::AddNodesToFuse(gsl::span<const OrtNode* const> nodes,
const OrtNodeFusionOptions* optional_fusion_options) {
std::vector<const onnxruntime::EpNode*> ep_nodes;
Expand Down
5 changes: 4 additions & 1 deletion onnxruntime/core/session/abi_ep_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

#include "core/common/inlined_containers_fwd.h"
#include "core/common/status.h"
#include "core/framework/execution_provider.h"
#include "core/session/onnxruntime_c_api.h"

namespace onnxruntime {
Expand Down Expand Up @@ -39,12 +40,14 @@ struct OrtEpGraphSupportInfo {
OrtNodeFusionOptions fusion_options = {};
};

explicit OrtEpGraphSupportInfo(const onnxruntime::EpGraph& graph) : ort_graph(graph) {}
OrtEpGraphSupportInfo(const onnxruntime::EpGraph& graph,
const onnxruntime::IExecutionProvider::IKernelLookup& kernel_lookup);

onnxruntime::Status AddNodesToFuse(gsl::span<const OrtNode* const> nodes,
const OrtNodeFusionOptions* node_fusion_options = nullptr);
onnxruntime::Status AddSingleNode(const OrtNode* node);

const onnxruntime::EpGraph& ort_graph;
std::vector<NodeGrouping> node_groupings;
const onnxruntime::IExecutionProvider::IKernelLookup& kernel_lookup;
};
47 changes: 4 additions & 43 deletions onnxruntime/core/session/onnxruntime_c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3430,53 +3430,14 @@ ORT_API_STATUS_IMPL(OrtApis::CopyTensors, _In_ const OrtEnv* env,
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Invalid arguments provided to CopyTensors.");
}

const OrtMemoryInfo* src_memory_info = nullptr;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: moved this into a shared utility function that can be used by the new API KernelInfo_CopyTensors

const OrtMemoryInfo* dst_memory_info = nullptr;

const auto validate_and_get_mem_info =
[](const OrtValue* const* values, size_t num_values, const OrtMemoryInfo*& mem_info) -> OrtStatus* {
for (size_t i = 0; i < num_values; ++i) {
const OrtValue* value = values[i];
if (value == nullptr || !value->IsTensor() || !value->IsAllocated()) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "OrtValue must contain Tensor with data.");
}

if (i == 0) {
mem_info = &value->Get<Tensor>().Location();
} else if (*mem_info != value->Get<Tensor>().Location()) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "All OrtValue instances must have the same OrtMemoryInfo");
}
}

return nullptr;
};

ORT_API_RETURN_IF_ERROR(validate_and_get_mem_info(src_tensors, num_tensors, src_memory_info));
ORT_API_RETURN_IF_ERROR(validate_and_get_mem_info(const_cast<const OrtValue**>(dst_tensors), num_tensors,
dst_memory_info));

auto& data_transfer_mgr = env->GetEnvironment().GetDataTransferManager();
const auto* data_transfer = data_transfer_mgr.GetDataTransfer(src_memory_info->device, dst_memory_info->device);

if (data_transfer == nullptr) {
return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED,
"Data transfer implementation between source and destination device was not found.");
}

std::vector<IDataTransfer::SrcDstPair> pairs;
pairs.reserve(num_tensors);
for (size_t i = 0; i < num_tensors; ++i) {
pairs.push_back({
src_tensors[i]->Get<Tensor>(),
*dst_tensors[i]->GetMutable<Tensor>(),
stream,
});
}

ORT_API_RETURN_IF_STATUS_NOT_OK(data_transfer->CopyTensors(pairs));
ORT_API_RETURN_IF_STATUS_NOT_OK(CopyTensors(data_transfer_mgr,
gsl::span<const OrtValue* const>(src_tensors, num_tensors),
gsl::span<OrtValue* const>(dst_tensors, num_tensors),
stream));

return nullptr;

API_IMPL_END
}

Expand Down
Loading
Loading