77#define ORT_API_MANUAL_INIT
88#include " onnxruntime_cxx_api.h"
99
10- constexpr const char * CUDA_ALLOCATOR = " Cuda" ;
11- constexpr const char * CUDA_PINNED_ALLOCATOR = " CudaPinned" ;
12-
1310using DeviceId = int16_t ;
1411
1512struct CUDAAllocator : OrtAllocator {
@@ -41,17 +38,11 @@ struct CUDAAllocator : OrtAllocator {
4138};
4239
4340struct CUDAPinnedAllocator : OrtAllocator {
44- CUDAPinnedAllocator (const char * name = CUDA_PINNED_ALLOCATOR ) {
41+ CUDAPinnedAllocator (const OrtMemoryInfo* mem_info) : mem_info_(mem_info ) {
4542 OrtAllocator::version = ORT_API_VERSION;
4643 OrtAllocator::Alloc = [](OrtAllocator* this_, size_t size) { return static_cast <CUDAPinnedAllocator*>(this_)->Alloc (size); };
4744 OrtAllocator::Free = [](OrtAllocator* this_, void * p) { static_cast <CUDAPinnedAllocator*>(this_)->Free (p); };
4845 OrtAllocator::Info = [](const OrtAllocator* this_) { return static_cast <const CUDAPinnedAllocator*>(this_)->Info (); };
49- const OrtApi* api = OrtGetApiBase ()->GetApi (ORT_API_VERSION);
50- api->CreateMemoryInfo (name,
51- OrtAllocatorType::OrtDeviceAllocator,
52- 0 /* CPU device always with id 0 */ ,
53- OrtMemType::OrtMemTypeDefault,
54- &mem_info_);
5546 }
5647 // TODO: Handle destructor
5748 // ~CUDAPinnedAllocator();
@@ -67,5 +58,5 @@ struct CUDAPinnedAllocator : OrtAllocator {
6758 CUDAPinnedAllocator& operator =(const CUDAPinnedAllocator&) = delete ;
6859
6960 DeviceId device_id_ = 0 ;
70- OrtMemoryInfo* mem_info_ = nullptr ;
61+ const OrtMemoryInfo* mem_info_ = nullptr ;
7162};
0 commit comments