@@ -44,6 +44,32 @@ const char* ORT_API_CALL TensorrtExecutionProviderFactory::GetVersionImpl(const
4444 return factory->ep_version_ .c_str ();
4545}
4646
47+ OrtStatus* TensorrtExecutionProviderFactory::CreateMemoryInfoForDevices (int num_devices) {
48+ cuda_gpu_memory_infos.reserve (num_devices);
49+ cuda_pinned_memory_infos.reserve (num_devices);
50+
51+ for (int device_id = 0 ; device_id < num_devices; ++device_id) {
52+ OrtMemoryInfo* mem_info = nullptr ;
53+ RETURN_IF_ERROR (ort_api.CreateMemoryInfo_V2 (" Cuda" , OrtMemoryInfoDeviceType_GPU,
54+ /* vendor OrtDevice::VendorIds::NVIDIA*/ 0x10DE ,
55+ /* device_id */ device_id, OrtDeviceMemoryType_DEFAULT,
56+ /* alignment*/ 0 , OrtAllocatorType::OrtDeviceAllocator, &mem_info));
57+
58+ cuda_gpu_memory_infos.emplace_back (MemoryInfoUniquePtr (mem_info, ort_api.ReleaseMemoryInfo ));
59+
60+ // HOST_ACCESSIBLE memory should use the non-CPU device type
61+ mem_info = nullptr ;
62+ RETURN_IF_ERROR (ort_api.CreateMemoryInfo_V2 (" CudaPinned" , OrtMemoryInfoDeviceType_GPU,
63+ /* vendor OrtDevice::VendorIds::NVIDIA*/ 0x10DE ,
64+ /* device_id */ device_id, OrtDeviceMemoryType_HOST_ACCESSIBLE,
65+ /* alignment*/ 0 , OrtAllocatorType::OrtDeviceAllocator, &mem_info));
66+
67+ cuda_pinned_memory_infos.emplace_back (MemoryInfoUniquePtr (mem_info, ort_api.ReleaseMemoryInfo ));
68+ }
69+
70+ return nullptr ;
71+ }
72+
4773OrtStatus* ORT_API_CALL TensorrtExecutionProviderFactory::GetSupportedDevicesImpl (
4874 OrtEpFactory* this_ptr,
4975 const OrtHardwareDevice* const * devices,
@@ -54,18 +80,24 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProviderFactory::GetSupportedDevicesImp
5480 size_t & num_ep_devices = *p_num_ep_devices;
5581 auto * factory = static_cast <TensorrtExecutionProviderFactory*>(this_ptr);
5682
83+ int num_cuda_devices = 0 ;
84+ cudaGetDeviceCount (&num_cuda_devices);
85+ RETURN_IF_ERROR (factory->CreateMemoryInfoForDevices (num_cuda_devices));
86+
5787 std::vector<const OrtMemoryDevice*> cuda_gpu_mem_devices;
5888 std::vector<const OrtMemoryDevice*> cuda_pinned_mem_devices;
59- int GPU_cnt = 0 ;
89+ int32_t device_id = 0 ;
6090
6191 for (size_t i = 0 ; i < num_devices && num_ep_devices < max_ep_devices; ++i) {
6292 // C API
6393 const OrtHardwareDevice& device = *devices[i];
6494 if (factory->ort_api .HardwareDevice_Type (&device) == OrtHardwareDeviceType::OrtHardwareDeviceType_GPU) {
65- if (GPU_cnt > 0 ) {
95+
96+ // workaround for duplicate devices when using remote desktop.
97+ if (device_id > 0 ) {
6698 continue ;
6799 }
68- GPU_cnt++;
100+
69101 // These can be returned as nullptr if you have nothing to add.
70102 OrtKeyValuePairs* ep_metadata = nullptr ;
71103 OrtKeyValuePairs* ep_options = nullptr ;
@@ -89,39 +121,19 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProviderFactory::GetSupportedDevicesImp
89121 return status;
90122 }
91123
92- uint32_t vendor_id = factory->ort_api .HardwareDevice_VendorId (&device);
93- // uint32_t device_id = factory->ort_api.HardwareDevice_DeviceId(&device);
94- uint32_t device_id = 0 ;
95-
96- // CUDA allocator OrtMemoryInfo
97- OrtMemoryInfo* mem_info = nullptr ;
98- status = factory->ort_api .CreateMemoryInfo_V2 (" Cuda" , OrtMemoryInfoDeviceType_GPU, vendor_id, device_id, OrtDeviceMemoryType_DEFAULT,
99- /* alignment*/ 0 , OrtAllocatorType::OrtDeviceAllocator, &mem_info);
100-
101- assert (status == nullptr ); // should never fail.
102- MemoryInfoUniquePtr cuda_gpu_memory_info = MemoryInfoUniquePtr (mem_info, factory->ort_api .ReleaseMemoryInfo );
103-
104- // CUDA PINNED allocator OrtMemoryInfo
105- // HOST_ACCESSIBLE memory should use the non-CPU device type.
106- mem_info = nullptr ;
107- status = factory->ort_api .CreateMemoryInfo_V2 (" CudaPinned" , OrtMemoryInfoDeviceType_GPU, vendor_id, device_id, OrtDeviceMemoryType_HOST_ACCESSIBLE,
108- /* alignment*/ 0 , OrtAllocatorType::OrtDeviceAllocator, &mem_info);
109-
110- assert (status == nullptr ); // should never fail.
111- MemoryInfoUniquePtr cuda_pinned_memory_info = MemoryInfoUniquePtr (mem_info, factory->ort_api .ReleaseMemoryInfo );
124+ const OrtMemoryInfo* cuda_gpu_mem_info = factory->cuda_gpu_memory_infos [device_id].get ();
125+ const OrtMemoryInfo* cuda_pinned_mem_info = factory->cuda_pinned_memory_infos [device_id].get ();
112126
113127 // Register the allocator info required by TRT EP.
114- RETURN_IF_ERROR (factory->ep_api .EpDevice_AddAllocatorInfo (ep_device, cuda_gpu_memory_info. get () ));
115- RETURN_IF_ERROR (factory->ep_api .EpDevice_AddAllocatorInfo (ep_device, cuda_pinned_memory_info. get () ));
128+ RETURN_IF_ERROR (factory->ep_api .EpDevice_AddAllocatorInfo (ep_device, cuda_gpu_mem_info ));
129+ RETURN_IF_ERROR (factory->ep_api .EpDevice_AddAllocatorInfo (ep_device, cuda_pinned_mem_info ));
116130
117131 // Get memory device from memory info for gpu data transfer
118- cuda_gpu_mem_devices.push_back (factory->ep_api .MemoryInfo_GetMemoryDevice (cuda_gpu_memory_info.get ()));
119- cuda_pinned_mem_devices.push_back (factory->ep_api .MemoryInfo_GetMemoryDevice (cuda_pinned_memory_info.get ()));
120-
121- factory->SetDefaultGpuMemInfo (std::move (cuda_gpu_memory_info), device_id);
122- factory->SetHostAccessibleMemInfo (std::move (cuda_pinned_memory_info), device_id);
132+ cuda_gpu_mem_devices.push_back (factory->ep_api .MemoryInfo_GetMemoryDevice (cuda_gpu_mem_info));
133+ cuda_pinned_mem_devices.push_back (factory->ep_api .MemoryInfo_GetMemoryDevice (cuda_pinned_mem_info));
123134
124135 ep_devices[num_ep_devices++] = ep_device;
136+ ++device_id;
125137 }
126138
127139 // C++ API equivalent. Throws on error.
@@ -202,13 +214,15 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProviderFactory::CreateAllocatorImpl(
202214 // NOTE: The OrtMemoryInfo pointer should only ever be coming straight from an OrtEpDevice, and pointer based
203215 // matching should work.
204216
205- uint32_t device_id = 0 ;
217+ const OrtMemoryDevice* mem_device = factory.ep_api .MemoryInfo_GetMemoryDevice (memory_info);
218+ uint32_t device_id = factory.ep_api .MemoryDevice_GetDeviceId (mem_device);
206219
207- if (factory.GetDeviceIdForDefaultGpuMemInfo (memory_info, &device_id) ) {
220+ if (factory.ep_api . MemoryDevice_GetMemoryType (mem_device) == OrtDeviceMemoryType_DEFAULT ) {
208221 // create a CUDA allocator
209222 auto cuda_allocator = std::make_unique<CUDAAllocator>(memory_info, static_cast <DeviceId>(device_id));
223+ factory.device_id_to_cuda_gpu_memory_info_map [device_id] = memory_info;
210224 *allocator = cuda_allocator.release ();
211- } else if (factory.GetDeviceIdForHostAccessibleMemInfo (memory_info, &device_id) ) {
225+ } else if (factory.ep_api . MemoryDevice_GetMemoryType (mem_device) == OrtDeviceMemoryType_HOST_ACCESSIBLE ) {
212226 // create a CUDA PINNED allocator
213227 auto cuda_pinned_allocator = std::make_unique<CUDAPinnedAllocator>(memory_info);
214228 *allocator = cuda_pinned_allocator.release ();
@@ -235,52 +249,6 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProviderFactory::CreateDataTransferImpl
235249 return nullptr ;
236250}
237251
238- bool TensorrtExecutionProviderFactory::GetDeviceIdForDefaultGpuMemInfo (const OrtMemoryInfo* mem_info, uint32_t * device_id) const {
239- auto iter = cuda_gpu_memory_info_to_device_id_map_.find (mem_info);
240- if (iter != cuda_gpu_memory_info_to_device_id_map_.end ()) {
241- *device_id = iter->second ;
242- return true ;
243- }
244- return false ;
245- }
246-
247- const OrtMemoryInfo* TensorrtExecutionProviderFactory::GetDefaultGpuMemInfoForDeviceId (uint32_t device_id) const {
248- auto iter = device_id_to_cuda_gpu_memory_info_map_.find (device_id);
249- if (iter != device_id_to_cuda_gpu_memory_info_map_.end ()) {
250- return iter->second ;
251- }
252- return nullptr ;
253- }
254-
255- void TensorrtExecutionProviderFactory::SetDefaultGpuMemInfo (MemoryInfoUniquePtr mem_info, uint32_t device_id) {
256- cuda_gpu_memory_info_to_device_id_map_[mem_info.get ()] = device_id;
257- device_id_to_cuda_gpu_memory_info_map_[device_id] = mem_info.get ();
258- cuda_gpu_memory_infos_.push_back (std::move (mem_info));
259- }
260-
261- bool TensorrtExecutionProviderFactory::GetDeviceIdForHostAccessibleMemInfo (const OrtMemoryInfo* mem_info, uint32_t * device_id) const {
262- auto iter = cuda_pinned_memory_info_to_device_id_map_.find (mem_info);
263- if (iter != cuda_pinned_memory_info_to_device_id_map_.end ()) {
264- *device_id = iter->second ;
265- return true ;
266- }
267- return false ;
268- }
269-
270- const OrtMemoryInfo* TensorrtExecutionProviderFactory::GetHostAccessibleMemInfoForDeviceId (uint32_t device_id) const {
271- auto iter = device_id_to_cuda_pinned_memory_info_map_.find (device_id);
272- if (iter != device_id_to_cuda_pinned_memory_info_map_.end ()) {
273- return iter->second ;
274- }
275- return nullptr ;
276- }
277-
278- void TensorrtExecutionProviderFactory::SetHostAccessibleMemInfo (MemoryInfoUniquePtr mem_info, uint32_t device_id) {
279- cuda_pinned_memory_info_to_device_id_map_[mem_info.get ()] = device_id;
280- device_id_to_cuda_pinned_memory_info_map_[device_id] = mem_info.get ();
281- cuda_pinned_memory_infos_.push_back (std::move (mem_info));
282- }
283-
284252void TensorrtExecutionProviderFactory::SetGPUDataTransfer (std::unique_ptr<TRTEpDataTransfer> gpu_data_transfer) {
285253 data_transfer_impl_ = std::move (gpu_data_transfer);
286254}
0 commit comments