Skip to content

Commit be453b1

Browse files
committed
add allocator and data transfer
1 parent 7851a1c commit be453b1

11 files changed

+322
-763
lines changed

plugin_execution_providers/tensorrt/tensorrt_cuda_allocator.cc renamed to plugin_execution_providers/tensorrt/cuda_allocator.cc

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,10 @@
33

44
#include <cassert>
55
#include <cuda_runtime_api.h>
6-
#include "tensorrt_cuda_allocator.h"
6+
#include "cuda_allocator.h"
77

88
void CUDA_RETURN_IF_ERROR(cudaError_t res);
99

10-
namespace onnxruntime {
1110
void CUDAAllocator::CheckDevice(bool throw_when_fail) const {
1211
#ifndef NDEBUG
1312
// check device to match at debug build
@@ -75,5 +74,3 @@ void CUDAPinnedAllocator::Free(void* p) {
7574
const OrtMemoryInfo* CUDAPinnedAllocator::Info() const {
7675
return mem_info_;
7776
}
78-
79-
} // namespace onnxruntime

plugin_execution_providers/tensorrt/tensorrt_cuda_allocator.h renamed to plugin_execution_providers/tensorrt/cuda_allocator.h

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,13 @@
77
#define ORT_API_MANUAL_INIT
88
#include "onnxruntime_cxx_api.h"
99

10-
namespace onnxruntime {
11-
12-
// Following names are originally defined in allocator.h
1310
constexpr const char* CUDA_ALLOCATOR = "Cuda";
1411
constexpr const char* CUDA_PINNED_ALLOCATOR = "CudaPinned";
1512

1613
using DeviceId = int16_t;
1714

1815
struct CUDAAllocator : OrtAllocator {
19-
CUDAAllocator(DeviceId device_id, const char* name = onnxruntime::CUDA_ALLOCATOR) {
16+
CUDAAllocator(DeviceId device_id, const char* name = CUDA_ALLOCATOR) {
2017
OrtAllocator::version = ORT_API_VERSION;
2118
OrtAllocator::Alloc = [](OrtAllocator* this_, size_t size) { return static_cast<CUDAAllocator*>(this_)->Alloc(size); };
2219
OrtAllocator::Free = [](OrtAllocator* this_, void* p) { static_cast<CUDAAllocator*>(this_)->Free(p); };
@@ -31,6 +28,7 @@ struct CUDAAllocator : OrtAllocator {
3128
OrtMemType::OrtMemTypeDefault,
3229
&mem_info_);
3330
}
31+
// TODO: Handle destructor
3432
//~CUDAAllocator();
3533

3634
void* Alloc(size_t size);
@@ -50,7 +48,7 @@ struct CUDAAllocator : OrtAllocator {
5048
};
5149

5250
struct CUDAPinnedAllocator : OrtAllocator {
53-
CUDAPinnedAllocator(const char* name = onnxruntime::CUDA_PINNED_ALLOCATOR) {
51+
CUDAPinnedAllocator(const char* name = CUDA_PINNED_ALLOCATOR) {
5452
OrtAllocator::version = ORT_API_VERSION;
5553
OrtAllocator::Alloc = [](OrtAllocator* this_, size_t size) { return static_cast<CUDAPinnedAllocator*>(this_)->Alloc(size); };
5654
OrtAllocator::Free = [](OrtAllocator* this_, void* p) { static_cast<CUDAPinnedAllocator*>(this_)->Free(p); };
@@ -62,6 +60,7 @@ struct CUDAPinnedAllocator : OrtAllocator {
6260
OrtMemType::OrtMemTypeDefault,
6361
&mem_info_);
6462
}
63+
// TODO: Handle destructor
6564
//~CUDAPinnedAllocator();
6665

6766
void* Alloc(size_t size);
@@ -77,6 +76,3 @@ struct CUDAPinnedAllocator : OrtAllocator {
7776
DeviceId device_id_ = 0;
7877
OrtMemoryInfo* mem_info_ = nullptr;
7978
};
80-
81-
82-
} // namespace onnxruntime

plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc

Lines changed: 8 additions & 676 deletions
Large diffs are not rendered by default.

plugin_execution_providers/tensorrt/tensorrt_execution_provider.h

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,9 @@ class OutputAllocator : public nvinfer1::IOutputAllocator {
152152

153153
using ShapeRangesMap = std::unordered_map<std::string, std::unordered_map<size_t, std::vector<std::vector<int64_t>>>>;
154154

155+
template <typename T>
156+
using IAllocatorUniquePtr = std::unique_ptr<T, std::function<void(T*)>>;
157+
155158
struct TensorrtComputeState {
156159
std::string fused_node_name;
157160
nvinfer1::IBuilder* builder;
@@ -168,14 +171,14 @@ struct TensorrtComputeState {
168171
bool int8_calibration_cache_available = false;
169172
bool dla_enable = false;
170173
int dla_core = 0;
171-
size_t* max_workspace_size_ptr = nullptr;
172174
std::string trt_node_name_with_precision;
173175
bool engine_cache_enable = false;
174176
std::string engine_cache_path;
175177
nvinfer1::IRuntime* runtime = nullptr;
176178
std::vector<nvinfer1::IOptimizationProfile*> profiles;
177179
bool context_memory_sharing_enable = false;
178180
size_t* max_context_mem_size_ptr = nullptr;
181+
IAllocatorUniquePtr<void>* context_memory = nullptr;
179182
std::unordered_map<std::string, float> dynamic_range_map;
180183
bool engine_decryption_enable = false;
181184
int (*engine_decryption)(const char*, char*, size_t*) = nullptr;
@@ -215,11 +218,6 @@ static const std::string k_cc_hw_compatible = "80+";
215218
static const std::string k_ep_ctx_hardware_architecture = "hardware_architecture";
216219
static const std::string k_ep_ctx_onnx_model_filename = "onnx_model_filename";
217220

218-
struct ApiPtrs {
219-
const OrtApi& ort_api;
220-
const OrtEpApi& ep_api;
221-
};
222-
223221
/// <summary>
224222
///
225223
/// Plugin TensorRT EP OrtNodeComputeInfo that represents the computation function for a compiled OrtGraph.
@@ -346,7 +344,7 @@ struct TensorrtExecutionProvider : OrtEp, ApiPtrs {
346344
bool context_memory_sharing_enable_ = false;
347345
bool layer_norm_fp32_fallback_ = false;
348346
size_t max_ctx_mem_size_ = 0;
349-
// IAllocatorUniquePtr<void> context_memory_ = nullptr;
347+
IAllocatorUniquePtr<void> context_memory_ = nullptr;
350348
mutable char model_path_[4096] = {}; // Reserved for max path length
351349
bool engine_decryption_enable_ = false;
352350
int (*engine_decryption_)(const char*, char*, size_t*) = nullptr;
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#include "tensorrt_execution_provider_data_transfer.h"
5+
6+
#include <cassert>
7+
#include <gsl/span>
8+
9+
void CUDA_RETURN_IF_ERROR(cudaError_t res);
10+
11+
/*static*/
12+
bool ORT_API_CALL TRTEpDataTransfer::CanCopyImpl(void* this_ptr,
13+
const OrtMemoryDevice* src_memory_device,
14+
const OrtMemoryDevice* dst_memory_device) noexcept {
15+
auto& impl = *static_cast<TRTEpDataTransfer*>(this_ptr);
16+
bool src_is_our_device = impl.ep_api.MemoryDevice_AreEqual(src_memory_device, impl.device_mem_info);
17+
bool dst_is_our_device = impl.ep_api.MemoryDevice_AreEqual(dst_memory_device, impl.device_mem_info);
18+
19+
return src_is_our_device || dst_is_our_device;
20+
}
21+
22+
// function to copy one or more tensors.
23+
// implementation can optionally use async copy if a stream is available for the input.
24+
/*static*/
25+
OrtStatus* ORT_API_CALL TRTEpDataTransfer::CopyTensorsImpl(void* this_ptr,
26+
const OrtValue** src_tensors_ptr,
27+
OrtValue** dst_tensors_ptr,
28+
OrtSyncStream** streams_ptr,
29+
size_t num_tensors) noexcept {
30+
auto& impl = *static_cast<TRTEpDataTransfer*>(this_ptr);
31+
32+
auto src_tensors = gsl::make_span<const OrtValue*>(src_tensors_ptr, num_tensors);
33+
auto dst_tensors = gsl::make_span<OrtValue*>(dst_tensors_ptr, num_tensors);
34+
auto streams = gsl::make_span<OrtSyncStream*>(streams_ptr, num_tensors);
35+
36+
for (size_t i = 0; i < num_tensors; ++i) {
37+
// NOTE: Stream support will be a separate PR. ignore teh streams_ptr values for now
38+
39+
const OrtMemoryDevice* src_device = nullptr;
40+
const OrtMemoryDevice* dst_device = nullptr;
41+
RETURN_IF_ERROR(impl.ep_api.Value_GetMemoryDevice(src_tensors[i], &src_device));
42+
RETURN_IF_ERROR(impl.ep_api.Value_GetMemoryDevice(dst_tensors[i], &dst_device));
43+
44+
OrtMemoryInfoDeviceType src_device_type = impl.ep_api.MemoryDevice_GetDeviceType(src_device);
45+
OrtMemoryInfoDeviceType dst_device_type = impl.ep_api.MemoryDevice_GetDeviceType(dst_device);
46+
OrtDeviceMemoryType src_mem_type = impl.ep_api.MemoryDevice_GetMemoryType(src_device);
47+
OrtDeviceMemoryType dst_mem_type = impl.ep_api.MemoryDevice_GetMemoryType(dst_device);
48+
bool copy_involves_pinned_memory = src_mem_type == OrtDeviceMemoryType_HOST_ACCESSIBLE ||
49+
dst_mem_type == OrtDeviceMemoryType_HOST_ACCESSIBLE;
50+
51+
const void* src_data = nullptr;
52+
void* dst_data = nullptr;
53+
RETURN_IF_ERROR(impl.ort_api.GetTensorData(src_tensors[i], &src_data));
54+
RETURN_IF_ERROR(impl.ort_api.GetTensorMutableData(dst_tensors[i], &dst_data));
55+
56+
size_t bytes = 0;
57+
RETURN_IF_ERROR(impl.ort_api.GetTensorSizeInBytes(reinterpret_cast<const OrtValue*>(src_data), &bytes));
58+
59+
// for the sync version of memcpy, launch to cuda default stream
60+
if (dst_device_type == OrtMemoryInfoDeviceType_GPU) {
61+
if (src_device_type == OrtMemoryInfoDeviceType_GPU) {
62+
// GPU -> GPU
63+
// Copy only if the two addresses are different and bytes > 0.
64+
if (dst_data != src_data && bytes > 0) {
65+
CUDA_RETURN_IF_ERROR(cudaMemcpy(dst_data, src_data, bytes, cudaMemcpyDeviceToDevice));
66+
// For device memory to device memory copy, no host-side synchronization is performed by cudaMemcpy.
67+
// see https://docs.nvidia.com/cuda/cuda-runtime-api/api-sync-behavior.html
68+
CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(nullptr));
69+
}
70+
} else {
71+
// CPU -> GPU, this is blocking
72+
CUDA_RETURN_IF_ERROR(cudaMemcpy(dst_data, src_data, bytes, cudaMemcpyHostToDevice));
73+
if (src_mem_type != OrtDeviceMemoryType_HOST_ACCESSIBLE) {
74+
// For cudaMemcpy from pageable host memory to device memory, DMA to final destination may not have completed.
75+
// see https://docs.nvidia.com/cuda/cuda-runtime-api/api-sync-behavior.html
76+
CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(nullptr));
77+
}
78+
}
79+
} else if (src_device_type == OrtMemoryInfoDeviceType_GPU) {
80+
// GPU -> CPU, this is blocking
81+
CUDA_RETURN_IF_ERROR(cudaMemcpy(dst_data, src_data, bytes, cudaMemcpyDeviceToHost));
82+
} else {
83+
// CPU -> CPU involves copy to/from pinned memory and a synchronize may be required first
84+
//ORT_ENFORCE(dst_data != src_data);
85+
memcpy(dst_data, src_data, bytes);
86+
}
87+
}
88+
89+
return nullptr;
90+
}
91+
92+
/*static*/
93+
void ORT_API_CALL TRTEpDataTransfer::ReleaseImpl(void* this_ptr) noexcept {
94+
// In our setup the factory owns a shared ExampleDataTransfer instance so it will do the cleanup, and we ignore
95+
// the call to Release from the plugin_ep::DataTransfer dtor (see /onnxruntime/core/framework/plugin_data_transfer.h)
96+
//
97+
// If you create a new instance on each call to OrtEpFactory::CreateDataTransfer you call `delete` here
98+
delete static_cast<TRTEpDataTransfer*>(this_ptr);
99+
}
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#pragma once
5+
6+
#include "tensorrt_execution_provider_utils.h"
7+
8+
struct TRTEpDataTransfer : OrtDataTransferImpl, ApiPtrs {
9+
TRTEpDataTransfer(ApiPtrs api_ptrs, const OrtMemoryDevice* device_mem_info_,
10+
const OrtMemoryDevice* shared_mem_info_ = nullptr)
11+
: ApiPtrs(api_ptrs), device_mem_info{device_mem_info_}, shared_mem_info{shared_mem_info_} {
12+
CanCopy = CanCopyImpl;
13+
CopyTensors = CopyTensorsImpl;
14+
Release = ReleaseImpl;
15+
}
16+
17+
static bool ORT_API_CALL CanCopyImpl(void* this_ptr, const OrtMemoryDevice* src_memory_device,
18+
const OrtMemoryDevice* dst_memory_device) noexcept;
19+
20+
// function to copy one or more tensors.
21+
// implementation can optionally use async copy if a stream is available for the input.
22+
static OrtStatus* ORT_API_CALL CopyTensorsImpl(void* this_ptr, const OrtValue** src_tensors_ptr,
23+
OrtValue** dst_tensors_ptr, OrtSyncStream** streams_ptr,
24+
size_t num_tensors) noexcept;
25+
static void ORT_API_CALL ReleaseImpl(void* this_ptr) noexcept;
26+
27+
private:
28+
const OrtMemoryDevice* device_mem_info;
29+
const OrtMemoryDevice* shared_mem_info;
30+
};

plugin_execution_providers/tensorrt/tensorrt_execution_provider_info.cc

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
#include "provider_options_utils.h"
88
#include "cuda/cuda_common.h"
99

10-
namespace onnxruntime {
1110
namespace tensorrt {
1211
namespace provider_option_names {
1312
constexpr const char* kDeviceId = "device_id";
@@ -336,4 +335,3 @@ TensorrtExecutionProviderInfo TensorrtExecutionProviderInfo::FromProviderOptions
336335
// trt_provider_options_v2.trt_ep_context_file_path = copy_string_if_needed(internal_options.ep_context_file_path);
337336
// trt_provider_options_v2.trt_engine_hw_compatible = internal_options.engine_hw_compatible;
338337
//}
339-
} // namespace onnxruntime

plugin_execution_providers/tensorrt/tensorrt_execution_provider_info.h

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99

1010
#define TRT_DEFAULT_OPTIMIZER_LEVEL 3
1111

12-
namespace onnxruntime {
1312
// Information needed to construct trt execution providers.
1413
struct TensorrtExecutionProviderInfo {
1514
int device_id{0};
@@ -55,11 +54,10 @@ struct TensorrtExecutionProviderInfo {
5554
std::string engine_cache_prefix{""};
5655
bool engine_hw_compatible{false};
5756

58-
static TensorrtExecutionProviderInfo FromProviderOptions(const ProviderOptions& options);
57+
static TensorrtExecutionProviderInfo FromProviderOptions(const onnxruntime::ProviderOptions& options);
5958
// static ProviderOptions ToProviderOptions(const TensorrtExecutionProviderInfo& info);
6059
// static ProviderOptions ToProviderOptions(const OrtTensorRTProviderOptionsV2& info);
6160
// static void UpdateProviderOptions(void* provider_options, const ProviderOptions& options, bool string_copy);
6261
//
6362
// std::vector<OrtCustomOpDomain*> custom_op_domain_list;
6463
};
65-
} // namespace onnxruntime

plugin_execution_providers/tensorrt/tensorrt_execution_provider_utils.h

Lines changed: 35 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,45 @@
1+
#define ORT_API_MANUAL_INIT
2+
#include "onnxruntime_cxx_api.h"
3+
#undef ORT_API_MANUAL_INIT
4+
5+
#include "flatbuffers/idl.h"
6+
#include "ort_trt_int8_cal_table.fbs.h"
7+
// #include "core/providers/cuda/cuda_pch.h"
8+
// #include "core/common/path_string.h"
9+
// #include "core/framework/murmurhash3.h"
10+
11+
#include"nv_includes.h"
12+
113
#include <fstream>
214
#include <unordered_map>
315
#include <string>
416
#include <vector>
517
#include <sstream>
618
#include <iostream>
719
#include <filesystem>
8-
#include "flatbuffers/idl.h"
9-
#include "ort_trt_int8_cal_table.fbs.h"
10-
#include <NvInferVersion.h>
11-
//#include "core/providers/cuda/cuda_pch.h"
12-
//#include "core/common/path_string.h"
13-
//#include "core/framework/murmurhash3.h"
1420

15-
namespace fs = std::filesystem;
21+
#define RETURN_IF_ERROR(fn) \
22+
do { \
23+
OrtStatus* _status = (fn); \
24+
if (_status != nullptr) { \
25+
return _status; \
26+
} \
27+
} while (0)
28+
29+
#define RETURN_IF(cond, ort_api, msg) \
30+
do { \
31+
if ((cond)) { \
32+
return (ort_api).CreateStatus(ORT_EP_FAIL, (msg)); \
33+
} \
34+
} while (0)
35+
36+
struct ApiPtrs {
37+
const OrtApi& ort_api;
38+
const OrtEpApi& ep_api;
39+
const OrtModelEditorApi& model_editor_api;
40+
};
1641

17-
//namespace onnxruntime {
42+
namespace fs = std::filesystem;
1843

1944
// Check if cycle exists in the graph after partitioning
2045
/*
@@ -143,6 +168,7 @@ std::vector<std::string> SplitToStringVec(std::string const& s, char separator)
143168
return splitted;
144169
}
145170

171+
/*
146172
nvinfer1::TacticSources GetTacticSourceFromString(std::string& tactic_string) {
147173
nvinfer1::TacticSources disabledTactics = 0;
148174
nvinfer1::TacticSources enabledTactics = 0;
@@ -197,6 +223,7 @@ nvinfer1::TacticSources GetTacticSourceFromString(std::string& tactic_string) {
197223
}
198224
return enabledTactics & ~disabledTactics;
199225
}
226+
*/
200227

201228
inline std::vector<char> loadTimingCacheFile(const std::string inFileName) {
202229
std::ifstream iFile(inFileName, std::ios::in | std::ios::binary);
@@ -968,4 +995,3 @@ std::string GetCacheSuffix(const std::string& fused_node_name, const std::string
968995
}
969996
return "";
970997
}
971-
//} // namespace onnxruntime

0 commit comments

Comments
 (0)