-
Notifications
You must be signed in to change notification settings - Fork 3.5k
[EP ABI] Initial support for kernel-based EPs #26206
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 24 commits
8c03c61
40d7866
ec46ea5
e8532a9
246681c
6b61b91
4a112fc
0f870d0
9be6923
5df3fb5
5aade60
d159b38
8187233
bf92f04
da14d65
73ae307
fb4a6a6
b8867d6
de8be32
dc78ec3
8babb63
90bf598
9f95589
90e4fc1
5f52cdc
9bfc281
60ea06c
33ffd8d
31cdc82
93b99e6
995e25b
0f7145f
f528f6f
9fbf230
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,6 +17,29 @@ ORT_RUNTIME_CLASS(DataTransferImpl); | |
ORT_RUNTIME_CLASS(SyncNotificationImpl); | ||
ORT_RUNTIME_CLASS(SyncStreamImpl); | ||
|
||
// Opaque types for kernel-based EPs | ||
ORT_RUNTIME_CLASS(KernelRegistry); | ||
ORT_RUNTIME_CLASS(KernelCreateContext); // stand-in for FuncManager. may not be needed. | ||
ORT_RUNTIME_CLASS(KernelDefBuilder); | ||
ORT_RUNTIME_CLASS(KernelDef); | ||
ORT_RUNTIME_CLASS(MLDataType); // combination of ONNXType (e.g., Tensor, Map, Sequence) and ONNXTensorElementDataType | ||
|
||
struct OrtKernelImpl; | ||
typedef struct OrtKernelImpl OrtKernelImpl; | ||
|
||
// struct that an EP implements for OpKernel computation. | ||
struct OrtKernelImpl { | ||
uint32_t ort_version_supported; ///< Must be initialized to ORT_API_VERSION | ||
|
||
ORT_API2_STATUS(Compute, _In_ OrtKernelImpl* this_ptr, _In_ OrtKernelContext* context); | ||
ORT_API_T(void, Release, _In_ OrtKernelImpl* this_ptr); | ||
}; | ||
|
||
typedef OrtStatus*(ORT_API_CALL* OrtKernelCreateFunc)(_In_ OrtKernelCreateContext* ctx, | ||
_In_ void* ep_state, | ||
_In_ const OrtKernelInfo* info, | ||
_Outptr_result_maybenull_ OrtKernelImpl** kernel_out); | ||
|
||
// struct that an EP implements for IDataTransfer to copy between devices it uses and CPU | ||
struct OrtDataTransferImpl { | ||
uint32_t ort_version_supported; ///< Must be initialized to ORT_API_VERSION | ||
|
@@ -465,6 +488,42 @@ struct OrtEpApi { | |
*/ | ||
ORT_API_T(uint64_t, GetSyncIdForLastWaitOnSyncStream, | ||
_In_ const OrtSyncStream* producer_stream, _In_ const OrtSyncStream* consumer_stream); | ||
|
||
ORT_API2_STATUS(CreateKernelRegistry, _Outptr_ OrtKernelRegistry** kernel_registry); | ||
ORT_CLASS_RELEASE(KernelRegistry); | ||
ORT_API2_STATUS(KernelRegistry_AddKernel, _In_ OrtKernelRegistry* kernel_registry, | ||
_In_ const OrtKernelDef* kernel_def, _In_ OrtKernelCreateFunc kernel_create_func, | ||
_In_ void* ep_state); | ||
|
||
ORT_API2_STATUS(CreateKernelDefBuilder, _Outptr_ OrtKernelDefBuilder** kernel_def_builder_out); | ||
ORT_CLASS_RELEASE(KernelDefBuilder); | ||
ORT_API2_STATUS(KernelDefBuilder_SetOperatorType, _In_ OrtKernelDefBuilder* kernel_def_builder, | ||
_In_ const char* op_type); | ||
ORT_API2_STATUS(KernelDefBuilder_SetDomain, _In_ OrtKernelDefBuilder* kernel_def_builder, _In_ const char* domain); | ||
ORT_API2_STATUS(KernelDefBuilder_SetSinceVersion, _In_ OrtKernelDefBuilder* kernel_def_builder, | ||
_In_ int since_version_start, _In_ int since_version_end); | ||
ORT_API2_STATUS(KernelDefBuilder_SetExecutionProvider, _In_ OrtKernelDefBuilder* kernel_def_builder, | ||
_In_ const char* ep_name); | ||
ORT_API2_STATUS(KernelDefBuilder_SetInputMemType, _In_ OrtKernelDefBuilder* kernel_def_builder, | ||
_In_ size_t input_index, _In_ OrtMemType mem_type); | ||
ORT_API2_STATUS(KernelDefBuilder_SetOutputMemType, _In_ OrtKernelDefBuilder* kernel_def_builder, | ||
_In_ size_t output_index, _In_ OrtMemType mem_type); | ||
ORT_API2_STATUS(KernelDefBuilder_AddTypeConstraint, _In_ OrtKernelDefBuilder* kernel_def_builder, | ||
_In_ const char* arg_name, _In_reads_(num_types) const OrtMLDataType* const* types, | ||
_In_ size_t num_types); | ||
ORT_API2_STATUS(KernelDefBuilder_Build, _In_ OrtKernelDefBuilder* kernel_def_builder, | ||
_Outptr_ OrtKernelDef** kernel_def_out); | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This PR does not yet add all KernelDefBuilder functions. It's missing aliasing, "may inplace". However, these things may not be used commonly and could be added later. |
||
ORT_CLASS_RELEASE(KernelDef); | ||
|
||
ORT_API2_STATUS(GetTensorMLDataType, _In_ ONNXTensorElementDataType elem_type, | ||
_Outptr_ const OrtMLDataType** out); | ||
Comment on lines
+750
to
+751
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I currently only added an API to get tensor data types. We would need to add similar APIs for sequences, maps, etc. Also, I'm not too sure if we should keep using the term "ML data type". I kept it to remain consistent with the internal names, but perhaps we can rename? |
||
|
||
ORT_API2_STATUS(KernelInfo_CopyTensors, _In_ const OrtKernelInfo* info, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is it possible to reuse the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The existing |
||
_In_reads_(num_tensors) const OrtValue* const* src_tensors, | ||
_In_reads_(num_tensors) OrtValue* const* dst_tensors, | ||
_In_opt_ OrtSyncStream* stream, | ||
_In_ size_t num_tensors); | ||
}; | ||
|
||
/** | ||
|
@@ -726,6 +785,9 @@ struct OrtEp { | |
*/ | ||
ORT_API_T(const char*, GetCompiledModelCompatibilityInfo, _In_ OrtEp* this_ptr, | ||
_In_ const OrtGraph* graph); | ||
|
||
ORT_API2_STATUS(GetKernelRegistry, _In_ OrtEp* this_ptr, | ||
_Outptr_result_maybenull_ const OrtKernelRegistry** kernel_registry); | ||
}; | ||
|
||
/** \brief The function signature that ORT will call to create OrtEpFactory instances. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3430,53 +3430,14 @@ ORT_API_STATUS_IMPL(OrtApis::CopyTensors, _In_ const OrtEnv* env, | |
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Invalid arguments provided to CopyTensors."); | ||
} | ||
|
||
const OrtMemoryInfo* src_memory_info = nullptr; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note: moved this into a shared utility function that can be used by the new API |
||
const OrtMemoryInfo* dst_memory_info = nullptr; | ||
|
||
const auto validate_and_get_mem_info = | ||
[](const OrtValue* const* values, size_t num_values, const OrtMemoryInfo*& mem_info) -> OrtStatus* { | ||
for (size_t i = 0; i < num_values; ++i) { | ||
const OrtValue* value = values[i]; | ||
if (value == nullptr || !value->IsTensor() || !value->IsAllocated()) { | ||
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "OrtValue must contain Tensor with data."); | ||
} | ||
|
||
if (i == 0) { | ||
mem_info = &value->Get<Tensor>().Location(); | ||
} else if (*mem_info != value->Get<Tensor>().Location()) { | ||
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "All OrtValue instances must have the same OrtMemoryInfo"); | ||
} | ||
} | ||
|
||
return nullptr; | ||
}; | ||
|
||
ORT_API_RETURN_IF_ERROR(validate_and_get_mem_info(src_tensors, num_tensors, src_memory_info)); | ||
ORT_API_RETURN_IF_ERROR(validate_and_get_mem_info(const_cast<const OrtValue**>(dst_tensors), num_tensors, | ||
dst_memory_info)); | ||
|
||
auto& data_transfer_mgr = env->GetEnvironment().GetDataTransferManager(); | ||
const auto* data_transfer = data_transfer_mgr.GetDataTransfer(src_memory_info->device, dst_memory_info->device); | ||
|
||
if (data_transfer == nullptr) { | ||
return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, | ||
"Data transfer implementation between source and destination device was not found."); | ||
} | ||
|
||
std::vector<IDataTransfer::SrcDstPair> pairs; | ||
pairs.reserve(num_tensors); | ||
for (size_t i = 0; i < num_tensors; ++i) { | ||
pairs.push_back({ | ||
src_tensors[i]->Get<Tensor>(), | ||
*dst_tensors[i]->GetMutable<Tensor>(), | ||
stream, | ||
}); | ||
} | ||
|
||
ORT_API_RETURN_IF_STATUS_NOT_OK(data_transfer->CopyTensors(pairs)); | ||
ORT_API_RETURN_IF_STATUS_NOT_OK(CopyTensors(data_transfer_mgr, | ||
gsl::span<const OrtValue* const>(src_tensors, num_tensors), | ||
gsl::span<OrtValue* const>(dst_tensors, num_tensors), | ||
stream)); | ||
|
||
return nullptr; | ||
|
||
API_IMPL_END | ||
} | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: it is one data type with this overload, right?