Skip to content

Commit 0a2cac8

Browse files
committed
2 parents 16c787a + f1abea1 commit 0a2cac8

File tree

2 files changed

+76
-76
lines changed

2 files changed

+76
-76
lines changed

CMakeLists.txt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,10 @@ if(CMAKE_VERSION VERSION_GREATER_EQUAL "3.24.0")
3636
cmake_policy(SET CMP0077 NEW)
3737
endif()
3838

39-
# Avoid warning of Calling FetchContent_Populate(GSL) is deprecated
39+
# Avoid warning of Calling FetchContent_Populate(GSL) is deprecated temporarily
40+
# TODO: find a better way to handle the header-only 3rd party deps
4041
if(CMAKE_VERSION VERSION_GREATER_EQUAL "3.30.0")
41-
cmake_policy(CMP0169 OLD)
42+
cmake_policy(SET CMP0169 OLD)
4243
endif()
4344

4445
# Needed for Java

include/custom_op/custom_op_lite.h

Lines changed: 73 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -16,19 +16,19 @@ namespace Custom {
1616

1717
class OrtKernelContextStorage : public ITensorStorage {
1818
public:
19-
OrtKernelContextStorage(const OrtW::CustomOpApi& api,
19+
OrtKernelContextStorage(const OrtW::CustomOpApi& custom_op_api,
2020
OrtKernelContext& ctx,
2121
size_t indice,
22-
bool is_input) : api_(api), ctx_(ctx), indice_(indice) {
22+
bool is_input) : api_(custom_op_api), ctx_(ctx), indice_(indice) {
2323
if (is_input) {
24-
auto input_count = api.KernelContext_GetInputCount(&ctx);
24+
auto input_count = api_.KernelContext_GetInputCount(&ctx);
2525
if (indice >= input_count) {
2626
ORTX_CXX_API_THROW("invalid indice", ORT_RUNTIME_EXCEPTION);
2727
}
28-
const_value_ = api.KernelContext_GetInput(&ctx, indice);
29-
auto* info = api.GetTensorTypeAndShape(const_value_);
30-
shape_ = api.GetTensorShape(info);
31-
api.ReleaseTensorTypeAndShapeInfo(info);
28+
const_value_ = api_.KernelContext_GetInput(&ctx, indice);
29+
auto* info = api_.GetTensorTypeAndShape(const_value_);
30+
shape_ = api_.GetTensorShape(info);
31+
api_.ReleaseTensorTypeAndShapeInfo(info);
3232
}
3333
}
3434

@@ -66,18 +66,18 @@ class OrtKernelContextStorage : public ITensorStorage {
6666
std::optional<std::vector<int64_t>> shape_;
6767
};
6868

69-
static std::string get_mem_type(const OrtW::CustomOpApi& api,
70-
OrtKernelContext& ctx,
71-
size_t indice,
72-
bool is_input){
69+
static std::string get_mem_type(const OrtW::CustomOpApi& custom_op_api,
70+
OrtKernelContext& ctx,
71+
size_t indice,
72+
bool is_input) {
7373
std::string output = "Cpu";
7474
if (is_input) {
75-
const OrtValue* const_value = api.KernelContext_GetInput(&ctx, indice);
75+
const OrtValue* const_value = custom_op_api.KernelContext_GetInput(&ctx, indice);
7676
const OrtMemoryInfo* mem_info = {};
77-
api.ThrowOnError(api.GetOrtApi().GetTensorMemoryInfo(const_value, &mem_info));
77+
custom_op_api.ThrowOnError(custom_op_api.GetOrtApi().GetTensorMemoryInfo(const_value, &mem_info));
7878
if (mem_info) {
7979
const char* mem_type = nullptr;
80-
api.ThrowOnError(api.GetOrtApi().MemoryInfoGetName(mem_info, &mem_type));
80+
custom_op_api.ThrowOnError(custom_op_api.GetOrtApi().MemoryInfoGetName(mem_info, &mem_type));
8181
if (mem_type) {
8282
output = mem_type;
8383
}
@@ -88,29 +88,29 @@ static std::string get_mem_type(const OrtW::CustomOpApi& api,
8888

8989
template <typename T>
9090
class OrtTensor : public Tensor<T> {
91-
public:
92-
OrtTensor(const OrtW::CustomOpApi& api,
93-
OrtKernelContext& ctx,
94-
size_t indice,
95-
bool is_input) : Tensor<T>(std::make_unique<OrtKernelContextStorage>(api, ctx, indice, is_input)),
96-
mem_type_(get_mem_type(api, ctx, indice, is_input)) {
91+
public:
92+
OrtTensor(const OrtW::CustomOpApi& custom_op_api,
93+
OrtKernelContext& ctx,
94+
size_t indice,
95+
bool is_input) : Tensor<T>(std::make_unique<OrtKernelContextStorage>(custom_op_api, ctx, indice, is_input)),
96+
mem_type_(get_mem_type(custom_op_api, ctx, indice, is_input)) {
9797
}
9898

9999
bool IsCpuTensor() const {
100100
return mem_type_ == "Cpu";
101101
}
102102

103-
private:
103+
private:
104104
std::string mem_type_ = "Cpu";
105105
};
106106

107107
class OrtStringTensorStorage : public IStringTensorStorage<std::string> {
108108
public:
109109
using strings = std::vector<std::string>;
110-
OrtStringTensorStorage(const OrtW::CustomOpApi& api,
110+
OrtStringTensorStorage(const OrtW::CustomOpApi& custom_op_api,
111111
OrtKernelContext& ctx,
112112
size_t indice,
113-
bool is_input) : api_(api), ctx_(ctx), indice_(indice) {
113+
bool is_input) : api_(custom_op_api), ctx_(ctx), indice_(indice) {
114114
if (is_input) {
115115
auto input_count = api_.KernelContext_GetInputCount(&ctx_);
116116
if (indice >= input_count) {
@@ -197,10 +197,10 @@ class OrtStringTensorStorage : public IStringTensorStorage<std::string> {
197197
class OrtStringViewTensorStorage : public IStringTensorStorage<std::string_view> {
198198
public:
199199
using strings = std::vector<std::string_view>;
200-
OrtStringViewTensorStorage(const OrtW::CustomOpApi& api,
200+
OrtStringViewTensorStorage(const OrtW::CustomOpApi& custom_op_api,
201201
OrtKernelContext& ctx,
202202
size_t indice,
203-
bool is_input) : api_(api), ctx_(ctx), indice_(indice) {
203+
bool is_input) : api_(custom_op_api), ctx_(ctx), indice_(indice) {
204204
if (is_input) {
205205
auto input_count = api_.KernelContext_GetInputCount(&ctx_);
206206
if (indice >= input_count) {
@@ -275,57 +275,56 @@ class OrtStringViewTensorStorage : public IStringTensorStorage<std::string_view>
275275

276276
// to make the metaprogramming magic happy.
277277
template <>
278-
class OrtTensor<std::string> : public Tensor<std::string>{
279-
public:
280-
OrtTensor(const OrtW::CustomOpApi& api,
281-
OrtKernelContext& ctx,
282-
size_t indice,
283-
bool is_input) : Tensor<std::string>(std::make_unique<OrtStringTensorStorage>(api, ctx, indice, is_input)),
284-
mem_type_(get_mem_type(api, ctx, indice, is_input)) {}
285-
278+
class OrtTensor<std::string> : public Tensor<std::string> {
279+
public:
280+
OrtTensor(const OrtW::CustomOpApi& custom_op_api,
281+
OrtKernelContext& ctx,
282+
size_t indice,
283+
bool is_input) : Tensor<std::string>(std::make_unique<OrtStringTensorStorage>(custom_op_api, ctx, indice, is_input)),
284+
mem_type_(get_mem_type(custom_op_api, ctx, indice, is_input)) {}
285+
286286
bool IsCpuTensor() const {
287287
return mem_type_ == "Cpu";
288288
}
289289

290-
private:
290+
private:
291291
std::string mem_type_ = "Cpu";
292292
};
293293

294294
template <>
295-
class OrtTensor<std::string_view> : public Tensor<std::string_view>{
296-
public:
297-
OrtTensor(const OrtW::CustomOpApi& api,
298-
OrtKernelContext& ctx,
299-
size_t indice,
300-
bool is_input) : Tensor<std::string_view>(std::make_unique<OrtStringViewTensorStorage>(api, ctx, indice, is_input)),
301-
mem_type_(get_mem_type(api, ctx, indice, is_input)) {}
295+
class OrtTensor<std::string_view> : public Tensor<std::string_view> {
296+
public:
297+
OrtTensor(const OrtW::CustomOpApi& custom_op_api,
298+
OrtKernelContext& ctx,
299+
size_t indice,
300+
bool is_input) : Tensor<std::string_view>(std::make_unique<OrtStringViewTensorStorage>(custom_op_api, ctx, indice, is_input)),
301+
mem_type_(get_mem_type(custom_op_api, ctx, indice, is_input)) {}
302302

303303
bool IsCpuTensor() const {
304304
return mem_type_ == "Cpu";
305305
}
306306

307-
private:
307+
private:
308308
std::string mem_type_ = "Cpu";
309309
};
310310

311311
using TensorPtr = std::unique_ptr<Custom::Arg>;
312312
using TensorPtrs = std::vector<TensorPtr>;
313313

314-
315314
using TensorBasePtr = std::unique_ptr<Custom::TensorBase>;
316315
using TensorBasePtrs = std::vector<TensorBasePtr>;
317316

318317
// Represent variadic input or output
319318
struct Variadic : public Arg {
320-
Variadic(const OrtW::CustomOpApi& api,
319+
Variadic(const OrtW::CustomOpApi& custom_op_api,
321320
OrtKernelContext& ctx,
322321
size_t indice,
323-
bool is_input) : api_(api), ctx_(ctx), indice_(indice), mem_type_(get_mem_type(api, ctx, indice, is_input)) {
322+
bool is_input) : api_(custom_op_api), ctx_(ctx), indice_(indice), mem_type_(get_mem_type(custom_op_api, ctx, indice, is_input)) {
324323
#if ORT_API_VERSION < 14
325324
ORTX_CXX_API_THROW("Variadic input or output only supported after onnxruntime 1.14", ORT_RUNTIME_EXCEPTION);
326325
#endif
327326
if (is_input) {
328-
auto input_count = api.KernelContext_GetInputCount(&ctx_);
327+
auto input_count = api_.KernelContext_GetInputCount(&ctx_);
329328
for (size_t ith_input = 0; ith_input < input_count; ++ith_input) {
330329
auto* const_value = api_.KernelContext_GetInput(&ctx_, ith_input);
331330
auto* info = api_.GetTensorTypeAndShape(const_value);
@@ -334,40 +333,40 @@ struct Variadic : public Arg {
334333
TensorBasePtr tensor;
335334
switch (type) {
336335
case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL:
337-
tensor = std::make_unique<Custom::OrtTensor<bool>>(api, ctx, ith_input, true);
336+
tensor = std::make_unique<Custom::OrtTensor<bool>>(api_, ctx, ith_input, true);
338337
break;
339338
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
340-
tensor = std::make_unique<Custom::OrtTensor<float>>(api, ctx, ith_input, true);
339+
tensor = std::make_unique<Custom::OrtTensor<float>>(api_, ctx, ith_input, true);
341340
break;
342341
case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE:
343-
tensor = std::make_unique<Custom::OrtTensor<double>>(api, ctx, ith_input, true);
342+
tensor = std::make_unique<Custom::OrtTensor<double>>(api_, ctx, ith_input, true);
344343
break;
345344
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8:
346-
tensor = std::make_unique<Custom::OrtTensor<uint8_t>>(api, ctx, ith_input, true);
345+
tensor = std::make_unique<Custom::OrtTensor<uint8_t>>(api_, ctx, ith_input, true);
347346
break;
348347
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8:
349-
tensor = std::make_unique<Custom::OrtTensor<int8_t>>(api, ctx, ith_input, true);
348+
tensor = std::make_unique<Custom::OrtTensor<int8_t>>(api_, ctx, ith_input, true);
350349
break;
351350
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16:
352-
tensor = std::make_unique<Custom::OrtTensor<uint16_t>>(api, ctx, ith_input, true);
351+
tensor = std::make_unique<Custom::OrtTensor<uint16_t>>(api_, ctx, ith_input, true);
353352
break;
354353
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16:
355-
tensor = std::make_unique<Custom::OrtTensor<int16_t>>(api, ctx, ith_input, true);
354+
tensor = std::make_unique<Custom::OrtTensor<int16_t>>(api_, ctx, ith_input, true);
356355
break;
357356
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32:
358-
tensor = std::make_unique<Custom::OrtTensor<uint32_t>>(api, ctx, ith_input, true);
357+
tensor = std::make_unique<Custom::OrtTensor<uint32_t>>(api_, ctx, ith_input, true);
359358
break;
360359
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32:
361-
tensor = std::make_unique<Custom::OrtTensor<int32_t>>(api, ctx, ith_input, true);
360+
tensor = std::make_unique<Custom::OrtTensor<int32_t>>(api_, ctx, ith_input, true);
362361
break;
363362
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64:
364-
tensor = std::make_unique<Custom::OrtTensor<uint64_t>>(api, ctx, ith_input, true);
363+
tensor = std::make_unique<Custom::OrtTensor<uint64_t>>(api_, ctx, ith_input, true);
365364
break;
366365
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64:
367-
tensor = std::make_unique<Custom::OrtTensor<int64_t>>(api, ctx, ith_input, true);
366+
tensor = std::make_unique<Custom::OrtTensor<int64_t>>(api_, ctx, ith_input, true);
368367
break;
369368
case ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING:
370-
tensor = std::make_unique<Custom::OrtTensor<std::string>>(api, ctx, ith_input, true);
369+
tensor = std::make_unique<Custom::OrtTensor<std::string>>(api_, ctx, ith_input, true);
371370
break;
372371
default:
373372
ORTX_CXX_API_THROW("unknow input type", ORT_RUNTIME_EXCEPTION);
@@ -395,7 +394,7 @@ struct Variadic : public Arg {
395394
size_t Size() const {
396395
return tensors_.size();
397396
}
398-
397+
399398
const TensorBasePtr& operator[](size_t indice) const {
400399
return tensors_.at(indice);
401400
}
@@ -412,11 +411,11 @@ struct Variadic : public Arg {
412411

413412
class OrtGraphKernelContext : public KernelContext {
414413
public:
415-
OrtGraphKernelContext(const OrtApi& api, const OrtKernelContext& ctx) : api_(api) {
414+
OrtGraphKernelContext(const OrtApi& ort_api, const OrtKernelContext& ctx) : api_(ort_api) {
416415
OrtMemoryInfo* info;
417-
OrtW::ThrowOnError(api, api.CreateCpuMemoryInfo(OrtAllocatorType::OrtArenaAllocator, OrtMemType::OrtMemTypeDefault, &info));
418-
OrtW::ThrowOnError(api, api.KernelContext_GetAllocator(&ctx, info, &allocator_));
419-
api.ReleaseMemoryInfo(info);
416+
OrtW::ThrowOnError(api_, api_.CreateCpuMemoryInfo(OrtAllocatorType::OrtArenaAllocator, OrtMemType::OrtMemTypeDefault, &info));
417+
OrtW::ThrowOnError(api_, api_.KernelContext_GetAllocator(&ctx, info, &allocator_));
418+
api_.ReleaseMemoryInfo(info);
420419
}
421420

422421
virtual ~OrtGraphKernelContext() {
@@ -458,31 +457,31 @@ class OrtGraphCudaKernelContext : public CUDAKernelContext {
458457
public:
459458
static const int cuda_resource_ver = 1;
460459

461-
OrtGraphCudaKernelContext(const OrtApi& api, const OrtKernelContext& ctx) : api_(api) {
462-
api.KernelContext_GetResource(&ctx, cuda_resource_ver, CudaResource::cuda_handle_t, &cuda_stream_);
460+
OrtGraphCudaKernelContext(const OrtApi& ort_api, const OrtKernelContext& ctx) : api_(ort_api) {
461+
api_.KernelContext_GetResource(&ctx, cuda_resource_ver, CudaResource::cuda_handle_t, &cuda_stream_);
463462
if (!cuda_stream_) {
464463
ORTX_CXX_API_THROW("Failed to fetch cuda stream from context", ORT_RUNTIME_EXCEPTION);
465464
}
466-
api.KernelContext_GetResource(&ctx, cuda_resource_ver, CudaResource::cublas_handle_t, &cublas_);
465+
api_.KernelContext_GetResource(&ctx, cuda_resource_ver, CudaResource::cublas_handle_t, &cublas_);
467466
if (!cublas_) {
468467
ORTX_CXX_API_THROW("Failed to fetch cublas handle from context", ORT_RUNTIME_EXCEPTION);
469468
}
470469
void* resource = nullptr;
471-
OrtStatusPtr result = api.KernelContext_GetResource(&ctx, cuda_resource_ver, CudaResource::device_id_t, &resource);
470+
OrtStatusPtr result = api_.KernelContext_GetResource(&ctx, cuda_resource_ver, CudaResource::device_id_t, &resource);
472471
if (result) {
473472
ORTX_CXX_API_THROW("Failed to fetch device id from context", ORT_RUNTIME_EXCEPTION);
474473
}
475474
memcpy(&device_id_, &resource, sizeof(int));
476475

477476
OrtMemoryInfo* info;
478-
OrtW::ThrowOnError(api, api.CreateCpuMemoryInfo(OrtAllocatorType::OrtArenaAllocator, OrtMemType::OrtMemTypeDefault, &info));
479-
OrtW::ThrowOnError(api, api.KernelContext_GetAllocator(&ctx, info, &cpu_allocator_));
480-
api.ReleaseMemoryInfo(info);
477+
OrtW::ThrowOnError(api_, api_.CreateCpuMemoryInfo(OrtAllocatorType::OrtArenaAllocator, OrtMemType::OrtMemTypeDefault, &info));
478+
OrtW::ThrowOnError(api_, api_.KernelContext_GetAllocator(&ctx, info, &cpu_allocator_));
479+
api_.ReleaseMemoryInfo(info);
481480

482481
OrtMemoryInfo* cuda_mem_info;
483-
OrtW::ThrowOnError(api, api.CreateMemoryInfo("Cuda", OrtAllocatorType::OrtArenaAllocator, device_id_, OrtMemType::OrtMemTypeDefault, &cuda_mem_info));
484-
OrtW::ThrowOnError(api, api.KernelContext_GetAllocator(&ctx, cuda_mem_info, &cuda_allocator_));
485-
api.ReleaseMemoryInfo(cuda_mem_info);
482+
OrtW::ThrowOnError(api_, api_.CreateMemoryInfo("Cuda", OrtAllocatorType::OrtArenaAllocator, device_id_, OrtMemType::OrtMemTypeDefault, &cuda_mem_info));
483+
OrtW::ThrowOnError(api_, api_.KernelContext_GetAllocator(&ctx, cuda_mem_info, &cuda_allocator_));
484+
api_.ReleaseMemoryInfo(cuda_mem_info);
486485
}
487486

488487
virtual ~OrtGraphCudaKernelContext() {
@@ -944,7 +943,7 @@ struct OrtLiteCustomFunc : public OrtLiteCustomOp {
944943

945944
class OrtAttributeReader {
946945
public:
947-
OrtAttributeReader(const OrtApi& api, const OrtKernelInfo& info) : base_kernel_(api, info) {
946+
OrtAttributeReader(const OrtApi& ort_api, const OrtKernelInfo& info) : base_kernel_(ort_api, info) {
948947
}
949948

950949
template <class T>

0 commit comments

Comments
 (0)