@@ -27,6 +27,8 @@ TensorrtExecutionProviderFactory::TensorrtExecutionProviderFactory(const char* e
2727 ReleaseAllocator = ReleaseAllocatorImpl;
2828
2929 CreateDataTransfer = CreateDataTransferImpl;
30+
31+ IsStreamAware = IsStreamAwareImpl;
3032}
3133
3234const char * ORT_API_CALL TensorrtExecutionProviderFactory::GetNameImpl (const OrtEpFactory* this_ptr) noexcept {
@@ -80,24 +82,19 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProviderFactory::GetSupportedDevicesImp
8082 size_t & num_ep_devices = *p_num_ep_devices;
8183 auto * factory = static_cast <TensorrtExecutionProviderFactory*>(this_ptr);
8284
85+ // Create two memory infos per device.
86+ // The memory info is required to create allocator and gpu data transfer.
8387 int num_cuda_devices = 0 ;
8488 cudaGetDeviceCount (&num_cuda_devices);
8589 RETURN_IF_ERROR (factory->CreateMemoryInfoForDevices (num_cuda_devices));
8690
87- std::vector<const OrtMemoryDevice*> cuda_gpu_mem_devices;
88- std::vector<const OrtMemoryDevice*> cuda_pinned_mem_devices;
8991 int32_t device_id = 0 ;
9092
9193 for (size_t i = 0 ; i < num_devices && num_ep_devices < max_ep_devices; ++i) {
9294 // C API
9395 const OrtHardwareDevice& device = *devices[i];
94- if (factory->ort_api .HardwareDevice_Type (&device) == OrtHardwareDeviceType::OrtHardwareDeviceType_GPU) {
95-
96- // workaround for duplicate devices when using remote desktop.
97- if (device_id > 0 ) {
98- continue ;
99- }
10096
97+ if (factory->ort_api .HardwareDevice_Type (&device) == OrtHardwareDeviceType::OrtHardwareDeviceType_GPU) {
10198 // These can be returned as nullptr if you have nothing to add.
10299 OrtKeyValuePairs* ep_metadata = nullptr ;
103100 OrtKeyValuePairs* ep_options = nullptr ;
@@ -129,8 +126,8 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProviderFactory::GetSupportedDevicesImp
129126 RETURN_IF_ERROR (factory->ep_api .EpDevice_AddAllocatorInfo (ep_device, cuda_pinned_mem_info));
130127
131128 // Get memory device from memory info for gpu data transfer
132- cuda_gpu_mem_devices.push_back (factory->ep_api .MemoryInfo_GetMemoryDevice (cuda_gpu_mem_info));
133- cuda_pinned_mem_devices.push_back (factory->ep_api .MemoryInfo_GetMemoryDevice (cuda_pinned_mem_info));
129+ factory-> cuda_gpu_mem_devices .push_back (factory->ep_api .MemoryInfo_GetMemoryDevice (cuda_gpu_mem_info));
130+ factory-> cuda_pinned_mem_devices .push_back (factory->ep_api .MemoryInfo_GetMemoryDevice (cuda_pinned_mem_info));
134131
135132 ep_devices[num_ep_devices++] = ep_device;
136133 ++device_id;
@@ -152,10 +149,12 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProviderFactory::GetSupportedDevicesImp
152149
153150 // Create gpu data transfer
154151 auto data_transfer_impl = std::make_unique<TRTEpDataTransfer>(static_cast <const ApiPtrs&>(*factory),
155- cuda_gpu_mem_devices, // device memory
156- cuda_pinned_mem_devices // shared memory
152+ factory-> cuda_gpu_mem_devices , // device memory
153+ factory-> cuda_pinned_mem_devices // shared memory
157154 );
158- factory->SetGPUDataTransfer (std::move (data_transfer_impl));
155+
156+ factory->data_transfer_impl = std::move (data_transfer_impl);
157+
159158 return nullptr ;
160159}
161160
@@ -244,13 +243,13 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProviderFactory::CreateDataTransferImpl
244243 OrtEpFactory* this_ptr,
245244 OrtDataTransferImpl** data_transfer) noexcept {
246245 auto & factory = *static_cast <TensorrtExecutionProviderFactory*>(this_ptr);
247- *data_transfer = factory.data_transfer_impl_ .get ();
246+ *data_transfer = factory.data_transfer_impl .get ();
248247
249248 return nullptr ;
250249}
251250
252- void TensorrtExecutionProviderFactory::SetGPUDataTransfer (std::unique_ptr<TRTEpDataTransfer> gpu_data_transfer) {
253- data_transfer_impl_ = std::move (gpu_data_transfer) ;
251+ bool ORT_API_CALL TensorrtExecutionProviderFactory::IsStreamAwareImpl ( const OrtEpFactory* /* this_ptr */ ) noexcept {
252+ return false ;
254253}
255254
256255// To make symbols visible on macOS/iOS
@@ -265,6 +264,7 @@ extern "C" {
265264// Public symbols
266265//
267266EXPORT_SYMBOL OrtStatus* CreateEpFactories (const char * registration_name, const OrtApiBase* ort_api_base,
267+ const OrtLogger*,
268268 OrtEpFactory** factories, size_t max_factories, size_t * num_factories) {
269269 const OrtApi* ort_api = ort_api_base->GetApi (ORT_API_VERSION);
270270 const OrtEpApi* ort_ep_api = ort_api->GetEpApi ();
@@ -285,7 +285,7 @@ EXPORT_SYMBOL OrtStatus* CreateEpFactories(const char* registration_name, const
285285}
286286
287287EXPORT_SYMBOL OrtStatus* ReleaseEpFactory (OrtEpFactory* factory) {
288- delete factory;
288+ delete static_cast <TensorrtExecutionProviderFactory*>( factory) ;
289289 return nullptr ;
290290}
291291
0 commit comments