Skip to content

Commit 2234001

Browse files
authored
refactor ORT-Extension for the coming GroupQueryAttention work (#674)
* refactor ORT-Extension for the coming GroupQueryAttention work * fix typo and add #if ORT_API_VERSION >= 15 for GetOrtAllocator * fix cuda build
1 parent 2321329 commit 2234001

17 files changed

+326
-134
lines changed

base/ortx_common.h

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
#pragma once
4+
#include <locale>
5+
#include <optional>
6+
#include <string>
7+
#include <sstream>
8+
#include "string_utils.h"
9+
#ifdef _WIN32
10+
#include <Windows.h>
11+
#endif
12+
13+
#define ORTX_RETURN_IF_ERROR(expr) \
14+
do { \
15+
auto _status = (expr); \
16+
if (_status != nullptr) { \
17+
return _status; \
18+
} \
19+
} while (0)
20+
21+
template <typename T>
22+
bool TryParseStringWithClassicLocale(std::string_view str, T& value) {
23+
if constexpr (std::is_integral<T>::value && std::is_unsigned<T>::value) {
24+
// if T is unsigned integral type, reject negative values which will wrap
25+
if (!str.empty() && str[0] == '-') {
26+
return false;
27+
}
28+
}
29+
30+
// don't allow leading whitespace
31+
if (!str.empty() && std::isspace(str[0], std::locale::classic())) {
32+
return false;
33+
}
34+
35+
std::istringstream is{std::string{str}};
36+
is.imbue(std::locale::classic());
37+
T parsed_value{};
38+
39+
const bool parse_successful =
40+
is >> parsed_value &&
41+
is.get() == std::istringstream::traits_type::eof(); // don't allow trailing characters
42+
if (!parse_successful) {
43+
return false;
44+
}
45+
46+
value = std::move(parsed_value);
47+
return true;
48+
}
49+
50+
inline bool TryParseStringWithClassicLocale(std::string_view str, std::string& value) {
51+
value = str;
52+
return true;
53+
}
54+
55+
inline bool TryParseStringWithClassicLocale(std::string_view str, bool& value) {
56+
if (str == "0" || str == "False" || str == "false") {
57+
value = false;
58+
return true;
59+
}
60+
61+
if (str == "1" || str == "True" || str == "true") {
62+
value = true;
63+
return true;
64+
}
65+
66+
return false;
67+
}
68+
69+
template <typename T>
70+
std::optional<T> ParseEnvironmentVariable(const std::string& name) {
71+
std::string buffer;
72+
#ifdef _WIN32
73+
constexpr size_t kBufferSize = 32767;
74+
75+
// Create buffer to hold the result
76+
buffer.resize(kBufferSize, '\0');
77+
78+
// The last argument is the size of the buffer pointed to by the lpBuffer parameter, including the null-terminating character, in characters.
79+
// If the function succeeds, the return value is the number of characters stored in the buffer pointed to by lpBuffer, not including the terminating null character.
80+
// Therefore, If the function succeeds, kBufferSize should be larger than char_count.
81+
auto char_count = GetEnvironmentVariableA(name.c_str(), buffer.data(), kBufferSize);
82+
83+
if (kBufferSize > char_count) {
84+
buffer.resize(char_count);
85+
} else {
86+
// Else either the call was failed, or the buffer wasn't large enough.
87+
// TODO: Understand the reason for failure by calling GetLastError().
88+
// If it is due to the specified environment variable being found in the environment block,
89+
// GetLastError() returns ERROR_ENVVAR_NOT_FOUND.
90+
// For now, we assume that the environment variable is not found.
91+
buffer.clear();
92+
}
93+
#else
94+
char* val = getenv(name.c_str());
95+
buffer = (val == nullptr) ? std::string() : std::string(val);
96+
#endif
97+
T parsed_value;
98+
if (!TryParseStringWithClassicLocale(buffer, parsed_value)) {
99+
OrtW::Exception(MakeString("Failed to parse environment variable - name: ", name, ", value: ", buffer), OrtErrorCode::ORT_FAIL);
100+
}
101+
return parsed_value;
102+
}
103+
104+
template <typename T>
105+
T ParseEnvironmentVariableWithDefault(const std::string& name, const T& default_value) {
106+
const auto parsed = ParseEnvironmentVariable<T>(name);
107+
if (parsed.has_value()) {
108+
return *parsed;
109+
}
110+
111+
return default_value;
112+
}
113+
114+
inline bool IsScalarOr1ElementVector(size_t num_dimensions, int64_t shape_size) {
115+
if (num_dimensions == 0 || (num_dimensions == 1 && shape_size == 1)) return true;
116+
return false;
117+
}

base/string_utils.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
#pragma once
44
#include <sstream>
55
#include <vector>
6-
#include "ocos.h"
6+
#include "onnxruntime_cpp_api_legacy.hpp"
77

88
template <typename T>
99
inline void MakeStringInternal(std::ostringstream& ss, const T& t) noexcept {

docs/development.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ The package contains all custom operators and some Python scripts to manipulate
1717
- no-opencv: disable operators based on OpenCV in build.
1818
- cc-debug: Generate debug info for extensions binaries and disable C/C++ compiler optimization.
1919

20-
For example:`pip install --config-settings "ortx-user-option=use-cuda,cc-debug" `, This command builds CUDA kernels into the package and installs it, accompanied by the generation of debug information.
20+
For example:`pip install . --config-settings "ortx-user-option=use-cuda,cc-debug" `, This command builds CUDA kernels into the package and installs it, accompanied by the generation of debug information.
2121

2222
Test:
2323

includes/custom_op_lite.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -585,6 +585,11 @@ struct Variadic : public TensorBase {
585585

586586
enum CudaResource {
587587
cuda_handle_t = 10000,
588+
cudnn_handle_t,
589+
cublas_handle_t,
590+
deferred_cpu_allocator_t,
591+
// below are cuda ep options
592+
device_id_t,
588593
};
589594

590595
struct CudaContext {
@@ -595,8 +600,20 @@ struct CudaContext {
595600
if (!cuda_stream) {
596601
ORTX_CXX_API_THROW("Failed to fetch cuda stream from context", ORT_RUNTIME_EXCEPTION);
597602
}
603+
ort_api.KernelContext_GetResource(&ctx, cuda_resource_ver, CudaResource::cublas_handle_t, &cublas);
604+
if (!cublas) {
605+
ORTX_CXX_API_THROW("Failed to fetch cublas handle from context", ORT_RUNTIME_EXCEPTION);
606+
}
607+
void* resource = nullptr;
608+
OrtStatusPtr result = ort_api.KernelContext_GetResource(&ctx, cuda_resource_ver, CudaResource::device_id_t, &resource);
609+
if (result) {
610+
ORTX_CXX_API_THROW("Failed to fetch device id from context", ORT_RUNTIME_EXCEPTION);
611+
}
612+
memcpy(&device_id, &resource, sizeof(int));
598613
}
599614
void* cuda_stream = {};
615+
void* cublas = {};
616+
int device_id = 0;
600617
};
601618

602619
#endif

includes/onnxruntime_cpp_api_legacy.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
#pragma once
55
#include <vector>
6-
#include "onnxruntime_c_api.h"
6+
#include "exceptions.h"
77

88
//
99
// DEPRECATED: All new custom OPs should not use any class/struct/functions from this file.

includes/onnxruntime_customop.hpp

Lines changed: 14 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -15,118 +15,16 @@
1515
#include <utility>
1616
#include <type_traits>
1717
#include <optional>
18+
#include <functional>
1819

19-
#include "onnxruntime_c_api.h"
2020
#include "exceptions.h"
21+
#include "onnxruntime_no_customop.h"
2122
#include "onnxruntime_cpp_api_legacy.hpp"
2223
#include "onnxruntime_extensions.h"
2324
#include "custom_op_lite.h"
2425

2526
#define MIN_ORT_VERSION_SUPPORTED 11
2627

27-
// namespace of ORT ABI Wrapper
28-
namespace OrtW {
29-
30-
class API {
31-
// To use ONNX C ABI in a way like OrtW::API::CreateStatus.
32-
public:
33-
static API& instance(const OrtApi* ort_api = nullptr) noexcept {
34-
static API self(ort_api);
35-
return self;
36-
}
37-
38-
static OrtStatusPtr CreateStatus(OrtErrorCode code, _In_ const char* msg) noexcept {
39-
return instance()->CreateStatus(code, msg);
40-
}
41-
42-
static void ReleaseStatus(OrtStatusPtr ptr) noexcept {
43-
instance()->ReleaseStatus(ptr);
44-
}
45-
46-
template <typename T>
47-
static OrtStatusPtr KernelInfoGetAttribute(const OrtKernelInfo& info, const char* name, T& value) noexcept;
48-
49-
static void ThrowOnError(OrtStatusPtr ptr) {
50-
OrtW::ThrowOnError(instance().api_, ptr);
51-
}
52-
53-
private:
54-
const OrtApi* operator->() const {
55-
return &api_;
56-
}
57-
58-
API(const OrtApi* api) : api_(*api) {
59-
if (api == nullptr) {
60-
ORTX_CXX_API_THROW("ort-extensions internal error: ORT-APIs used before RegisterCustomOps", ORT_RUNTIME_EXCEPTION);
61-
}
62-
}
63-
64-
const OrtApi& api_;
65-
};
66-
67-
template <>
68-
inline OrtStatusPtr API::KernelInfoGetAttribute<int64_t>(const OrtKernelInfo& info, const char* name, int64_t& value) noexcept {
69-
return instance()->KernelInfoGetAttribute_int64(&info, name, &value);
70-
}
71-
72-
template <>
73-
inline OrtStatusPtr API::KernelInfoGetAttribute<float>(const OrtKernelInfo& info, const char* name, float& value) noexcept {
74-
return instance()->KernelInfoGetAttribute_float(&info, name, &value);
75-
}
76-
77-
template <>
78-
inline OrtStatusPtr API::KernelInfoGetAttribute<std::string>(const OrtKernelInfo& info, const char* name, std::string& value) noexcept {
79-
size_t size = 0;
80-
std::string out;
81-
// Feed nullptr for the data buffer to query the true size of the string attribute
82-
OrtStatus* status = instance()->KernelInfoGetAttribute_string(&info, name, nullptr, &size);
83-
if (status == nullptr) {
84-
out.resize(size);
85-
status = instance()->KernelInfoGetAttribute_string(&info, name, &out[0], &size);
86-
out.resize(size - 1); // remove the terminating character '\0'
87-
}
88-
89-
if (status == nullptr) {
90-
value = std::move(out);
91-
}
92-
93-
return status;
94-
}
95-
96-
template <class T>
97-
inline OrtStatusPtr GetOpAttribute(const OrtKernelInfo& info, const char* name, T& value) noexcept {
98-
if (auto status = API::KernelInfoGetAttribute(info, name, value); status) {
99-
// Ideally, we should know which kind of error code can be ignored, but it is not available now.
100-
// Just ignore all of them.
101-
API::ReleaseStatus(status);
102-
}
103-
104-
return nullptr;
105-
}
106-
107-
inline OrtStatusPtr CreateStatus(const char* msg, OrtErrorCode code) {
108-
return API::CreateStatus(code, msg);
109-
}
110-
111-
inline OrtStatusPtr CreateStatus(const std::string& msg, OrtErrorCode code) {
112-
return API::CreateStatus(code, msg.c_str());
113-
}
114-
115-
inline void ReleaseStatus(OrtStatusPtr& status) {
116-
API::ReleaseStatus(status);
117-
status = nullptr;
118-
}
119-
120-
} // namespace OrtW
121-
122-
#define ORTX_RETURN_IF_ERROR(expr) \
123-
do { \
124-
auto _status = (expr); \
125-
if (_status != nullptr) { \
126-
return _status; \
127-
} \
128-
} while (0)
129-
13028
namespace Ort {
13129
namespace Custom {
13230

@@ -164,6 +62,12 @@ struct ComputeArgsList<OrtStatusPtr (C::*)(Args...) const> {
16462
using MemberFunctionType = OrtStatusPtr (C::*)(Args...) const;
16563
};
16664

65+
template <typename T, typename = void>
66+
struct CustomOp_defined_getInputMemoryType : std::false_type {};
67+
68+
template <typename T>
69+
struct CustomOp_defined_getInputMemoryType<T, std::void_t<decltype(&T::GetInputMemoryType)>> : std::true_type {};
70+
16771
template <typename CustomOpKernel>
16872
struct OrtLiteCustomStructV2 : public OrtLiteCustomOp {
16973
using ComputeFunction = decltype(&CustomOpKernel::Compute);
@@ -236,6 +140,12 @@ struct OrtLiteCustomStructV2 : public OrtLiteCustomOp {
236140
OrtCustomOp::CreateKernel = nullptr;
237141
OrtCustomOp::KernelCompute = nullptr;
238142

143+
if constexpr (CustomOp_defined_getInputMemoryType<CustomOpKernel>::value) {
144+
OrtCustomOp::GetInputMemoryType = [](const OrtCustomOp* /*this_*/, size_t index) -> OrtMemType {
145+
return CustomOpKernel::GetInputMemoryType(index);
146+
};
147+
}
148+
239149
OrtCustomOp::CreateKernelV2 = [](const OrtCustomOp* this_,
240150
const OrtApi* api, const OrtKernelInfo* info, void** op_kernel) -> OrtStatusPtr {
241151
if (api == nullptr) {

0 commit comments

Comments
 (0)