Skip to content

Commit 6fd38c3

Browse files
committed
refactor memory info stored in factory
1 parent da729f9 commit 6fd38c3

File tree

3 files changed

+61
-102
lines changed

3 files changed

+61
-102
lines changed

plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2327,7 +2327,11 @@ OrtStatus* TRTEpNodeComputeInfo::ComputeImpl(OrtNodeComputeInfo* this_ptr, void*
23272327

23282328
// Get default OrtMemoryInfo from factory
23292329
// Get allocator from OrtKernelContext
2330-
const OrtMemoryInfo* mem_info = ep.factory_.GetDefaultGpuMemInfoForDeviceId(device_id);
2330+
const OrtMemoryInfo* mem_info = nullptr;
2331+
if (ep.factory_.device_id_to_cuda_gpu_memory_info_map.find(device_id) !=
2332+
ep.factory_.device_id_to_cuda_gpu_memory_info_map.end()) {
2333+
mem_info = ep.factory_.device_id_to_cuda_gpu_memory_info_map[device_id];
2334+
}
23312335
OrtAllocator* alloc = nullptr;
23322336
ep.GetAllocator(&alloc);
23332337
if (alloc == nullptr) {

plugin_execution_providers/tensorrt/tensorrt_provider_factory.cc

Lines changed: 47 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,32 @@ const char* ORT_API_CALL TensorrtExecutionProviderFactory::GetVersionImpl(const
4444
return factory->ep_version_.c_str();
4545
}
4646

47+
OrtStatus* TensorrtExecutionProviderFactory::CreateMemoryInfoForDevices(int num_devices) {
48+
cuda_gpu_memory_infos.reserve(num_devices);
49+
cuda_pinned_memory_infos.reserve(num_devices);
50+
51+
for (int device_id = 0; device_id < num_devices; ++device_id) {
52+
OrtMemoryInfo* mem_info = nullptr;
53+
RETURN_IF_ERROR(ort_api.CreateMemoryInfo_V2("Cuda", OrtMemoryInfoDeviceType_GPU,
54+
/*vendor OrtDevice::VendorIds::NVIDIA*/ 0x10DE,
55+
/* device_id */ device_id, OrtDeviceMemoryType_DEFAULT,
56+
/*alignment*/ 0, OrtAllocatorType::OrtDeviceAllocator, &mem_info));
57+
58+
cuda_gpu_memory_infos.emplace_back(MemoryInfoUniquePtr(mem_info, ort_api.ReleaseMemoryInfo));
59+
60+
// HOST_ACCESSIBLE memory should use the non-CPU device type
61+
mem_info = nullptr;
62+
RETURN_IF_ERROR(ort_api.CreateMemoryInfo_V2("CudaPinned", OrtMemoryInfoDeviceType_GPU,
63+
/*vendor OrtDevice::VendorIds::NVIDIA*/ 0x10DE,
64+
/* device_id */ device_id, OrtDeviceMemoryType_HOST_ACCESSIBLE,
65+
/*alignment*/ 0, OrtAllocatorType::OrtDeviceAllocator, &mem_info));
66+
67+
cuda_pinned_memory_infos.emplace_back(MemoryInfoUniquePtr(mem_info, ort_api.ReleaseMemoryInfo));
68+
}
69+
70+
return nullptr;
71+
}
72+
4773
OrtStatus* ORT_API_CALL TensorrtExecutionProviderFactory::GetSupportedDevicesImpl(
4874
OrtEpFactory* this_ptr,
4975
const OrtHardwareDevice* const* devices,
@@ -54,18 +80,24 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProviderFactory::GetSupportedDevicesImp
5480
size_t& num_ep_devices = *p_num_ep_devices;
5581
auto* factory = static_cast<TensorrtExecutionProviderFactory*>(this_ptr);
5682

83+
int num_cuda_devices = 0;
84+
cudaGetDeviceCount(&num_cuda_devices);
85+
RETURN_IF_ERROR(factory->CreateMemoryInfoForDevices(num_cuda_devices));
86+
5787
std::vector<const OrtMemoryDevice*> cuda_gpu_mem_devices;
5888
std::vector<const OrtMemoryDevice*> cuda_pinned_mem_devices;
59-
int GPU_cnt = 0;
89+
int32_t device_id = 0;
6090

6191
for (size_t i = 0; i < num_devices && num_ep_devices < max_ep_devices; ++i) {
6292
// C API
6393
const OrtHardwareDevice& device = *devices[i];
6494
if (factory->ort_api.HardwareDevice_Type(&device) == OrtHardwareDeviceType::OrtHardwareDeviceType_GPU) {
65-
if (GPU_cnt > 0) {
95+
96+
// workaround for duplicate devices when using remote desktop.
97+
if (device_id > 0) {
6698
continue;
6799
}
68-
GPU_cnt++;
100+
69101
// These can be returned as nullptr if you have nothing to add.
70102
OrtKeyValuePairs* ep_metadata = nullptr;
71103
OrtKeyValuePairs* ep_options = nullptr;
@@ -89,39 +121,19 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProviderFactory::GetSupportedDevicesImp
89121
return status;
90122
}
91123

92-
uint32_t vendor_id = factory->ort_api.HardwareDevice_VendorId(&device);
93-
//uint32_t device_id = factory->ort_api.HardwareDevice_DeviceId(&device);
94-
uint32_t device_id = 0;
95-
96-
// CUDA allocator OrtMemoryInfo
97-
OrtMemoryInfo* mem_info = nullptr;
98-
status = factory->ort_api.CreateMemoryInfo_V2("Cuda", OrtMemoryInfoDeviceType_GPU, vendor_id, device_id, OrtDeviceMemoryType_DEFAULT,
99-
/*alignment*/ 0, OrtAllocatorType::OrtDeviceAllocator, &mem_info);
100-
101-
assert(status == nullptr); // should never fail.
102-
MemoryInfoUniquePtr cuda_gpu_memory_info = MemoryInfoUniquePtr(mem_info, factory->ort_api.ReleaseMemoryInfo);
103-
104-
// CUDA PINNED allocator OrtMemoryInfo
105-
// HOST_ACCESSIBLE memory should use the non-CPU device type.
106-
mem_info = nullptr;
107-
status = factory->ort_api.CreateMemoryInfo_V2("CudaPinned", OrtMemoryInfoDeviceType_GPU, vendor_id, device_id, OrtDeviceMemoryType_HOST_ACCESSIBLE,
108-
/*alignment*/ 0, OrtAllocatorType::OrtDeviceAllocator, &mem_info);
109-
110-
assert(status == nullptr); // should never fail.
111-
MemoryInfoUniquePtr cuda_pinned_memory_info = MemoryInfoUniquePtr(mem_info, factory->ort_api.ReleaseMemoryInfo);
124+
const OrtMemoryInfo* cuda_gpu_mem_info = factory->cuda_gpu_memory_infos[device_id].get();
125+
const OrtMemoryInfo* cuda_pinned_mem_info = factory->cuda_pinned_memory_infos[device_id].get();
112126

113127
// Register the allocator info required by TRT EP.
114-
RETURN_IF_ERROR(factory->ep_api.EpDevice_AddAllocatorInfo(ep_device, cuda_gpu_memory_info.get()));
115-
RETURN_IF_ERROR(factory->ep_api.EpDevice_AddAllocatorInfo(ep_device, cuda_pinned_memory_info.get()));
128+
RETURN_IF_ERROR(factory->ep_api.EpDevice_AddAllocatorInfo(ep_device, cuda_gpu_mem_info));
129+
RETURN_IF_ERROR(factory->ep_api.EpDevice_AddAllocatorInfo(ep_device, cuda_pinned_mem_info));
116130

117131
// Get memory device from memory info for gpu data transfer
118-
cuda_gpu_mem_devices.push_back(factory->ep_api.MemoryInfo_GetMemoryDevice(cuda_gpu_memory_info.get()));
119-
cuda_pinned_mem_devices.push_back(factory->ep_api.MemoryInfo_GetMemoryDevice(cuda_pinned_memory_info.get()));
120-
121-
factory->SetDefaultGpuMemInfo(std::move(cuda_gpu_memory_info), device_id);
122-
factory->SetHostAccessibleMemInfo(std::move(cuda_pinned_memory_info), device_id);
132+
cuda_gpu_mem_devices.push_back(factory->ep_api.MemoryInfo_GetMemoryDevice(cuda_gpu_mem_info));
133+
cuda_pinned_mem_devices.push_back(factory->ep_api.MemoryInfo_GetMemoryDevice(cuda_pinned_mem_info));
123134

124135
ep_devices[num_ep_devices++] = ep_device;
136+
++device_id;
125137
}
126138

127139
// C++ API equivalent. Throws on error.
@@ -202,13 +214,15 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProviderFactory::CreateAllocatorImpl(
202214
// NOTE: The OrtMemoryInfo pointer should only ever be coming straight from an OrtEpDevice, and pointer based
203215
// matching should work.
204216

205-
uint32_t device_id = 0;
217+
const OrtMemoryDevice* mem_device = factory.ep_api.MemoryInfo_GetMemoryDevice(memory_info);
218+
uint32_t device_id = factory.ep_api.MemoryDevice_GetDeviceId(mem_device);
206219

207-
if (factory.GetDeviceIdForDefaultGpuMemInfo(memory_info, &device_id)) {
220+
if (factory.ep_api.MemoryDevice_GetMemoryType(mem_device) == OrtDeviceMemoryType_DEFAULT) {
208221
// create a CUDA allocator
209222
auto cuda_allocator = std::make_unique<CUDAAllocator>(memory_info, static_cast<DeviceId>(device_id));
223+
factory.device_id_to_cuda_gpu_memory_info_map[device_id] = memory_info;
210224
*allocator = cuda_allocator.release();
211-
} else if (factory.GetDeviceIdForHostAccessibleMemInfo(memory_info, &device_id)) {
225+
} else if (factory.ep_api.MemoryDevice_GetMemoryType(mem_device) == OrtDeviceMemoryType_HOST_ACCESSIBLE) {
212226
// create a CUDA PINNED allocator
213227
auto cuda_pinned_allocator = std::make_unique<CUDAPinnedAllocator>(memory_info);
214228
*allocator = cuda_pinned_allocator.release();
@@ -235,52 +249,6 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProviderFactory::CreateDataTransferImpl
235249
return nullptr;
236250
}
237251

238-
bool TensorrtExecutionProviderFactory::GetDeviceIdForDefaultGpuMemInfo(const OrtMemoryInfo* mem_info, uint32_t* device_id) const {
239-
auto iter = cuda_gpu_memory_info_to_device_id_map_.find(mem_info);
240-
if (iter != cuda_gpu_memory_info_to_device_id_map_.end()) {
241-
*device_id = iter->second;
242-
return true;
243-
}
244-
return false;
245-
}
246-
247-
const OrtMemoryInfo* TensorrtExecutionProviderFactory::GetDefaultGpuMemInfoForDeviceId(uint32_t device_id) const {
248-
auto iter = device_id_to_cuda_gpu_memory_info_map_.find(device_id);
249-
if (iter != device_id_to_cuda_gpu_memory_info_map_.end()) {
250-
return iter->second;
251-
}
252-
return nullptr;
253-
}
254-
255-
void TensorrtExecutionProviderFactory::SetDefaultGpuMemInfo(MemoryInfoUniquePtr mem_info, uint32_t device_id) {
256-
cuda_gpu_memory_info_to_device_id_map_[mem_info.get()] = device_id;
257-
device_id_to_cuda_gpu_memory_info_map_[device_id] = mem_info.get();
258-
cuda_gpu_memory_infos_.push_back(std::move(mem_info));
259-
}
260-
261-
bool TensorrtExecutionProviderFactory::GetDeviceIdForHostAccessibleMemInfo(const OrtMemoryInfo* mem_info, uint32_t* device_id) const {
262-
auto iter = cuda_pinned_memory_info_to_device_id_map_.find(mem_info);
263-
if (iter != cuda_pinned_memory_info_to_device_id_map_.end()) {
264-
*device_id = iter->second;
265-
return true;
266-
}
267-
return false;
268-
}
269-
270-
const OrtMemoryInfo* TensorrtExecutionProviderFactory::GetHostAccessibleMemInfoForDeviceId(uint32_t device_id) const {
271-
auto iter = device_id_to_cuda_pinned_memory_info_map_.find(device_id);
272-
if (iter != device_id_to_cuda_pinned_memory_info_map_.end()) {
273-
return iter->second;
274-
}
275-
return nullptr;
276-
}
277-
278-
void TensorrtExecutionProviderFactory::SetHostAccessibleMemInfo(MemoryInfoUniquePtr mem_info, uint32_t device_id) {
279-
cuda_pinned_memory_info_to_device_id_map_[mem_info.get()] = device_id;
280-
device_id_to_cuda_pinned_memory_info_map_[device_id] = mem_info.get();
281-
cuda_pinned_memory_infos_.push_back(std::move(mem_info));
282-
}
283-
284252
void TensorrtExecutionProviderFactory::SetGPUDataTransfer(std::unique_ptr<TRTEpDataTransfer> gpu_data_transfer) {
285253
data_transfer_impl_ = std::move(gpu_data_transfer);
286254
}

plugin_execution_providers/tensorrt/tensorrt_provider_factory.h

Lines changed: 9 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,15 @@ struct TensorrtExecutionProviderFactory : public OrtEpFactory, public ApiPtrs {
1616

1717
const OrtMemoryInfo* GetHostAccessibleMemInfoForDeviceId(uint32_t device_id) const;
1818

19+
OrtStatus* CreateMemoryInfoForDevices(int num_devices);
20+
21+
// CUDA gpu memory and CUDA pinned memory are required for allocator and data transfer, these are the OrtMemoryInfo
22+
// instance required for that.
23+
// Current TRT EP implementation uses one default OrtMemoryInfo and one host accessible OrtMemoryInfo per ep device.
24+
std::vector<MemoryInfoUniquePtr> cuda_gpu_memory_infos;
25+
std::vector<MemoryInfoUniquePtr> cuda_pinned_memory_infos;
26+
std::unordered_map<uint32_t, const OrtMemoryInfo*> device_id_to_cuda_gpu_memory_info_map; // device id -> OrtMemoryInfo
27+
1928
private:
2029
static const char* ORT_API_CALL GetNameImpl(const OrtEpFactory* this_ptr) noexcept;
2130

@@ -44,33 +53,11 @@ struct TensorrtExecutionProviderFactory : public OrtEpFactory, public ApiPtrs {
4453
static OrtStatus* ORT_API_CALL CreateDataTransferImpl(OrtEpFactory* this_ptr,
4554
OrtDataTransferImpl** data_transfer) noexcept;
4655

47-
bool GetDeviceIdForDefaultGpuMemInfo(const OrtMemoryInfo* mem_info, uint32_t* device_id) const;
48-
49-
void SetDefaultGpuMemInfo(MemoryInfoUniquePtr mem_info, uint32_t device_id);
50-
51-
bool GetDeviceIdForHostAccessibleMemInfo(const OrtMemoryInfo* mem_info, uint32_t* device_id) const;
52-
53-
void SetHostAccessibleMemInfo(MemoryInfoUniquePtr mem_info, uint32_t device_id);
54-
5556
void SetGPUDataTransfer(std::unique_ptr<TRTEpDataTransfer> gpu_data_transfer);
5657

5758
const std::string ep_name_; // EP name
5859
const std::string vendor_{"Nvidia"}; // EP vendor name
5960
const std::string ep_version_{"0.1.0"}; // EP version
6061

61-
// OrtMemoryInfo for allocators and data transfer.
62-
63-
// CUDA gpu memory and CUDA pinned memory are required for allocator and data transfer, these are the OrtMemoryInfo instance required for that.
64-
// Current TRT EP implementation uses one default OrtMemoryInfo and one host accessible OrtMemoryInfo per ep device.
65-
std::unordered_map<const OrtMemoryInfo*, uint32_t> cuda_gpu_memory_info_to_device_id_map_; // OrtMemoryInfo -> device id
66-
std::unordered_map<const OrtMemoryInfo*, uint32_t> cuda_pinned_memory_info_to_device_id_map_;
67-
std::unordered_map<uint32_t, const OrtMemoryInfo*> device_id_to_cuda_gpu_memory_info_map_; // device id -> OrtMemoryInfo
68-
std::unordered_map<uint32_t, const OrtMemoryInfo*> device_id_to_cuda_pinned_memory_info_map_;
69-
std::vector<MemoryInfoUniquePtr> cuda_gpu_memory_infos_;
70-
std::vector<MemoryInfoUniquePtr> cuda_pinned_memory_infos_;
71-
72-
// CPU allocator so we can control the arena behavior. optional as ORT always provides a CPU allocator if needed.
73-
// MemoryInfoUniquePtr cpu_memory_info_;
74-
7562
std::unique_ptr<TRTEpDataTransfer> data_transfer_impl_; // data transfer implementation for this factory
7663
};

0 commit comments

Comments
 (0)