Skip to content

Commit 30e0f91

Browse files
committed
update memory info and data transfer in TRT EP's factor to accommodate mutiple GPU devices
1 parent 731ed72 commit 30e0f91

File tree

7 files changed

+158
-81
lines changed

7 files changed

+158
-81
lines changed

plugin_execution_providers/tensorrt/cuda_allocator.h

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,22 +13,13 @@ constexpr const char* CUDA_PINNED_ALLOCATOR = "CudaPinned";
1313
using DeviceId = int16_t;
1414

1515
struct CUDAAllocator : OrtAllocator {
16-
CUDAAllocator(const OrtMemoryInfo* mem_info, const char* name = CUDA_ALLOCATOR) {
16+
CUDAAllocator(const OrtMemoryInfo* mem_info, DeviceId device_id) : mem_info_(mem_info), device_id_(device_id) {
1717
OrtAllocator::version = ORT_API_VERSION;
18-
OrtAllocator::Alloc = [](OrtAllocator* this_, size_t size) { return static_cast<CUDAAllocator*>(this_)->Alloc(size); };
18+
OrtAllocator::Alloc = [](OrtAllocator* this_, size_t size) {
19+
return static_cast<CUDAAllocator*>(this_)->Alloc(size);
20+
};
1921
OrtAllocator::Free = [](OrtAllocator* this_, void* p) { static_cast<CUDAAllocator*>(this_)->Free(p); };
2022
OrtAllocator::Info = [](const OrtAllocator* this_) { return static_cast<const CUDAAllocator*>(this_)->Info(); };
21-
22-
mem_info_ = mem_info;
23-
24-
device_id_ = device_id;
25-
26-
const OrtApi* api = OrtGetApiBase()->GetApi(ORT_API_VERSION);
27-
api->CreateMemoryInfo(name,
28-
OrtAllocatorType::OrtDeviceAllocator,
29-
static_cast<int>(device_id),
30-
OrtMemType::OrtMemTypeDefault,
31-
&mem_info_);
3223
}
3324
// TODO: Handle destructor
3425
//~CUDAAllocator();

plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -654,7 +654,7 @@ OrtStatusPtr BindContextOutput(Ort::KernelContext& ctx,
654654
}
655655

656656
OrtStatusPtr BindKernelOutput(Ort::KernelContext& ctx,
657-
OrtMemoryInfo* /*mem_info*/,
657+
const OrtMemoryInfo* /*mem_info*/,
658658
DDSOutputAllocatorMap& allocator_map,
659659
char const* output_name,
660660
size_t output_index,
@@ -1416,6 +1416,7 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this
14161416
tactics = GetTacticSourceFromString(tactic_sources_);
14171417
}
14181418
*compute_state = {
1419+
static_cast<uint32_t>(device_id_),
14191420
fused_node_name,
14201421
builder_.get(),
14211422
&parsers_[fused_node_name],
@@ -2281,6 +2282,7 @@ OrtStatus* TRTEpNodeComputeInfo::ComputeImpl(OrtNodeComputeInfo* this_ptr, void*
22812282
std::unordered_map<std::string, std::vector<int64_t>>
22822283
shape_tensor_values_int64; // same as above but for int64 shape tensor input
22832284

2285+
uint16_t device_id = trt_state->device_id;
22842286
auto max_workspace_size = trt_state->max_workspace_size;
22852287
auto trt_builder = trt_state->builder;
22862288
auto trt_engine = trt_state->engine->get();
@@ -2317,7 +2319,7 @@ OrtStatus* TRTEpNodeComputeInfo::ComputeImpl(OrtNodeComputeInfo* this_ptr, void*
23172319

23182320
// Get default OrtMemoryInfo from factory
23192321
// Get allocator from OrtKernelContext
2320-
OrtMemoryInfo* mem_info = ep.factory_.GetDefaultMemInfo();
2322+
const OrtMemoryInfo* mem_info = ep.factory_.GetDefaultGpuMemInfoForDeviceId(device_id);
23212323
OrtAllocator* alloc = nullptr;
23222324
ep.GetAllocator(&alloc);
23232325
if (alloc == nullptr) {

plugin_execution_providers/tensorrt/tensorrt_execution_provider.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@ class OutputAllocator : public nvinfer1::IOutputAllocator {
153153
};
154154

155155
struct TensorrtComputeState {
156+
uint32_t device_id;
156157
std::string fused_node_name;
157158
nvinfer1::IBuilder* builder;
158159
tensorrt_ptr::unique_pointer<nvonnxparser::IParser>* parser = nullptr;
@@ -207,6 +208,7 @@ struct TensorrtComputeState {
207208

208209
// Minimum information to construct kernel function state for direct engine load code path
209210
struct TensorrtComputeStateForEPContext {
211+
uint32_t device_id;
210212
std::string fused_node_name;
211213
std::unique_ptr<nvinfer1::ICudaEngine>* engine = nullptr;
212214
std::unique_ptr<nvinfer1::IExecutionContext>* context = nullptr;

plugin_execution_providers/tensorrt/tensorrt_execution_provider_data_transfer.cc

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,21 @@
99
void CUDA_RETURN_IF_ERROR(cudaError_t res);
1010

1111
/*static*/
12-
bool ORT_API_CALL TRTEpDataTransfer::CanCopyImpl(void* this_ptr,
13-
const OrtMemoryDevice* src_memory_device,
14-
const OrtMemoryDevice* dst_memory_device) noexcept {
12+
bool ORT_API_CALL TRTEpDataTransfer::CanCopyImpl(void* this_ptr, const OrtMemoryDevice* src_memory_device,
13+
const OrtMemoryDevice* dst_memory_device) noexcept {
1514
auto& impl = *static_cast<TRTEpDataTransfer*>(this_ptr);
16-
bool src_is_our_device = impl.ep_api.MemoryDevice_AreEqual(src_memory_device, impl.device_mem_info);
17-
bool dst_is_our_device = impl.ep_api.MemoryDevice_AreEqual(dst_memory_device, impl.device_mem_info);
1815

19-
return src_is_our_device || dst_is_our_device;
16+
auto it = std::find_if(impl.cuda_gpu_mem_devices_.begin(), impl.cuda_gpu_mem_devices_.end(),
17+
[&impl, &src_memory_device, &dst_memory_device](const OrtMemoryDevice* memory_device) {
18+
bool src_is_our_device = impl.ep_api.MemoryDevice_AreEqual(src_memory_device, memory_device);
19+
bool dst_is_our_device = impl.ep_api.MemoryDevice_AreEqual(dst_memory_device, memory_device);
20+
return src_is_our_device || dst_is_our_device;
21+
});
22+
23+
if (it != impl.cuda_gpu_mem_devices_.end()) {
24+
return true;
25+
}
26+
return false;
2027
}
2128

2229
// function to copy one or more tensors.

plugin_execution_providers/tensorrt/tensorrt_execution_provider_data_transfer.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66
#include "tensorrt_execution_provider_utils.h"
77

88
struct TRTEpDataTransfer : OrtDataTransferImpl, ApiPtrs {
9-
TRTEpDataTransfer(ApiPtrs api_ptrs, const OrtMemoryDevice* device_mem_info_,
10-
const OrtMemoryDevice* shared_mem_info_ = nullptr)
11-
: ApiPtrs(api_ptrs), device_mem_info{device_mem_info_}, shared_mem_info{shared_mem_info_} {
9+
TRTEpDataTransfer(ApiPtrs api_ptrs, std::vector<const OrtMemoryDevice*> device_mem_infos,
10+
std::vector<const OrtMemoryDevice*> shared_mem_infos)
11+
: ApiPtrs(api_ptrs), cuda_gpu_mem_devices_{device_mem_infos}, cuda_pinned_mem_devices_{shared_mem_infos} {
1212
CanCopy = CanCopyImpl;
1313
CopyTensors = CopyTensorsImpl;
1414
Release = ReleaseImpl;
@@ -25,6 +25,6 @@ struct TRTEpDataTransfer : OrtDataTransferImpl, ApiPtrs {
2525
static void ORT_API_CALL ReleaseImpl(void* this_ptr) noexcept;
2626

2727
private:
28-
const OrtMemoryDevice* device_mem_info;
29-
const OrtMemoryDevice* shared_mem_info;
28+
std::vector<const OrtMemoryDevice*> cuda_gpu_mem_devices_;
29+
std::vector<const OrtMemoryDevice*> cuda_pinned_mem_devices_;
3030
};

plugin_execution_providers/tensorrt/tensorrt_provider_factory.cc

Lines changed: 104 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -28,32 +28,6 @@ TensorrtExecutionProviderFactory::TensorrtExecutionProviderFactory(const char* e
2828
ReleaseAllocator = ReleaseAllocatorImpl;
2929

3030
CreateDataTransfer = CreateDataTransferImpl;
31-
32-
// Default GPU allocator OrtMemoryInfo
33-
OrtMemoryInfo* mem_info = nullptr;
34-
auto* status = ort_api.CreateMemoryInfo_V2("Cuda", OrtMemoryInfoDeviceType_GPU,
35-
/*vendor*/ 0x10DE, /* device_id */ 0, OrtDeviceMemoryType_DEFAULT,
36-
/*alignment*/ 0, OrtAllocatorType::OrtDeviceAllocator, &mem_info);
37-
assert(status == nullptr); // should never fail.
38-
default_gpu_memory_info_ = MemoryInfoUniquePtr(mem_info, ort_api.ReleaseMemoryInfo);
39-
40-
// CUDA PINNED allocator OrtMemoryInfo
41-
// HOST_ACCESSIBLE memory should use the non-CPU device type
42-
mem_info = nullptr;
43-
status = ort_api.CreateMemoryInfo_V2("CudaPinned", OrtMemoryInfoDeviceType_GPU,
44-
/*vendor*/ 0x10DE, /* device_id */ 0, OrtDeviceMemoryType_HOST_ACCESSIBLE,
45-
/*alignment*/ 0, OrtAllocatorType::OrtDeviceAllocator, &mem_info);
46-
assert(status == nullptr); // should never fail.
47-
host_accessible_gpu_memory_info_ = MemoryInfoUniquePtr(mem_info, ort_api.ReleaseMemoryInfo);
48-
49-
// Create gpu data transfer
50-
data_transfer_impl_ = std::make_unique<TRTEpDataTransfer>(
51-
apis,
52-
ep_api.MemoryInfo_GetMemoryDevice(default_gpu_memory_info_.get()), // device memory
53-
ep_api.MemoryInfo_GetMemoryDevice(host_accessible_gpu_memory_info_.get()) // shared memory
54-
);
55-
56-
data_transfer_impl_.reset(); // but we're CPU only so we return nullptr for the IDataTransfer.
5731
}
5832

5933
const char* ORT_API_CALL TensorrtExecutionProviderFactory::GetNameImpl(const OrtEpFactory* this_ptr) noexcept {
@@ -76,6 +50,9 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProviderFactory::GetSupportedDevicesImp
7650
size_t& num_ep_devices = *p_num_ep_devices;
7751
auto* factory = static_cast<TensorrtExecutionProviderFactory*>(this_ptr);
7852

53+
std::vector<const OrtMemoryDevice*> cuda_gpu_mem_devices;
54+
std::vector<const OrtMemoryDevice*> cuda_pinned_mem_devices;
55+
7956
for (size_t i = 0; i < num_devices && num_ep_devices < max_ep_devices; ++i) {
8057
// C API
8158
const OrtHardwareDevice& device = *devices[i];
@@ -88,7 +65,7 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProviderFactory::GetSupportedDevicesImp
8865

8966
// The ep options can be provided here as default values.
9067
// Users can also call SessionOptionsAppendExecutionProvider_V2 C API with provided ep options to override.
91-
factory->ort_api.AddKeyValuePair(ep_metadata, "version", "0.1"); // random example using made up values
68+
factory->ort_api.AddKeyValuePair(ep_metadata, "gpu_type", "data center"); // random example using made up values
9269
factory->ort_api.AddKeyValuePair(ep_options, "trt_builder_optimization_level", "3");
9370

9471
// OrtEpDevice copies ep_metadata and ep_options.
@@ -103,25 +80,60 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProviderFactory::GetSupportedDevicesImp
10380
return status;
10481
}
10582

106-
// register the allocator info required by the EP.
107-
RETURN_IF_ERROR(factory->ep_api.EpDevice_AddAllocatorInfo(ep_device, factory->default_gpu_memory_info_.get()));
108-
RETURN_IF_ERROR(factory->ep_api.EpDevice_AddAllocatorInfo(ep_device, factory->host_accessible_gpu_memory_info_.get()));
83+
uint32_t vendor_id = factory->ort_api.HardwareDevice_VendorId(&device);
84+
uint32_t device_id = factory->ort_api.HardwareDevice_DeviceId(&device);
85+
86+
// CUDA allocator OrtMemoryInfo
87+
OrtMemoryInfo* mem_info = nullptr;
88+
status = factory->ort_api.CreateMemoryInfo_V2("Cuda", OrtMemoryInfoDeviceType_GPU, vendor_id, device_id, OrtDeviceMemoryType_DEFAULT,
89+
/*alignment*/ 0, OrtAllocatorType::OrtDeviceAllocator, &mem_info);
90+
91+
assert(status == nullptr); // should never fail.
92+
MemoryInfoUniquePtr cuda_gpu_memory_info = MemoryInfoUniquePtr(mem_info, factory->ort_api.ReleaseMemoryInfo);
93+
94+
// CUDA PINNED allocator OrtMemoryInfo
95+
// HOST_ACCESSIBLE memory should use the non-CPU device type.
96+
mem_info = nullptr;
97+
status = factory->ort_api.CreateMemoryInfo_V2("CudaPinned", OrtMemoryInfoDeviceType_GPU, vendor_id, device_id, OrtDeviceMemoryType_HOST_ACCESSIBLE,
98+
/*alignment*/ 0, OrtAllocatorType::OrtDeviceAllocator, &mem_info);
99+
100+
assert(status == nullptr); // should never fail.
101+
MemoryInfoUniquePtr cuda_pinned_memory_info = MemoryInfoUniquePtr(mem_info, factory->ort_api.ReleaseMemoryInfo);
102+
103+
// Register the allocator info required by TRT EP.
104+
RETURN_IF_ERROR(factory->ep_api.EpDevice_AddAllocatorInfo(ep_device, cuda_gpu_memory_info.get()));
105+
RETURN_IF_ERROR(factory->ep_api.EpDevice_AddAllocatorInfo(ep_device, cuda_pinned_memory_info.get()));
106+
107+
// Get memory device from memory info for gpu data transfer
108+
cuda_gpu_mem_devices.push_back(factory->ep_api.MemoryInfo_GetMemoryDevice(cuda_gpu_memory_info.get()));
109+
cuda_pinned_mem_devices.push_back(factory->ep_api.MemoryInfo_GetMemoryDevice(cuda_pinned_memory_info.get()));
110+
111+
factory->SetDefaultGpuMemInfo(std::move(cuda_gpu_memory_info), device_id);
112+
factory->SetHostAccessibleMemInfo(std::move(cuda_pinned_memory_info), device_id);
109113

110114
ep_devices[num_ep_devices++] = ep_device;
111115
}
112116

113-
// C++ API equivalent. Throws on error.
114-
//{
115-
// Ort::ConstHardwareDevice device(devices[i]);
116-
// if (device.Type() == OrtHardwareDeviceType::OrtHardwareDeviceType_GPU) {
117-
// Ort::KeyValuePairs ep_metadata;
118-
// Ort::KeyValuePairs ep_options;
119-
// ep_metadata.Add("version", "0.1");
120-
// ep_options.Add("trt_builder_optimization_level", "3");
121-
// Ort::EpDevice ep_device{*this_ptr, device, ep_metadata.GetConst(), ep_options.GetConst()};
122-
// ep_devices[num_ep_devices++] = ep_device.release();
123-
// }
124-
//}
117+
// Create gpu data transfer
118+
auto data_transfer_impl = std::make_unique<TRTEpDataTransfer>(
119+
static_cast<const ApiPtrs&>(*factory),
120+
cuda_gpu_mem_devices, // device memory
121+
cuda_pinned_mem_devices // shared memory
122+
);
123+
124+
125+
// C++ API equivalent. Throws on error.
126+
//{
127+
// Ort::ConstHardwareDevice device(devices[i]);
128+
// if (device.Type() == OrtHardwareDeviceType::OrtHardwareDeviceType_GPU) {
129+
// Ort::KeyValuePairs ep_metadata;
130+
// Ort::KeyValuePairs ep_options;
131+
// ep_metadata.Add("version", "0.1");
132+
// ep_options.Add("trt_builder_optimization_level", "3");
133+
// Ort::EpDevice ep_device{*this_ptr, device, ep_metadata.GetConst(), ep_options.GetConst()};
134+
// ep_devices[num_ep_devices++] = ep_device.release();
135+
// }
136+
//}
125137
}
126138

127139
return nullptr;
@@ -181,11 +193,14 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProviderFactory::CreateAllocatorImpl(
181193

182194
// NOTE: The OrtMemoryInfo pointer should only ever be coming straight from an OrtEpDevice, and pointer based
183195
// matching should work.
184-
if (memory_info == factory.default_gpu_memory_info_.get()) {
196+
197+
uint32_t device_id = 0;
198+
199+
if (factory.GetDeviceIdForDefaultGpuMemInfo(memory_info, &device_id)) {
185200
// create a CUDA allocator
186-
auto cuda_allocator = std::make_unique<CUDAAllocator>(memory_info);
201+
auto cuda_allocator = std::make_unique<CUDAAllocator>(memory_info, static_cast<uint16_t>(device_id));
187202
*allocator = cuda_allocator.release();
188-
} else if (memory_info == factory.host_accessible_gpu_memory_info_.get()) {
203+
} else if (factory.GetDeviceIdForHostAccessibleMemInfo(memory_info, &device_id)) {
189204
// create a CUDA PINNED allocator
190205
auto cuda_pinned_allocator = std::make_unique<CUDAPinnedAllocator>(memory_info);
191206
*allocator = cuda_pinned_allocator.release();
@@ -212,8 +227,50 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProviderFactory::CreateDataTransferImpl
212227
return nullptr;
213228
}
214229

215-
OrtMemoryInfo* TensorrtExecutionProviderFactory::GetDefaultMemInfo() const {
216-
return default_gpu_memory_info_.get();
230+
bool TensorrtExecutionProviderFactory::GetDeviceIdForDefaultGpuMemInfo(const OrtMemoryInfo* mem_info, uint32_t* device_id) const {
231+
auto iter = cuda_gpu_memory_info_to_device_id_map_.find(mem_info);
232+
if (iter != cuda_gpu_memory_info_to_device_id_map_.end()) {
233+
*device_id = iter->second;
234+
return true;
235+
}
236+
return false;
237+
}
238+
239+
const OrtMemoryInfo* TensorrtExecutionProviderFactory::GetDefaultGpuMemInfoForDeviceId(uint32_t device_id) const {
240+
auto iter = device_id_to_cuda_gpu_memory_info_map_.find(device_id);
241+
if (iter != device_id_to_cuda_gpu_memory_info_map_.end()) {
242+
return iter->second;
243+
}
244+
return nullptr;
245+
}
246+
247+
void TensorrtExecutionProviderFactory::SetDefaultGpuMemInfo(MemoryInfoUniquePtr mem_info, uint32_t device_id) {
248+
cuda_gpu_memory_info_to_device_id_map_[mem_info.get()] = device_id;
249+
device_id_to_cuda_gpu_memory_info_map_[device_id] = mem_info.get();
250+
cuda_gpu_memory_infos_.push_back(std::move(mem_info));
251+
}
252+
253+
bool TensorrtExecutionProviderFactory::GetDeviceIdForHostAccessibleMemInfo(const OrtMemoryInfo* mem_info, uint32_t* device_id) const {
254+
auto iter = cuda_pinned_memory_info_to_device_id_map_.find(mem_info);
255+
if (iter != cuda_pinned_memory_info_to_device_id_map_.end()) {
256+
*device_id = iter->second;
257+
return true;
258+
}
259+
return false;
260+
}
261+
262+
const OrtMemoryInfo* TensorrtExecutionProviderFactory::GetHostAccessibleMemInfoForDeviceId(uint32_t device_id) const {
263+
auto iter = device_id_to_cuda_pinned_memory_info_map_.find(device_id);
264+
if (iter != device_id_to_cuda_pinned_memory_info_map_.end()) {
265+
return iter->second;
266+
}
267+
return nullptr;
268+
}
269+
270+
void TensorrtExecutionProviderFactory::SetHostAccessibleMemInfo(MemoryInfoUniquePtr mem_info, uint32_t device_id) {
271+
cuda_pinned_memory_info_to_device_id_map_[mem_info.get()] = device_id;
272+
device_id_to_cuda_pinned_memory_info_map_[device_id] = mem_info.get();
273+
cuda_pinned_memory_infos_.push_back(std::move(mem_info));
217274
}
218275

219276
// To make symbols visible on macOS/iOS

plugin_execution_providers/tensorrt/tensorrt_provider_factory.h

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,18 @@
33
#include "tensorrt_execution_provider_utils.h"
44
#include "tensorrt_execution_provider_data_transfer.h"
55

6+
using MemoryInfoUniquePtr = std::unique_ptr<OrtMemoryInfo, std::function<void(OrtMemoryInfo*)>>;
7+
68
///
79
/// Plugin TensorRT EP factory that can create an OrtEp and return information about the supported hardware devices.
810
///
911
struct TensorrtExecutionProviderFactory : public OrtEpFactory, public ApiPtrs {
1012
public:
1113
TensorrtExecutionProviderFactory(const char* ep_name, ApiPtrs apis);
12-
OrtMemoryInfo* GetDefaultMemInfo() const;
14+
15+
const OrtMemoryInfo* GetDefaultGpuMemInfoForDeviceId(uint32_t device_id) const;
16+
17+
const OrtMemoryInfo* GetHostAccessibleMemInfoForDeviceId(uint32_t device_id) const;
1318

1419
private:
1520
static const char* ORT_API_CALL GetNameImpl(const OrtEpFactory* this_ptr) noexcept;
@@ -37,17 +42,30 @@ struct TensorrtExecutionProviderFactory : public OrtEpFactory, public ApiPtrs {
3742
static OrtStatus* ORT_API_CALL CreateDataTransferImpl(OrtEpFactory* this_ptr,
3843
OrtDataTransferImpl** data_transfer) noexcept;
3944

45+
bool GetDeviceIdForDefaultGpuMemInfo(const OrtMemoryInfo* mem_info, uint32_t* device_id) const;
46+
47+
void SetDefaultGpuMemInfo(MemoryInfoUniquePtr mem_info, uint32_t device_id);
48+
49+
bool GetDeviceIdForHostAccessibleMemInfo(const OrtMemoryInfo* mem_info, uint32_t* device_id) const;
50+
51+
void SetHostAccessibleMemInfo(MemoryInfoUniquePtr mem_info, uint32_t device_id);
52+
4053
const std::string ep_name_; // EP name
4154
const std::string vendor_{"Nvidia"}; // EP vendor name
4255

43-
// CPU allocator so we can control the arena behavior. optional as ORT always provides a CPU allocator if needed.
44-
using MemoryInfoUniquePtr = std::unique_ptr<OrtMemoryInfo, std::function<void(OrtMemoryInfo*)>>;
45-
//MemoryInfoUniquePtr cpu_memory_info_;
56+
// OrtMemoryInfo for allocators and data transfer.
57+
58+
// CUDA gpu memory and CUDA pinned memory are required for allocator and data transfer, these are the OrtMemoryInfo instance required for that.
59+
// Current TRT EP implementation uses one default OrtMemoryInfo and one host accessible OrtMemoryInfo per ep device.
60+
std::unordered_map<const OrtMemoryInfo*, uint32_t> cuda_gpu_memory_info_to_device_id_map_; // OrtMemoryInfo -> device id
61+
std::unordered_map<const OrtMemoryInfo*, uint32_t> cuda_pinned_memory_info_to_device_id_map_;
62+
std::unordered_map<uint32_t, const OrtMemoryInfo*> device_id_to_cuda_gpu_memory_info_map_; // device id -> OrtMemoryInfo
63+
std::unordered_map<uint32_t, const OrtMemoryInfo*> device_id_to_cuda_pinned_memory_info_map_;
64+
std::vector<MemoryInfoUniquePtr> cuda_gpu_memory_infos_;
65+
std::vector<MemoryInfoUniquePtr> cuda_pinned_memory_infos_;
4666

47-
// GPU memory and pinned/shared memory are required for data transfer, these are the
48-
// OrtMemoryInfo instance required for that.
49-
MemoryInfoUniquePtr default_gpu_memory_info_;
50-
MemoryInfoUniquePtr host_accessible_gpu_memory_info_;
67+
// CPU allocator so we can control the arena behavior. optional as ORT always provides a CPU allocator if needed.
68+
// MemoryInfoUniquePtr cpu_memory_info_;
5169

5270
std::unique_ptr<TRTEpDataTransfer> data_transfer_impl_; // data transfer implementation for this factory
5371
};

0 commit comments

Comments
 (0)