|
15 | 15 | #include <utility> |
16 | 16 | #include <type_traits> |
17 | 17 | #include <optional> |
| 18 | +#include <functional> |
18 | 19 |
|
19 | | -#include "onnxruntime_c_api.h" |
20 | 20 | #include "exceptions.h" |
| 21 | +#include "onnxruntime_no_customop.h" |
21 | 22 | #include "onnxruntime_cpp_api_legacy.hpp" |
22 | 23 | #include "onnxruntime_extensions.h" |
23 | 24 | #include "custom_op_lite.h" |
24 | 25 |
|
25 | 26 | #define MIN_ORT_VERSION_SUPPORTED 11 |
26 | 27 |
|
27 | | -// namespace of ORT ABI Wrapper |
28 | | -namespace OrtW { |
29 | | - |
30 | | -class API { |
31 | | - // To use ONNX C ABI in a way like OrtW::API::CreateStatus. |
32 | | - public: |
33 | | - static API& instance(const OrtApi* ort_api = nullptr) noexcept { |
34 | | - static API self(ort_api); |
35 | | - return self; |
36 | | - } |
37 | | - |
38 | | - static OrtStatusPtr CreateStatus(OrtErrorCode code, _In_ const char* msg) noexcept { |
39 | | - return instance()->CreateStatus(code, msg); |
40 | | - } |
41 | | - |
42 | | - static void ReleaseStatus(OrtStatusPtr ptr) noexcept { |
43 | | - instance()->ReleaseStatus(ptr); |
44 | | - } |
45 | | - |
46 | | - template <typename T> |
47 | | - static OrtStatusPtr KernelInfoGetAttribute(const OrtKernelInfo& info, const char* name, T& value) noexcept; |
48 | | - |
49 | | - static void ThrowOnError(OrtStatusPtr ptr) { |
50 | | - OrtW::ThrowOnError(instance().api_, ptr); |
51 | | - } |
52 | | - |
53 | | - private: |
54 | | - const OrtApi* operator->() const { |
55 | | - return &api_; |
56 | | - } |
57 | | - |
58 | | - API(const OrtApi* api) : api_(*api) { |
59 | | - if (api == nullptr) { |
60 | | - ORTX_CXX_API_THROW("ort-extensions internal error: ORT-APIs used before RegisterCustomOps", ORT_RUNTIME_EXCEPTION); |
61 | | - } |
62 | | - } |
63 | | - |
64 | | - const OrtApi& api_; |
65 | | -}; |
66 | | - |
67 | | -template <> |
68 | | -inline OrtStatusPtr API::KernelInfoGetAttribute<int64_t>(const OrtKernelInfo& info, const char* name, int64_t& value) noexcept { |
69 | | - return instance()->KernelInfoGetAttribute_int64(&info, name, &value); |
70 | | -} |
71 | | - |
72 | | -template <> |
73 | | -inline OrtStatusPtr API::KernelInfoGetAttribute<float>(const OrtKernelInfo& info, const char* name, float& value) noexcept { |
74 | | - return instance()->KernelInfoGetAttribute_float(&info, name, &value); |
75 | | -} |
76 | | - |
77 | | -template <> |
78 | | -inline OrtStatusPtr API::KernelInfoGetAttribute<std::string>(const OrtKernelInfo& info, const char* name, std::string& value) noexcept { |
79 | | - size_t size = 0; |
80 | | - std::string out; |
81 | | - // Feed nullptr for the data buffer to query the true size of the string attribute |
82 | | - OrtStatus* status = instance()->KernelInfoGetAttribute_string(&info, name, nullptr, &size); |
83 | | - if (status == nullptr) { |
84 | | - out.resize(size); |
85 | | - status = instance()->KernelInfoGetAttribute_string(&info, name, &out[0], &size); |
86 | | - out.resize(size - 1); // remove the terminating character '\0' |
87 | | - } |
88 | | - |
89 | | - if (status == nullptr) { |
90 | | - value = std::move(out); |
91 | | - } |
92 | | - |
93 | | - return status; |
94 | | -} |
95 | | - |
96 | | -template <class T> |
97 | | -inline OrtStatusPtr GetOpAttribute(const OrtKernelInfo& info, const char* name, T& value) noexcept { |
98 | | - if (auto status = API::KernelInfoGetAttribute(info, name, value); status) { |
99 | | - // Ideally, we should know which kind of error code can be ignored, but it is not available now. |
100 | | - // Just ignore all of them. |
101 | | - API::ReleaseStatus(status); |
102 | | - } |
103 | | - |
104 | | - return nullptr; |
105 | | -} |
106 | | - |
107 | | -inline OrtStatusPtr CreateStatus(const char* msg, OrtErrorCode code) { |
108 | | - return API::CreateStatus(code, msg); |
109 | | -} |
110 | | - |
111 | | -inline OrtStatusPtr CreateStatus(const std::string& msg, OrtErrorCode code) { |
112 | | - return API::CreateStatus(code, msg.c_str()); |
113 | | -} |
114 | | - |
115 | | -inline void ReleaseStatus(OrtStatusPtr& status) { |
116 | | - API::ReleaseStatus(status); |
117 | | - status = nullptr; |
118 | | -} |
119 | | - |
120 | | -} // namespace OrtW |
121 | | - |
122 | | -#define ORTX_RETURN_IF_ERROR(expr) \ |
123 | | - do { \ |
124 | | - auto _status = (expr); \ |
125 | | - if (_status != nullptr) { \ |
126 | | - return _status; \ |
127 | | - } \ |
128 | | - } while (0) |
129 | | - |
130 | 28 | namespace Ort { |
131 | 29 | namespace Custom { |
132 | 30 |
|
@@ -164,6 +62,12 @@ struct ComputeArgsList<OrtStatusPtr (C::*)(Args...) const> { |
164 | 62 | using MemberFunctionType = OrtStatusPtr (C::*)(Args...) const; |
165 | 63 | }; |
166 | 64 |
|
| 65 | +template <typename T, typename = void> |
| 66 | +struct CustomOp_defined_getInputMemoryType : std::false_type {}; |
| 67 | + |
| 68 | +template <typename T> |
| 69 | +struct CustomOp_defined_getInputMemoryType<T, std::void_t<decltype(&T::GetInputMemoryType)>> : std::true_type {}; |
| 70 | + |
167 | 71 | template <typename CustomOpKernel> |
168 | 72 | struct OrtLiteCustomStructV2 : public OrtLiteCustomOp { |
169 | 73 | using ComputeFunction = decltype(&CustomOpKernel::Compute); |
@@ -236,6 +140,12 @@ struct OrtLiteCustomStructV2 : public OrtLiteCustomOp { |
236 | 140 | OrtCustomOp::CreateKernel = nullptr; |
237 | 141 | OrtCustomOp::KernelCompute = nullptr; |
238 | 142 |
|
| 143 | + if constexpr (CustomOp_defined_getInputMemoryType<CustomOpKernel>::value) { |
| 144 | + OrtCustomOp::GetInputMemoryType = [](const OrtCustomOp* /*this_*/, size_t index) -> OrtMemType { |
| 145 | + return CustomOpKernel::GetInputMemoryType(index); |
| 146 | + }; |
| 147 | + } |
| 148 | + |
239 | 149 | OrtCustomOp::CreateKernelV2 = [](const OrtCustomOp* this_, |
240 | 150 | const OrtApi* api, const OrtKernelInfo* info, void** op_kernel) -> OrtStatusPtr { |
241 | 151 | if (api == nullptr) { |
|
0 commit comments