Skip to content

Commit f443a33

Browse files
committed
update cuda/pinned allocator to make compiler happy
1 parent 30e0f91 commit f443a33

File tree

2 files changed

+3
-12
lines changed

2 files changed

+3
-12
lines changed

plugin_execution_providers/tensorrt/cuda_allocator.h

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,6 @@
77
#define ORT_API_MANUAL_INIT
88
#include "onnxruntime_cxx_api.h"
99

10-
constexpr const char* CUDA_ALLOCATOR = "Cuda";
11-
constexpr const char* CUDA_PINNED_ALLOCATOR = "CudaPinned";
12-
1310
using DeviceId = int16_t;
1411

1512
struct CUDAAllocator : OrtAllocator {
@@ -41,17 +38,11 @@ struct CUDAAllocator : OrtAllocator {
4138
};
4239

4340
struct CUDAPinnedAllocator : OrtAllocator {
44-
CUDAPinnedAllocator(const char* name = CUDA_PINNED_ALLOCATOR) {
41+
CUDAPinnedAllocator(const OrtMemoryInfo* mem_info) : mem_info_(mem_info) {
4542
OrtAllocator::version = ORT_API_VERSION;
4643
OrtAllocator::Alloc = [](OrtAllocator* this_, size_t size) { return static_cast<CUDAPinnedAllocator*>(this_)->Alloc(size); };
4744
OrtAllocator::Free = [](OrtAllocator* this_, void* p) { static_cast<CUDAPinnedAllocator*>(this_)->Free(p); };
4845
OrtAllocator::Info = [](const OrtAllocator* this_) { return static_cast<const CUDAPinnedAllocator*>(this_)->Info(); };
49-
const OrtApi* api = OrtGetApiBase()->GetApi(ORT_API_VERSION);
50-
api->CreateMemoryInfo(name,
51-
OrtAllocatorType::OrtDeviceAllocator,
52-
0 /* CPU device always with id 0 */,
53-
OrtMemType::OrtMemTypeDefault,
54-
&mem_info_);
5546
}
5647
// TODO: Handle destructor
5748
//~CUDAPinnedAllocator();
@@ -67,5 +58,5 @@ struct CUDAPinnedAllocator : OrtAllocator {
6758
CUDAPinnedAllocator& operator=(const CUDAPinnedAllocator&) = delete;
6859

6960
DeviceId device_id_ = 0;
70-
OrtMemoryInfo* mem_info_ = nullptr;
61+
const OrtMemoryInfo* mem_info_ = nullptr;
7162
};

plugin_execution_providers/tensorrt/tensorrt_provider_factory.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProviderFactory::CreateAllocatorImpl(
198198

199199
if (factory.GetDeviceIdForDefaultGpuMemInfo(memory_info, &device_id)) {
200200
// create a CUDA allocator
201-
auto cuda_allocator = std::make_unique<CUDAAllocator>(memory_info, static_cast<uint16_t>(device_id));
201+
auto cuda_allocator = std::make_unique<CUDAAllocator>(memory_info, static_cast<DeviceId>(device_id));
202202
*allocator = cuda_allocator.release();
203203
} else if (factory.GetDeviceIdForHostAccessibleMemInfo(memory_info, &device_id)) {
204204
// create a CUDA PINNED allocator

0 commit comments

Comments
 (0)