@@ -16,19 +16,19 @@ namespace Custom {
1616
1717class OrtKernelContextStorage : public ITensorStorage {
1818 public:
19- OrtKernelContextStorage (const OrtW::CustomOpApi& api ,
19+ OrtKernelContextStorage (const OrtW::CustomOpApi& custom_op_api ,
2020 OrtKernelContext& ctx,
2121 size_t indice,
22- bool is_input) : api_(api ), ctx_(ctx), indice_(indice) {
22+ bool is_input) : api_(custom_op_api ), ctx_(ctx), indice_(indice) {
2323 if (is_input) {
24- auto input_count = api .KernelContext_GetInputCount (&ctx);
24+ auto input_count = api_ .KernelContext_GetInputCount (&ctx);
2525 if (indice >= input_count) {
2626 ORTX_CXX_API_THROW (" invalid indice" , ORT_RUNTIME_EXCEPTION);
2727 }
28- const_value_ = api .KernelContext_GetInput (&ctx, indice);
29- auto * info = api .GetTensorTypeAndShape (const_value_);
30- shape_ = api .GetTensorShape (info);
31- api .ReleaseTensorTypeAndShapeInfo (info);
28+ const_value_ = api_ .KernelContext_GetInput (&ctx, indice);
29+ auto * info = api_ .GetTensorTypeAndShape (const_value_);
30+ shape_ = api_ .GetTensorShape (info);
31+ api_ .ReleaseTensorTypeAndShapeInfo (info);
3232 }
3333 }
3434
@@ -66,18 +66,18 @@ class OrtKernelContextStorage : public ITensorStorage {
6666 std::optional<std::vector<int64_t >> shape_;
6767};
6868
69- static std::string get_mem_type (const OrtW::CustomOpApi& api ,
70- OrtKernelContext& ctx,
71- size_t indice,
72- bool is_input){
69+ static std::string get_mem_type (const OrtW::CustomOpApi& custom_op_api ,
70+ OrtKernelContext& ctx,
71+ size_t indice,
72+ bool is_input) {
7373 std::string output = " Cpu" ;
7474 if (is_input) {
75- const OrtValue* const_value = api .KernelContext_GetInput (&ctx, indice);
75+ const OrtValue* const_value = custom_op_api .KernelContext_GetInput (&ctx, indice);
7676 const OrtMemoryInfo* mem_info = {};
77- api .ThrowOnError (api .GetOrtApi ().GetTensorMemoryInfo (const_value, &mem_info));
77+ custom_op_api .ThrowOnError (custom_op_api .GetOrtApi ().GetTensorMemoryInfo (const_value, &mem_info));
7878 if (mem_info) {
7979 const char * mem_type = nullptr ;
80- api .ThrowOnError (api .GetOrtApi ().MemoryInfoGetName (mem_info, &mem_type));
80+ custom_op_api .ThrowOnError (custom_op_api .GetOrtApi ().MemoryInfoGetName (mem_info, &mem_type));
8181 if (mem_type) {
8282 output = mem_type;
8383 }
@@ -88,29 +88,29 @@ static std::string get_mem_type(const OrtW::CustomOpApi& api,
8888
8989template <typename T>
9090class OrtTensor : public Tensor <T> {
91- public:
92- OrtTensor (const OrtW::CustomOpApi& api ,
93- OrtKernelContext& ctx,
94- size_t indice,
95- bool is_input) : Tensor<T>(std::make_unique<OrtKernelContextStorage>(api , ctx, indice, is_input)),
96- mem_type_ (get_mem_type(api , ctx, indice, is_input)) {
91+ public:
92+ OrtTensor (const OrtW::CustomOpApi& custom_op_api ,
93+ OrtKernelContext& ctx,
94+ size_t indice,
95+ bool is_input) : Tensor<T>(std::make_unique<OrtKernelContextStorage>(custom_op_api , ctx, indice, is_input)),
96+ mem_type_ (get_mem_type(custom_op_api , ctx, indice, is_input)) {
9797 }
9898
9999 bool IsCpuTensor () const {
100100 return mem_type_ == " Cpu" ;
101101 }
102102
103- private:
103+ private:
104104 std::string mem_type_ = " Cpu" ;
105105};
106106
107107class OrtStringTensorStorage : public IStringTensorStorage <std::string> {
108108 public:
109109 using strings = std::vector<std::string>;
110- OrtStringTensorStorage (const OrtW::CustomOpApi& api ,
110+ OrtStringTensorStorage (const OrtW::CustomOpApi& custom_op_api ,
111111 OrtKernelContext& ctx,
112112 size_t indice,
113- bool is_input) : api_(api ), ctx_(ctx), indice_(indice) {
113+ bool is_input) : api_(custom_op_api ), ctx_(ctx), indice_(indice) {
114114 if (is_input) {
115115 auto input_count = api_.KernelContext_GetInputCount (&ctx_);
116116 if (indice >= input_count) {
@@ -197,10 +197,10 @@ class OrtStringTensorStorage : public IStringTensorStorage<std::string> {
197197class OrtStringViewTensorStorage : public IStringTensorStorage <std::string_view> {
198198 public:
199199 using strings = std::vector<std::string_view>;
200- OrtStringViewTensorStorage (const OrtW::CustomOpApi& api ,
200+ OrtStringViewTensorStorage (const OrtW::CustomOpApi& custom_op_api ,
201201 OrtKernelContext& ctx,
202202 size_t indice,
203- bool is_input) : api_(api ), ctx_(ctx), indice_(indice) {
203+ bool is_input) : api_(custom_op_api ), ctx_(ctx), indice_(indice) {
204204 if (is_input) {
205205 auto input_count = api_.KernelContext_GetInputCount (&ctx_);
206206 if (indice >= input_count) {
@@ -275,57 +275,56 @@ class OrtStringViewTensorStorage : public IStringTensorStorage<std::string_view>
275275
276276// to make the metaprogramming magic happy.
277277template <>
278- class OrtTensor <std::string> : public Tensor<std::string>{
279- public:
280- OrtTensor (const OrtW::CustomOpApi& api ,
281- OrtKernelContext& ctx,
282- size_t indice,
283- bool is_input) : Tensor<std::string>(std::make_unique<OrtStringTensorStorage>(api , ctx, indice, is_input)),
284- mem_type_ (get_mem_type(api , ctx, indice, is_input)) {}
285-
278+ class OrtTensor <std::string> : public Tensor<std::string> {
279+ public:
280+ OrtTensor (const OrtW::CustomOpApi& custom_op_api ,
281+ OrtKernelContext& ctx,
282+ size_t indice,
283+ bool is_input) : Tensor<std::string>(std::make_unique<OrtStringTensorStorage>(custom_op_api , ctx, indice, is_input)),
284+ mem_type_ (get_mem_type(custom_op_api , ctx, indice, is_input)) {}
285+
286286 bool IsCpuTensor () const {
287287 return mem_type_ == " Cpu" ;
288288 }
289289
290- private:
290+ private:
291291 std::string mem_type_ = " Cpu" ;
292292};
293293
294294template <>
295- class OrtTensor <std::string_view> : public Tensor<std::string_view>{
296- public:
297- OrtTensor (const OrtW::CustomOpApi& api ,
298- OrtKernelContext& ctx,
299- size_t indice,
300- bool is_input) : Tensor<std::string_view>(std::make_unique<OrtStringViewTensorStorage>(api , ctx, indice, is_input)),
301- mem_type_ (get_mem_type(api , ctx, indice, is_input)) {}
295+ class OrtTensor <std::string_view> : public Tensor<std::string_view> {
296+ public:
297+ OrtTensor (const OrtW::CustomOpApi& custom_op_api ,
298+ OrtKernelContext& ctx,
299+ size_t indice,
300+ bool is_input) : Tensor<std::string_view>(std::make_unique<OrtStringViewTensorStorage>(custom_op_api , ctx, indice, is_input)),
301+ mem_type_ (get_mem_type(custom_op_api , ctx, indice, is_input)) {}
302302
303303 bool IsCpuTensor () const {
304304 return mem_type_ == " Cpu" ;
305305 }
306306
307- private:
307+ private:
308308 std::string mem_type_ = " Cpu" ;
309309};
310310
311311using TensorPtr = std::unique_ptr<Custom::Arg>;
312312using TensorPtrs = std::vector<TensorPtr>;
313313
314-
315314using TensorBasePtr = std::unique_ptr<Custom::TensorBase>;
316315using TensorBasePtrs = std::vector<TensorBasePtr>;
317316
318317// Represent variadic input or output
319318struct Variadic : public Arg {
320- Variadic (const OrtW::CustomOpApi& api ,
319+ Variadic (const OrtW::CustomOpApi& custom_op_api ,
321320 OrtKernelContext& ctx,
322321 size_t indice,
323- bool is_input) : api_(api ), ctx_(ctx), indice_(indice), mem_type_(get_mem_type(api , ctx, indice, is_input)) {
322+ bool is_input) : api_(custom_op_api ), ctx_(ctx), indice_(indice), mem_type_(get_mem_type(custom_op_api , ctx, indice, is_input)) {
324323#if ORT_API_VERSION < 14
325324 ORTX_CXX_API_THROW (" Variadic input or output only supported after onnxruntime 1.14" , ORT_RUNTIME_EXCEPTION);
326325#endif
327326 if (is_input) {
328- auto input_count = api .KernelContext_GetInputCount (&ctx_);
327+ auto input_count = api_ .KernelContext_GetInputCount (&ctx_);
329328 for (size_t ith_input = 0 ; ith_input < input_count; ++ith_input) {
330329 auto * const_value = api_.KernelContext_GetInput (&ctx_, ith_input);
331330 auto * info = api_.GetTensorTypeAndShape (const_value);
@@ -334,40 +333,40 @@ struct Variadic : public Arg {
334333 TensorBasePtr tensor;
335334 switch (type) {
336335 case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL:
337- tensor = std::make_unique<Custom::OrtTensor<bool >>(api , ctx, ith_input, true );
336+ tensor = std::make_unique<Custom::OrtTensor<bool >>(api_ , ctx, ith_input, true );
338337 break ;
339338 case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
340- tensor = std::make_unique<Custom::OrtTensor<float >>(api , ctx, ith_input, true );
339+ tensor = std::make_unique<Custom::OrtTensor<float >>(api_ , ctx, ith_input, true );
341340 break ;
342341 case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE:
343- tensor = std::make_unique<Custom::OrtTensor<double >>(api , ctx, ith_input, true );
342+ tensor = std::make_unique<Custom::OrtTensor<double >>(api_ , ctx, ith_input, true );
344343 break ;
345344 case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8:
346- tensor = std::make_unique<Custom::OrtTensor<uint8_t >>(api , ctx, ith_input, true );
345+ tensor = std::make_unique<Custom::OrtTensor<uint8_t >>(api_ , ctx, ith_input, true );
347346 break ;
348347 case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8:
349- tensor = std::make_unique<Custom::OrtTensor<int8_t >>(api , ctx, ith_input, true );
348+ tensor = std::make_unique<Custom::OrtTensor<int8_t >>(api_ , ctx, ith_input, true );
350349 break ;
351350 case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16:
352- tensor = std::make_unique<Custom::OrtTensor<uint16_t >>(api , ctx, ith_input, true );
351+ tensor = std::make_unique<Custom::OrtTensor<uint16_t >>(api_ , ctx, ith_input, true );
353352 break ;
354353 case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16:
355- tensor = std::make_unique<Custom::OrtTensor<int16_t >>(api , ctx, ith_input, true );
354+ tensor = std::make_unique<Custom::OrtTensor<int16_t >>(api_ , ctx, ith_input, true );
356355 break ;
357356 case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32:
358- tensor = std::make_unique<Custom::OrtTensor<uint32_t >>(api , ctx, ith_input, true );
357+ tensor = std::make_unique<Custom::OrtTensor<uint32_t >>(api_ , ctx, ith_input, true );
359358 break ;
360359 case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32:
361- tensor = std::make_unique<Custom::OrtTensor<int32_t >>(api , ctx, ith_input, true );
360+ tensor = std::make_unique<Custom::OrtTensor<int32_t >>(api_ , ctx, ith_input, true );
362361 break ;
363362 case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64:
364- tensor = std::make_unique<Custom::OrtTensor<uint64_t >>(api , ctx, ith_input, true );
363+ tensor = std::make_unique<Custom::OrtTensor<uint64_t >>(api_ , ctx, ith_input, true );
365364 break ;
366365 case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64:
367- tensor = std::make_unique<Custom::OrtTensor<int64_t >>(api , ctx, ith_input, true );
366+ tensor = std::make_unique<Custom::OrtTensor<int64_t >>(api_ , ctx, ith_input, true );
368367 break ;
369368 case ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING:
370- tensor = std::make_unique<Custom::OrtTensor<std::string>>(api , ctx, ith_input, true );
369+ tensor = std::make_unique<Custom::OrtTensor<std::string>>(api_ , ctx, ith_input, true );
371370 break ;
372371 default :
373372 ORTX_CXX_API_THROW (" unknow input type" , ORT_RUNTIME_EXCEPTION);
@@ -395,7 +394,7 @@ struct Variadic : public Arg {
395394 size_t Size () const {
396395 return tensors_.size ();
397396 }
398-
397+
399398 const TensorBasePtr& operator [](size_t indice) const {
400399 return tensors_.at (indice);
401400 }
@@ -412,11 +411,11 @@ struct Variadic : public Arg {
412411
413412class OrtGraphKernelContext : public KernelContext {
414413 public:
415- OrtGraphKernelContext (const OrtApi& api , const OrtKernelContext& ctx) : api_(api ) {
414+ OrtGraphKernelContext (const OrtApi& ort_api , const OrtKernelContext& ctx) : api_(ort_api ) {
416415 OrtMemoryInfo* info;
417- OrtW::ThrowOnError (api, api .CreateCpuMemoryInfo (OrtAllocatorType::OrtArenaAllocator, OrtMemType::OrtMemTypeDefault, &info));
418- OrtW::ThrowOnError (api, api .KernelContext_GetAllocator (&ctx, info, &allocator_));
419- api .ReleaseMemoryInfo (info);
416+ OrtW::ThrowOnError (api_, api_ .CreateCpuMemoryInfo (OrtAllocatorType::OrtArenaAllocator, OrtMemType::OrtMemTypeDefault, &info));
417+ OrtW::ThrowOnError (api_, api_ .KernelContext_GetAllocator (&ctx, info, &allocator_));
418+ api_ .ReleaseMemoryInfo (info);
420419 }
421420
422421 virtual ~OrtGraphKernelContext () {
@@ -458,31 +457,31 @@ class OrtGraphCudaKernelContext : public CUDAKernelContext {
458457 public:
459458 static const int cuda_resource_ver = 1 ;
460459
461- OrtGraphCudaKernelContext (const OrtApi& api , const OrtKernelContext& ctx) : api_(api ) {
462- api .KernelContext_GetResource (&ctx, cuda_resource_ver, CudaResource::cuda_handle_t , &cuda_stream_);
460+ OrtGraphCudaKernelContext (const OrtApi& ort_api , const OrtKernelContext& ctx) : api_(ort_api ) {
461+ api_ .KernelContext_GetResource (&ctx, cuda_resource_ver, CudaResource::cuda_handle_t , &cuda_stream_);
463462 if (!cuda_stream_) {
464463 ORTX_CXX_API_THROW (" Failed to fetch cuda stream from context" , ORT_RUNTIME_EXCEPTION);
465464 }
466- api .KernelContext_GetResource (&ctx, cuda_resource_ver, CudaResource::cublas_handle_t , &cublas_);
465+ api_ .KernelContext_GetResource (&ctx, cuda_resource_ver, CudaResource::cublas_handle_t , &cublas_);
467466 if (!cublas_) {
468467 ORTX_CXX_API_THROW (" Failed to fetch cublas handle from context" , ORT_RUNTIME_EXCEPTION);
469468 }
470469 void * resource = nullptr ;
471- OrtStatusPtr result = api .KernelContext_GetResource (&ctx, cuda_resource_ver, CudaResource::device_id_t , &resource);
470+ OrtStatusPtr result = api_ .KernelContext_GetResource (&ctx, cuda_resource_ver, CudaResource::device_id_t , &resource);
472471 if (result) {
473472 ORTX_CXX_API_THROW (" Failed to fetch device id from context" , ORT_RUNTIME_EXCEPTION);
474473 }
475474 memcpy (&device_id_, &resource, sizeof (int ));
476475
477476 OrtMemoryInfo* info;
478- OrtW::ThrowOnError (api, api .CreateCpuMemoryInfo (OrtAllocatorType::OrtArenaAllocator, OrtMemType::OrtMemTypeDefault, &info));
479- OrtW::ThrowOnError (api, api .KernelContext_GetAllocator (&ctx, info, &cpu_allocator_));
480- api .ReleaseMemoryInfo (info);
477+ OrtW::ThrowOnError (api_, api_ .CreateCpuMemoryInfo (OrtAllocatorType::OrtArenaAllocator, OrtMemType::OrtMemTypeDefault, &info));
478+ OrtW::ThrowOnError (api_, api_ .KernelContext_GetAllocator (&ctx, info, &cpu_allocator_));
479+ api_ .ReleaseMemoryInfo (info);
481480
482481 OrtMemoryInfo* cuda_mem_info;
483- OrtW::ThrowOnError (api, api .CreateMemoryInfo (" Cuda" , OrtAllocatorType::OrtArenaAllocator, device_id_, OrtMemType::OrtMemTypeDefault, &cuda_mem_info));
484- OrtW::ThrowOnError (api, api .KernelContext_GetAllocator (&ctx, cuda_mem_info, &cuda_allocator_));
485- api .ReleaseMemoryInfo (cuda_mem_info);
482+ OrtW::ThrowOnError (api_, api_ .CreateMemoryInfo (" Cuda" , OrtAllocatorType::OrtArenaAllocator, device_id_, OrtMemType::OrtMemTypeDefault, &cuda_mem_info));
483+ OrtW::ThrowOnError (api_, api_ .KernelContext_GetAllocator (&ctx, cuda_mem_info, &cuda_allocator_));
484+ api_ .ReleaseMemoryInfo (cuda_mem_info);
486485 }
487486
488487 virtual ~OrtGraphCudaKernelContext () {
@@ -944,7 +943,7 @@ struct OrtLiteCustomFunc : public OrtLiteCustomOp {
944943
945944class OrtAttributeReader {
946945 public:
947- OrtAttributeReader (const OrtApi& api , const OrtKernelInfo& info) : base_kernel_(api , info) {
946+ OrtAttributeReader (const OrtApi& ort_api , const OrtKernelInfo& info) : base_kernel_(ort_api , info) {
948947 }
949948
950949 template <class T >
0 commit comments