@@ -28,32 +28,6 @@ TensorrtExecutionProviderFactory::TensorrtExecutionProviderFactory(const char* e
2828 ReleaseAllocator = ReleaseAllocatorImpl;
2929
3030 CreateDataTransfer = CreateDataTransferImpl;
31-
32- // Default GPU allocator OrtMemoryInfo
33- OrtMemoryInfo* mem_info = nullptr ;
34- auto * status = ort_api.CreateMemoryInfo_V2 (" Cuda" , OrtMemoryInfoDeviceType_GPU,
35- /* vendor*/ 0x10DE , /* device_id */ 0 , OrtDeviceMemoryType_DEFAULT,
36- /* alignment*/ 0 , OrtAllocatorType::OrtDeviceAllocator, &mem_info);
37- assert (status == nullptr ); // should never fail.
38- default_gpu_memory_info_ = MemoryInfoUniquePtr (mem_info, ort_api.ReleaseMemoryInfo );
39-
40- // CUDA PINNED allocator OrtMemoryInfo
41- // HOST_ACCESSIBLE memory should use the non-CPU device type
42- mem_info = nullptr ;
43- status = ort_api.CreateMemoryInfo_V2 (" CudaPinned" , OrtMemoryInfoDeviceType_GPU,
44- /* vendor*/ 0x10DE , /* device_id */ 0 , OrtDeviceMemoryType_HOST_ACCESSIBLE,
45- /* alignment*/ 0 , OrtAllocatorType::OrtDeviceAllocator, &mem_info);
46- assert (status == nullptr ); // should never fail.
47- host_accessible_gpu_memory_info_ = MemoryInfoUniquePtr (mem_info, ort_api.ReleaseMemoryInfo );
48-
49- // Create gpu data transfer
50- data_transfer_impl_ = std::make_unique<TRTEpDataTransfer>(
51- apis,
52- ep_api.MemoryInfo_GetMemoryDevice (default_gpu_memory_info_.get ()), // device memory
53- ep_api.MemoryInfo_GetMemoryDevice (host_accessible_gpu_memory_info_.get ()) // shared memory
54- );
55-
56- data_transfer_impl_.reset (); // but we're CPU only so we return nullptr for the IDataTransfer.
5731}
5832
5933const char * ORT_API_CALL TensorrtExecutionProviderFactory::GetNameImpl (const OrtEpFactory* this_ptr) noexcept {
@@ -76,6 +50,9 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProviderFactory::GetSupportedDevicesImp
7650 size_t & num_ep_devices = *p_num_ep_devices;
7751 auto * factory = static_cast <TensorrtExecutionProviderFactory*>(this_ptr);
7852
53+ std::vector<const OrtMemoryDevice*> cuda_gpu_mem_devices;
54+ std::vector<const OrtMemoryDevice*> cuda_pinned_mem_devices;
55+
7956 for (size_t i = 0 ; i < num_devices && num_ep_devices < max_ep_devices; ++i) {
8057 // C API
8158 const OrtHardwareDevice& device = *devices[i];
@@ -88,7 +65,7 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProviderFactory::GetSupportedDevicesImp
8865
8966 // The ep options can be provided here as default values.
9067 // Users can also call SessionOptionsAppendExecutionProvider_V2 C API with provided ep options to override.
91- factory->ort_api .AddKeyValuePair (ep_metadata, " version " , " 0.1 " ); // random example using made up values
68+ factory->ort_api .AddKeyValuePair (ep_metadata, " gpu_type " , " data center " ); // random example using made up values
9269 factory->ort_api .AddKeyValuePair (ep_options, " trt_builder_optimization_level" , " 3" );
9370
9471 // OrtEpDevice copies ep_metadata and ep_options.
@@ -103,25 +80,60 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProviderFactory::GetSupportedDevicesImp
10380 return status;
10481 }
10582
106- // register the allocator info required by the EP.
107- RETURN_IF_ERROR (factory->ep_api .EpDevice_AddAllocatorInfo (ep_device, factory->default_gpu_memory_info_ .get ()));
108- RETURN_IF_ERROR (factory->ep_api .EpDevice_AddAllocatorInfo (ep_device, factory->host_accessible_gpu_memory_info_ .get ()));
83+ uint32_t vendor_id = factory->ort_api .HardwareDevice_VendorId (&device);
84+ uint32_t device_id = factory->ort_api .HardwareDevice_DeviceId (&device);
85+
86+ // CUDA allocator OrtMemoryInfo
87+ OrtMemoryInfo* mem_info = nullptr ;
88+ status = factory->ort_api .CreateMemoryInfo_V2 (" Cuda" , OrtMemoryInfoDeviceType_GPU, vendor_id, device_id, OrtDeviceMemoryType_DEFAULT,
89+ /* alignment*/ 0 , OrtAllocatorType::OrtDeviceAllocator, &mem_info);
90+
91+ assert (status == nullptr ); // should never fail.
92+ MemoryInfoUniquePtr cuda_gpu_memory_info = MemoryInfoUniquePtr (mem_info, factory->ort_api .ReleaseMemoryInfo );
93+
94+ // CUDA PINNED allocator OrtMemoryInfo
95+ // HOST_ACCESSIBLE memory should use the non-CPU device type.
96+ mem_info = nullptr ;
97+ status = factory->ort_api .CreateMemoryInfo_V2 (" CudaPinned" , OrtMemoryInfoDeviceType_GPU, vendor_id, device_id, OrtDeviceMemoryType_HOST_ACCESSIBLE,
98+ /* alignment*/ 0 , OrtAllocatorType::OrtDeviceAllocator, &mem_info);
99+
100+ assert (status == nullptr ); // should never fail.
101+ MemoryInfoUniquePtr cuda_pinned_memory_info = MemoryInfoUniquePtr (mem_info, factory->ort_api .ReleaseMemoryInfo );
102+
103+ // Register the allocator info required by TRT EP.
104+ RETURN_IF_ERROR (factory->ep_api .EpDevice_AddAllocatorInfo (ep_device, cuda_gpu_memory_info.get ()));
105+ RETURN_IF_ERROR (factory->ep_api .EpDevice_AddAllocatorInfo (ep_device, cuda_pinned_memory_info.get ()));
106+
107+ // Get memory device from memory info for gpu data transfer
108+ cuda_gpu_mem_devices.push_back (factory->ep_api .MemoryInfo_GetMemoryDevice (cuda_gpu_memory_info.get ()));
109+ cuda_pinned_mem_devices.push_back (factory->ep_api .MemoryInfo_GetMemoryDevice (cuda_pinned_memory_info.get ()));
110+
111+ factory->SetDefaultGpuMemInfo (std::move (cuda_gpu_memory_info), device_id);
112+ factory->SetHostAccessibleMemInfo (std::move (cuda_pinned_memory_info), device_id);
109113
110114 ep_devices[num_ep_devices++] = ep_device;
111115 }
112116
113- // C++ API equivalent. Throws on error.
114- // {
115- // Ort::ConstHardwareDevice device(devices[i]);
116- // if (device.Type() == OrtHardwareDeviceType::OrtHardwareDeviceType_GPU) {
117- // Ort::KeyValuePairs ep_metadata;
118- // Ort::KeyValuePairs ep_options;
119- // ep_metadata.Add("version", "0.1");
120- // ep_options.Add("trt_builder_optimization_level", "3");
121- // Ort::EpDevice ep_device{*this_ptr, device, ep_metadata.GetConst(), ep_options.GetConst()};
122- // ep_devices[num_ep_devices++] = ep_device.release();
123- // }
124- // }
117+ // Create gpu data transfer
118+ auto data_transfer_impl = std::make_unique<TRTEpDataTransfer>(
119+ static_cast <const ApiPtrs&>(*factory),
120+ cuda_gpu_mem_devices, // device memory
121+ cuda_pinned_mem_devices // shared memory
122+ );
123+
124+
125+ // C++ API equivalent. Throws on error.
126+ // {
127+ // Ort::ConstHardwareDevice device(devices[i]);
128+ // if (device.Type() == OrtHardwareDeviceType::OrtHardwareDeviceType_GPU) {
129+ // Ort::KeyValuePairs ep_metadata;
130+ // Ort::KeyValuePairs ep_options;
131+ // ep_metadata.Add("version", "0.1");
132+ // ep_options.Add("trt_builder_optimization_level", "3");
133+ // Ort::EpDevice ep_device{*this_ptr, device, ep_metadata.GetConst(), ep_options.GetConst()};
134+ // ep_devices[num_ep_devices++] = ep_device.release();
135+ // }
136+ // }
125137 }
126138
127139 return nullptr ;
@@ -181,11 +193,14 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProviderFactory::CreateAllocatorImpl(
181193
182194 // NOTE: The OrtMemoryInfo pointer should only ever be coming straight from an OrtEpDevice, and pointer based
183195 // matching should work.
184- if (memory_info == factory.default_gpu_memory_info_ .get ()) {
196+
197+ uint32_t device_id = 0 ;
198+
199+ if (factory.GetDeviceIdForDefaultGpuMemInfo (memory_info, &device_id)) {
185200 // create a CUDA allocator
186- auto cuda_allocator = std::make_unique<CUDAAllocator>(memory_info);
201+ auto cuda_allocator = std::make_unique<CUDAAllocator>(memory_info, static_cast < uint16_t >(device_id) );
187202 *allocator = cuda_allocator.release ();
188- } else if (memory_info == factory.host_accessible_gpu_memory_info_ . get ( )) {
203+ } else if (factory.GetDeviceIdForHostAccessibleMemInfo (memory_info, &device_id )) {
189204 // create a CUDA PINNED allocator
190205 auto cuda_pinned_allocator = std::make_unique<CUDAPinnedAllocator>(memory_info);
191206 *allocator = cuda_pinned_allocator.release ();
@@ -212,8 +227,50 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProviderFactory::CreateDataTransferImpl
212227 return nullptr ;
213228}
214229
215- OrtMemoryInfo* TensorrtExecutionProviderFactory::GetDefaultMemInfo () const {
216- return default_gpu_memory_info_.get ();
230+ bool TensorrtExecutionProviderFactory::GetDeviceIdForDefaultGpuMemInfo (const OrtMemoryInfo* mem_info, uint32_t * device_id) const {
231+ auto iter = cuda_gpu_memory_info_to_device_id_map_.find (mem_info);
232+ if (iter != cuda_gpu_memory_info_to_device_id_map_.end ()) {
233+ *device_id = iter->second ;
234+ return true ;
235+ }
236+ return false ;
237+ }
238+
239+ const OrtMemoryInfo* TensorrtExecutionProviderFactory::GetDefaultGpuMemInfoForDeviceId (uint32_t device_id) const {
240+ auto iter = device_id_to_cuda_gpu_memory_info_map_.find (device_id);
241+ if (iter != device_id_to_cuda_gpu_memory_info_map_.end ()) {
242+ return iter->second ;
243+ }
244+ return nullptr ;
245+ }
246+
247+ void TensorrtExecutionProviderFactory::SetDefaultGpuMemInfo (MemoryInfoUniquePtr mem_info, uint32_t device_id) {
248+ cuda_gpu_memory_info_to_device_id_map_[mem_info.get ()] = device_id;
249+ device_id_to_cuda_gpu_memory_info_map_[device_id] = mem_info.get ();
250+ cuda_gpu_memory_infos_.push_back (std::move (mem_info));
251+ }
252+
253+ bool TensorrtExecutionProviderFactory::GetDeviceIdForHostAccessibleMemInfo (const OrtMemoryInfo* mem_info, uint32_t * device_id) const {
254+ auto iter = cuda_pinned_memory_info_to_device_id_map_.find (mem_info);
255+ if (iter != cuda_pinned_memory_info_to_device_id_map_.end ()) {
256+ *device_id = iter->second ;
257+ return true ;
258+ }
259+ return false ;
260+ }
261+
262+ const OrtMemoryInfo* TensorrtExecutionProviderFactory::GetHostAccessibleMemInfoForDeviceId (uint32_t device_id) const {
263+ auto iter = device_id_to_cuda_pinned_memory_info_map_.find (device_id);
264+ if (iter != device_id_to_cuda_pinned_memory_info_map_.end ()) {
265+ return iter->second ;
266+ }
267+ return nullptr ;
268+ }
269+
270+ void TensorrtExecutionProviderFactory::SetHostAccessibleMemInfo (MemoryInfoUniquePtr mem_info, uint32_t device_id) {
271+ cuda_pinned_memory_info_to_device_id_map_[mem_info.get ()] = device_id;
272+ device_id_to_cuda_pinned_memory_info_map_[device_id] = mem_info.get ();
273+ cuda_pinned_memory_infos_.push_back (std::move (mem_info));
217274}
218275
219276// To make symbols visible on macOS/iOS
0 commit comments