Skip to content

Commit ccf20da

Browse files
committed
update and sync with latest ep c api
1 parent da0f9c6 commit ccf20da

File tree

4 files changed

+37
-35
lines changed

4 files changed

+37
-35
lines changed

plugin_execution_providers/tensorrt/tensorrt_execution_provider_data_transfer.cc

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,10 @@
99
void CUDA_RETURN_IF_ERROR(cudaError_t res);
1010

1111
/*static*/
12-
bool ORT_API_CALL TRTEpDataTransfer::CanCopyImpl(void* this_ptr, const OrtMemoryDevice* src_memory_device,
12+
bool ORT_API_CALL TRTEpDataTransfer::CanCopyImpl(const OrtDataTransferImpl* this_ptr,
13+
const OrtMemoryDevice* src_memory_device,
1314
const OrtMemoryDevice* dst_memory_device) noexcept {
14-
auto& impl = *static_cast<TRTEpDataTransfer*>(this_ptr);
15+
auto& impl = *static_cast<const TRTEpDataTransfer*>(this_ptr);
1516

1617
auto it = std::find_if(impl.cuda_gpu_mem_devices_.begin(), impl.cuda_gpu_mem_devices_.end(),
1718
[&impl, &src_memory_device, &dst_memory_device](const OrtMemoryDevice* memory_device) {
@@ -29,7 +30,7 @@ bool ORT_API_CALL TRTEpDataTransfer::CanCopyImpl(void* this_ptr, const OrtMemory
2930
// function to copy one or more tensors.
3031
// implementation can optionally use async copy if a stream is available for the input.
3132
/*static*/
32-
OrtStatus* ORT_API_CALL TRTEpDataTransfer::CopyTensorsImpl(void* this_ptr,
33+
OrtStatus* ORT_API_CALL TRTEpDataTransfer::CopyTensorsImpl(OrtDataTransferImpl* this_ptr,
3334
const OrtValue** src_tensors_ptr,
3435
OrtValue** dst_tensors_ptr,
3536
OrtSyncStream** streams_ptr,
@@ -97,10 +98,10 @@ OrtStatus* ORT_API_CALL TRTEpDataTransfer::CopyTensorsImpl(void* this_ptr,
9798
}
9899

99100
/*static*/
100-
void ORT_API_CALL TRTEpDataTransfer::ReleaseImpl(void* this_ptr) noexcept {
101+
void ORT_API_CALL TRTEpDataTransfer::ReleaseImpl(OrtDataTransferImpl* this_ptr) noexcept {
101102
// In our setup the factory owns a shared ExampleDataTransfer instance so it will do the cleanup, and we ignore
102103
// the call to Release from the plugin_ep::DataTransfer dtor (see /onnxruntime/core/framework/plugin_data_transfer.h)
103104
//
104105
// If you create a new instance on each call to OrtEpFactory::CreateDataTransfer you call `delete` here
105-
delete static_cast<TRTEpDataTransfer*>(this_ptr);
106+
delete this_ptr;
106107
}

plugin_execution_providers/tensorrt/tensorrt_execution_provider_data_transfer.h

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,27 +4,28 @@
44
#pragma once
55

66
#include "ep_utils.h"
7+
#include "onnxruntime_c_api.h"
78

89
struct TRTEpDataTransfer : OrtDataTransferImpl, ApiPtrs {
9-
TRTEpDataTransfer(ApiPtrs api_ptrs, std::vector<const OrtMemoryDevice*> device_mem_infos,
10-
std::vector<const OrtMemoryDevice*> shared_mem_infos)
10+
TRTEpDataTransfer(ApiPtrs api_ptrs, std::vector<const OrtMemoryDevice*>& device_mem_infos,
11+
std::vector<const OrtMemoryDevice*>& shared_mem_infos)
1112
: ApiPtrs(api_ptrs), cuda_gpu_mem_devices_{device_mem_infos}, cuda_pinned_mem_devices_{shared_mem_infos} {
1213
CanCopy = CanCopyImpl;
1314
CopyTensors = CopyTensorsImpl;
1415
Release = ReleaseImpl;
1516
}
1617

17-
static bool ORT_API_CALL CanCopyImpl(void* this_ptr, const OrtMemoryDevice* src_memory_device,
18+
static bool ORT_API_CALL CanCopyImpl(const OrtDataTransferImpl* this_ptr, const OrtMemoryDevice* src_memory_device,
1819
const OrtMemoryDevice* dst_memory_device) noexcept;
1920

2021
// function to copy one or more tensors.
2122
// implementation can optionally use async copy if a stream is available for the input.
22-
static OrtStatus* ORT_API_CALL CopyTensorsImpl(void* this_ptr, const OrtValue** src_tensors_ptr,
23+
static OrtStatus* ORT_API_CALL CopyTensorsImpl(OrtDataTransferImpl* this_ptr, const OrtValue** src_tensors_ptr,
2324
OrtValue** dst_tensors_ptr, OrtSyncStream** streams_ptr,
2425
size_t num_tensors) noexcept;
25-
static void ORT_API_CALL ReleaseImpl(void* this_ptr) noexcept;
26+
static void ORT_API_CALL ReleaseImpl(OrtDataTransferImpl* this_ptr) noexcept;
2627

2728
private:
28-
std::vector<const OrtMemoryDevice*> cuda_gpu_mem_devices_;
29-
std::vector<const OrtMemoryDevice*> cuda_pinned_mem_devices_;
29+
std::vector<const OrtMemoryDevice*>& cuda_gpu_mem_devices_;
30+
std::vector<const OrtMemoryDevice*>& cuda_pinned_mem_devices_;
3031
};

plugin_execution_providers/tensorrt/tensorrt_provider_factory.cc

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ TensorrtExecutionProviderFactory::TensorrtExecutionProviderFactory(const char* e
2727
ReleaseAllocator = ReleaseAllocatorImpl;
2828

2929
CreateDataTransfer = CreateDataTransferImpl;
30+
31+
IsStreamAware = IsStreamAwareImpl;
3032
}
3133

3234
const char* ORT_API_CALL TensorrtExecutionProviderFactory::GetNameImpl(const OrtEpFactory* this_ptr) noexcept {
@@ -80,24 +82,19 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProviderFactory::GetSupportedDevicesImp
8082
size_t& num_ep_devices = *p_num_ep_devices;
8183
auto* factory = static_cast<TensorrtExecutionProviderFactory*>(this_ptr);
8284

85+
// Create two memory infos per device.
86+
// The memory info is required to create allocator and gpu data transfer.
8387
int num_cuda_devices = 0;
8488
cudaGetDeviceCount(&num_cuda_devices);
8589
RETURN_IF_ERROR(factory->CreateMemoryInfoForDevices(num_cuda_devices));
8690

87-
std::vector<const OrtMemoryDevice*> cuda_gpu_mem_devices;
88-
std::vector<const OrtMemoryDevice*> cuda_pinned_mem_devices;
8991
int32_t device_id = 0;
9092

9193
for (size_t i = 0; i < num_devices && num_ep_devices < max_ep_devices; ++i) {
9294
// C API
9395
const OrtHardwareDevice& device = *devices[i];
94-
if (factory->ort_api.HardwareDevice_Type(&device) == OrtHardwareDeviceType::OrtHardwareDeviceType_GPU) {
95-
96-
// workaround for duplicate devices when using remote desktop.
97-
if (device_id > 0) {
98-
continue;
99-
}
10096

97+
if (factory->ort_api.HardwareDevice_Type(&device) == OrtHardwareDeviceType::OrtHardwareDeviceType_GPU) {
10198
// These can be returned as nullptr if you have nothing to add.
10299
OrtKeyValuePairs* ep_metadata = nullptr;
103100
OrtKeyValuePairs* ep_options = nullptr;
@@ -129,8 +126,8 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProviderFactory::GetSupportedDevicesImp
129126
RETURN_IF_ERROR(factory->ep_api.EpDevice_AddAllocatorInfo(ep_device, cuda_pinned_mem_info));
130127

131128
// Get memory device from memory info for gpu data transfer
132-
cuda_gpu_mem_devices.push_back(factory->ep_api.MemoryInfo_GetMemoryDevice(cuda_gpu_mem_info));
133-
cuda_pinned_mem_devices.push_back(factory->ep_api.MemoryInfo_GetMemoryDevice(cuda_pinned_mem_info));
129+
factory->cuda_gpu_mem_devices.push_back(factory->ep_api.MemoryInfo_GetMemoryDevice(cuda_gpu_mem_info));
130+
factory->cuda_pinned_mem_devices.push_back(factory->ep_api.MemoryInfo_GetMemoryDevice(cuda_pinned_mem_info));
134131

135132
ep_devices[num_ep_devices++] = ep_device;
136133
++device_id;
@@ -152,10 +149,12 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProviderFactory::GetSupportedDevicesImp
152149

153150
// Create gpu data transfer
154151
auto data_transfer_impl = std::make_unique<TRTEpDataTransfer>(static_cast<const ApiPtrs&>(*factory),
155-
cuda_gpu_mem_devices, // device memory
156-
cuda_pinned_mem_devices // shared memory
152+
factory->cuda_gpu_mem_devices, // device memory
153+
factory->cuda_pinned_mem_devices // shared memory
157154
);
158-
factory->SetGPUDataTransfer(std::move(data_transfer_impl));
155+
156+
factory->data_transfer_impl = std::move(data_transfer_impl);
157+
159158
return nullptr;
160159
}
161160

@@ -244,13 +243,13 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProviderFactory::CreateDataTransferImpl
244243
OrtEpFactory* this_ptr,
245244
OrtDataTransferImpl** data_transfer) noexcept {
246245
auto& factory = *static_cast<TensorrtExecutionProviderFactory*>(this_ptr);
247-
*data_transfer = factory.data_transfer_impl_.get();
246+
*data_transfer = factory.data_transfer_impl.get();
248247

249248
return nullptr;
250249
}
251250

252-
void TensorrtExecutionProviderFactory::SetGPUDataTransfer(std::unique_ptr<TRTEpDataTransfer> gpu_data_transfer) {
253-
data_transfer_impl_ = std::move(gpu_data_transfer);
251+
bool ORT_API_CALL TensorrtExecutionProviderFactory::IsStreamAwareImpl(const OrtEpFactory* /*this_ptr*/) noexcept {
252+
return false;
254253
}
255254

256255
// To make symbols visible on macOS/iOS
@@ -265,6 +264,7 @@ extern "C" {
265264
// Public symbols
266265
//
267266
EXPORT_SYMBOL OrtStatus* CreateEpFactories(const char* registration_name, const OrtApiBase* ort_api_base,
267+
const OrtLogger*,
268268
OrtEpFactory** factories, size_t max_factories, size_t* num_factories) {
269269
const OrtApi* ort_api = ort_api_base->GetApi(ORT_API_VERSION);
270270
const OrtEpApi* ort_ep_api = ort_api->GetEpApi();
@@ -285,7 +285,7 @@ EXPORT_SYMBOL OrtStatus* CreateEpFactories(const char* registration_name, const
285285
}
286286

287287
EXPORT_SYMBOL OrtStatus* ReleaseEpFactory(OrtEpFactory* factory) {
288-
delete factory;
288+
delete static_cast<TensorrtExecutionProviderFactory*>(factory);
289289
return nullptr;
290290
}
291291

plugin_execution_providers/tensorrt/tensorrt_provider_factory.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,6 @@ struct TensorrtExecutionProviderFactory : public OrtEpFactory, public ApiPtrs {
1212
public:
1313
TensorrtExecutionProviderFactory(const char* ep_name, ApiPtrs apis);
1414

15-
const OrtMemoryInfo* GetDefaultGpuMemInfoForDeviceId(uint32_t device_id) const;
16-
17-
const OrtMemoryInfo* GetHostAccessibleMemInfoForDeviceId(uint32_t device_id) const;
18-
1915
OrtStatus* CreateMemoryInfoForDevices(int num_devices);
2016

2117
// CUDA gpu memory and CUDA pinned memory are required for allocator and data transfer, these are the OrtMemoryInfo
@@ -25,6 +21,10 @@ struct TensorrtExecutionProviderFactory : public OrtEpFactory, public ApiPtrs {
2521
std::vector<MemoryInfoUniquePtr> cuda_pinned_memory_infos;
2622
std::unordered_map<uint32_t, const OrtMemoryInfo*> device_id_to_cuda_gpu_memory_info_map; // device id -> OrtMemoryInfo
2723

24+
std::vector<const OrtMemoryDevice*> cuda_gpu_mem_devices;
25+
std::vector<const OrtMemoryDevice*> cuda_pinned_mem_devices;
26+
std::unique_ptr<TRTEpDataTransfer> data_transfer_impl; // data transfer implementation for this factory
27+
2828
private:
2929
static const char* ORT_API_CALL GetNameImpl(const OrtEpFactory* this_ptr) noexcept;
3030

@@ -53,11 +53,11 @@ struct TensorrtExecutionProviderFactory : public OrtEpFactory, public ApiPtrs {
5353
static OrtStatus* ORT_API_CALL CreateDataTransferImpl(OrtEpFactory* this_ptr,
5454
OrtDataTransferImpl** data_transfer) noexcept;
5555

56+
static bool ORT_API_CALL IsStreamAwareImpl(const OrtEpFactory* /*this_ptr*/) noexcept;
57+
5658
void SetGPUDataTransfer(std::unique_ptr<TRTEpDataTransfer> gpu_data_transfer);
5759

5860
const std::string ep_name_; // EP name
5961
const std::string vendor_{"Nvidia"}; // EP vendor name
6062
const std::string ep_version_{"0.1.0"}; // EP version
61-
62-
std::unique_ptr<TRTEpDataTransfer> data_transfer_impl_; // data transfer implementation for this factory
6363
};

0 commit comments

Comments
 (0)