Skip to content

Commit 4fd9ebb

Browse files
authored
[EP API] header-only adapter for EP API (#26919)
### Description This PR adds a few headers for supporting building WebGPU EP and CUDA EP as plugin EPs. See summary of #26907
1 parent 0411d41 commit 4fd9ebb

File tree

16 files changed

+1457
-0
lines changed

16 files changed

+1457
-0
lines changed

include/onnxruntime/ep/README.md

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
## EP adapter
2+
3+
This folder contains a set of C++ header files. They are used specifically for allowing ONNX Runtime internal kernel-based EPs to use the plugin-style EP API while keep minimal changes to existing code.
4+
5+
### Folder Structure
6+
7+
There are 2 types of header files:
8+
9+
- General header files for plugin EP. This may include utilities, macros and shared routines that depending on ONNX Runtime public API only. There are multiple places for header files of this category (which we are going to unify them to one place. There is an ongoing discussion about unifying shared headers for plugin EPs):
10+
- `include/onnxruntime/ep/` (#26919)
11+
- `onnxruntime/test/autoep/library/plugin_ep_utils.h`
12+
- `include/onnxruntime/core/providers/utils/` (#25753)
13+
14+
- Header files specifically used for supporting WebGPU EP and CUDA EP to use EP APIs. These header files do not only depend on ONNX Runtime public API, but also depend on ONNX Runtime internal headers. They define adapter classes that replace their compatible, internal ONNX Runtime equivalents.
15+
- `include/onnxruntime/ep/adapter/`
16+
17+
### Usage
18+
19+
Make sure to include "ep/adapters.h" to include all adapter implementation code. This file brings the adapter classes into the EP's namespace, so it should be included before other EP code that relies on the adapter classes. Using "ep/adapters.h" as a pre-compiled header is the recommended way to include it first.
20+
21+
`ep/adapters.h` has conflicts with shared provider. Shared provider should be disabled when using these adapters.
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#pragma once
5+
6+
#if !defined(ORT_EP_API_ADAPTER_HEADER_INCLUDED)
7+
#error "This header should not be included directly. Include ep/adapters.h instead."
8+
#endif
9+
10+
#include "core/framework/allocator.h"
11+
12+
namespace onnxruntime {
13+
namespace ep {
14+
namespace adapter {
15+
16+
/// <summary>
17+
/// A bridge class between the EP API OrtAllocator and an IAllocator implementation.
18+
/// </summary>
19+
class Allocator : public OrtAllocator {
20+
public:
21+
explicit Allocator(const OrtMemoryInfo* memory_info, AllocatorPtr impl)
22+
: OrtAllocator{}, memory_info_(memory_info), impl_(impl) {
23+
version = ORT_API_VERSION;
24+
Alloc = AllocImpl;
25+
Free = FreeImpl;
26+
Info = InfoImpl;
27+
}
28+
29+
private:
30+
static void* ORT_API_CALL AllocImpl(OrtAllocator* this_ptr, size_t size) noexcept {
31+
auto* allocator = static_cast<Allocator*>(this_ptr);
32+
return allocator->impl_->Alloc(size);
33+
}
34+
35+
static void ORT_API_CALL FreeImpl(OrtAllocator* this_ptr, void* p) noexcept {
36+
auto* allocator = static_cast<Allocator*>(this_ptr);
37+
allocator->impl_->Free(p);
38+
}
39+
40+
static const OrtMemoryInfo* ORT_API_CALL InfoImpl(const OrtAllocator* this_ptr) noexcept {
41+
auto* allocator = static_cast<const Allocator*>(this_ptr);
42+
return allocator->memory_info_;
43+
}
44+
45+
const OrtMemoryInfo* memory_info_;
46+
AllocatorPtr impl_;
47+
};
48+
49+
} // namespace adapter
50+
} // namespace ep
51+
} // namespace onnxruntime
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#pragma once
5+
6+
#if !defined(ORT_EP_API_ADAPTER_HEADER_INCLUDED)
7+
#error "This header should not be included directly. Include ep/adapters.h instead."
8+
#endif
9+
10+
#include "core/common/status.h"
11+
#include "core/common/common.h"
12+
#include "core/framework/data_transfer.h"
13+
#include "core/framework/tensor.h"
14+
15+
namespace onnxruntime {
16+
namespace ep {
17+
namespace adapter {
18+
19+
/// <summary>
20+
/// An adapter class partially implementing the interface of `onnxruntime::DataTransferManager`.
21+
/// </summary>
22+
struct DataTransferManager {
23+
explicit DataTransferManager(std::unique_ptr<IDataTransfer> impl) : impl_{std::move(impl)} {}
24+
25+
common::Status CopyTensor(const Tensor& src, Tensor& dst) const {
26+
if (src.Shape().Size() != dst.Shape().Size()) {
27+
return ORT_MAKE_STATUS(ONNXRUNTIME,
28+
FAIL,
29+
"Tensor size mismatch: source tensor size is ",
30+
src.Shape().Size(),
31+
", destination tensor size is ",
32+
dst.Shape().Size());
33+
}
34+
35+
if (impl_->CanCopy(src.Location().device, dst.Location().device)) {
36+
return impl_->CopyTensor(src, dst);
37+
}
38+
39+
return ORT_MAKE_STATUS(ONNXRUNTIME,
40+
FAIL,
41+
"There's no data transfer registered for copying tensors from ",
42+
src.Location().device.ToString(),
43+
" to ",
44+
dst.Location().device.ToString());
45+
}
46+
47+
private:
48+
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(DataTransferManager);
49+
std::unique_ptr<IDataTransfer> impl_;
50+
};
51+
52+
} // namespace adapter
53+
} // namespace ep
54+
} // namespace onnxruntime
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#pragma once
5+
6+
#if !defined(ORT_EP_API_ADAPTER_HEADER_INCLUDED)
7+
#error "This header should not be included directly. Include ep/adapters.h instead."
8+
#endif
9+
10+
#include "data_transfer_manager.h"
11+
12+
#include "core/framework/execution_provider.h"
13+
14+
namespace onnxruntime {
15+
namespace ep {
16+
namespace adapter {
17+
18+
/// <summary>
19+
/// Wrapper around IExecutionProvider to expose via OrtEp.
20+
/// </summary>
21+
class Ep : public OrtEp {
22+
protected:
23+
explicit Ep(std::unique_ptr<IExecutionProvider> impl, AllocatorPtr temp_space_cpu_allocator, AllocatorPtr temp_space_allocator)
24+
: OrtEp{},
25+
impl_(std::move(impl)),
26+
data_transfer_manager_{impl_->GetDataTransfer()},
27+
profiler_{impl_->GetProfiler()},
28+
temp_space_cpu_allocator_{temp_space_cpu_allocator},
29+
temp_space_allocator_{temp_space_allocator} {
30+
}
31+
32+
public:
33+
inline IExecutionProvider* EpImpl() const noexcept {
34+
return impl_.get();
35+
}
36+
inline const DataTransferManager& GetDataTransferManager() const noexcept {
37+
return data_transfer_manager_;
38+
}
39+
Status GetTempSpaceCPUAllocator(AllocatorPtr* output) const {
40+
*output = temp_space_cpu_allocator_;
41+
return Status::OK();
42+
}
43+
Status GetTempSpaceAllocator(AllocatorPtr* output) const {
44+
*output = temp_space_allocator_;
45+
return Status::OK();
46+
}
47+
48+
private:
49+
std::unique_ptr<IExecutionProvider> impl_;
50+
DataTransferManager data_transfer_manager_;
51+
std::unique_ptr<profiling::EpProfiler> profiler_;
52+
AllocatorPtr temp_space_cpu_allocator_;
53+
AllocatorPtr temp_space_allocator_;
54+
};
55+
56+
} // namespace adapter
57+
} // namespace ep
58+
} // namespace onnxruntime
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#pragma once
5+
6+
#if !defined(ORT_EP_API_ADAPTER_HEADER_INCLUDED)
7+
#error "This header should not be included directly. Include ep/adapters.h instead."
8+
#endif
9+
10+
#include <memory>
11+
12+
namespace onnxruntime {
13+
namespace ep {
14+
namespace adapter {
15+
16+
/// <summary>
17+
/// An adapter class partially implementing the interface of `onnxruntime::KernelDef`.
18+
/// </summary>
19+
class KernelDef {
20+
public:
21+
explicit KernelDef(const OrtKernelInfo* kernel_info) : kernel_info_{kernel_info} {}
22+
23+
const std::string OpName() const {
24+
return kernel_info_.GetNodeName();
25+
}
26+
27+
const std::string Domain() const {
28+
return kernel_info_.GetOperatorDomain();
29+
}
30+
31+
private:
32+
const Ort::ConstKernelInfo kernel_info_;
33+
};
34+
35+
} // namespace adapter
36+
} // namespace ep
37+
} // namespace onnxruntime
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#pragma once
5+
6+
#if !defined(ORT_EP_API_ADAPTER_HEADER_INCLUDED)
7+
#error "This header should not be included directly. Include ep/adapters.h instead."
8+
#endif
9+
10+
#include <memory>
11+
12+
#include "core/framework/data_types.h"
13+
14+
namespace onnxruntime {
15+
namespace ep {
16+
namespace adapter {
17+
18+
/// <summary>
19+
/// Gets an OrtMLDataType for a tensor type. Throws on error.
20+
/// </summary>
21+
/// <param name="elem_type"></param>
22+
/// <returns></returns>
23+
inline const OrtDataType* GetTensorType(ONNXTensorElementDataType elem_type) {
24+
const OrtEpApi& ep_api = Ort::GetEpApi();
25+
const OrtDataType* result = nullptr;
26+
27+
Ort::ThrowOnError(ep_api.GetTensorDataType(elem_type, &result));
28+
return result;
29+
}
30+
31+
inline const OrtDataType* MLDataTypeToOrtDataType(MLDataType ml_type) {
32+
auto tensor_type = ml_type->AsTensorType();
33+
EP_ENFORCE(tensor_type != nullptr, "EP Kernel registration only supports tensor types.");
34+
auto elem_type = tensor_type->GetElementType();
35+
auto primitive_type = static_cast<const PrimitiveDataTypeBase*>(elem_type);
36+
auto onnx_type = static_cast<ONNXTensorElementDataType>(primitive_type->GetDataType());
37+
return GetTensorType(onnx_type);
38+
}
39+
40+
/// <summary>
41+
/// An adapter class partially implementing the interface of `onnxruntime::KernelDefBuilder`.
42+
/// </summary>
43+
struct KernelDefBuilder {
44+
static std::unique_ptr<KernelDefBuilder> Create() { return std::make_unique<KernelDefBuilder>(); }
45+
46+
explicit KernelDefBuilder() {}
47+
48+
KernelDefBuilder& SetName(const char* op_name) {
49+
builder_.SetOperatorType(op_name);
50+
return *this;
51+
}
52+
53+
KernelDefBuilder& SetDomain(const char* domain) {
54+
builder_.SetDomain(domain);
55+
return *this;
56+
}
57+
58+
KernelDefBuilder& SinceVersion(int since_version) {
59+
return SinceVersion(since_version, INT_MAX);
60+
}
61+
62+
KernelDefBuilder& SinceVersion(int since_version_start, int since_version_end) {
63+
builder_.SetSinceVersion(since_version_start, since_version_end);
64+
return *this;
65+
}
66+
67+
KernelDefBuilder& Provider(const char* provider_type) {
68+
builder_.SetExecutionProvider(provider_type);
69+
return *this;
70+
}
71+
72+
KernelDefBuilder& TypeConstraint(const char* arg_name, std::vector<MLDataType> types) {
73+
std::vector<const OrtDataType*> ort_types;
74+
ort_types.reserve(types.size());
75+
for (const auto& type : types) {
76+
ort_types.push_back(MLDataTypeToOrtDataType(type));
77+
}
78+
builder_.AddTypeConstraint(arg_name, ort_types);
79+
return *this;
80+
}
81+
82+
KernelDefBuilder& TypeConstraint(const char* arg_name, MLDataType type) {
83+
builder_.AddTypeConstraint(arg_name, MLDataTypeToOrtDataType(type));
84+
return *this;
85+
}
86+
87+
KernelDefBuilder& MayInplace(const std::vector<std::pair<int, int>>& inplaces) {
88+
for (const auto& pair : inplaces) {
89+
builder_.AddInputOutputMutableAlias(pair.first, pair.second);
90+
}
91+
return *this;
92+
}
93+
KernelDefBuilder& MayInplace(int input_index, int output_index) {
94+
builder_.AddInputOutputMutableAlias(input_index, output_index);
95+
return *this;
96+
}
97+
98+
KernelDefBuilder& Alias(const std::vector<std::pair<int, int>>& aliases) {
99+
for (const auto& pair : aliases) {
100+
builder_.AddInputOutputAlias(pair.first, pair.second);
101+
}
102+
return *this;
103+
}
104+
KernelDefBuilder& Alias(int input_index, int output_index) {
105+
builder_.AddInputOutputAlias(input_index, output_index);
106+
return *this;
107+
}
108+
109+
KernelDefBuilder& InputMemoryType(OrtMemType type, int input_index) {
110+
builder_.SetInputMemType(input_index, type);
111+
return *this;
112+
}
113+
114+
KernelDefBuilder& InputMemoryType(OrtMemType type, const std::vector<int>& input_indexes) {
115+
for (int input_index : input_indexes) {
116+
builder_.SetInputMemType(input_index, type);
117+
}
118+
return *this;
119+
}
120+
121+
KernelDefBuilder& OutputMemoryType(OrtMemType type, int output_index) {
122+
builder_.SetOutputMemType(output_index, type);
123+
return *this;
124+
}
125+
126+
KernelDefBuilder& OutputMemoryType(OrtMemType type, const std::vector<int>& output_indexes) {
127+
for (int output_index : output_indexes) {
128+
builder_.SetOutputMemType(output_index, type);
129+
}
130+
return *this;
131+
}
132+
133+
KernelDefBuilder& ExecQueueId(int /*queue_id*/) { return *this; }
134+
135+
Ort::KernelDef Build() { return builder_.Build(); }
136+
137+
private:
138+
Ort::KernelDefBuilder builder_;
139+
};
140+
141+
} // namespace adapter
142+
} // namespace ep
143+
} // namespace onnxruntime

0 commit comments

Comments
 (0)