Skip to content

Commit 3d6fa57

Browse files
committed
fix a lot of compile errors
1 parent be453b1 commit 3d6fa57

File tree

8 files changed

+426
-535
lines changed

8 files changed

+426
-535
lines changed

plugin_execution_providers/tensorrt/cuda/cu_inc/unary_elementwise_impl.cuh

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
#pragma once
55
#include <stdint.h>
66

7-
namespace onnxruntime {
87
namespace cuda {
98

109
// We would like to use 64-bit integer to support large matrices. However, CUDA seems to support only 32-bit integer
@@ -75,4 +74,3 @@ void UnaryElementWiseImpl(
7574
}
7675

7776
} // namespace cuda
78-
} // namespace onnxruntime

plugin_execution_providers/tensorrt/cuda/unary_elementwise_ops_impl.cu

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
#endif
1010
#include <cuda_fp16.h>
1111

12-
namespace onnxruntime {
1312

1413
namespace cuda {
1514

@@ -90,4 +89,3 @@ IMPL_CAST_IMPL_FROM(bool)
9089
//IMPL_CAST_IMPL_FROM(BFloat16)
9190

9291
} // namespace cuda
93-
} // namespace onnxruntime

plugin_execution_providers/tensorrt/cuda/unary_elementwise_ops_impl.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
#include <cuda_fp16.h>
88
#include <cuda_runtime.h>
99

10-
namespace onnxruntime {
1110
namespace cuda {
1211

1312
// Cast
@@ -50,5 +49,3 @@ void Impl_Cast(cudaStream_t stream, const InT* input_data, OutT* output_data, si
5049
}
5150

5251
} // namespace cuda
53-
54-
} // namespace onnxruntime

plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc

Lines changed: 309 additions & 488 deletions
Large diffs are not rendered by default.

plugin_execution_providers/tensorrt/tensorrt_execution_provider.h

Lines changed: 69 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include "onnxruntime_cxx_api.h"
55
#undef ORT_API_MANUAL_INIT
66

7+
#include "tensorrt_provider_factory.h"
78
#include "utils/provider_options.h"
89
#include "tensorrt_execution_provider_info.h"
910
#include "nv_includes.h"
@@ -150,11 +151,6 @@ class OutputAllocator : public nvinfer1::IOutputAllocator {
150151
std::vector<int64_t> output_shapes;
151152
};
152153

153-
using ShapeRangesMap = std::unordered_map<std::string, std::unordered_map<size_t, std::vector<std::vector<int64_t>>>>;
154-
155-
template <typename T>
156-
using IAllocatorUniquePtr = std::unique_ptr<T, std::function<void(T*)>>;
157-
158154
struct TensorrtComputeState {
159155
std::string fused_node_name;
160156
nvinfer1::IBuilder* builder;
@@ -166,6 +162,8 @@ struct TensorrtComputeState {
166162
std::vector<std::unordered_map<std::string, size_t>> output_info;
167163
std::unordered_map<std::string, std::unordered_map<size_t, std::vector<std::vector<int64_t>>>> input_shape_ranges;
168164
std::mutex* tensorrt_mu_ptr = nullptr;
165+
std::string compute_capability;
166+
size_t max_workspace_size = 1 << 30; // 1GB;
169167
bool fp16_enable = false;
170168
bool int8_enable = false;
171169
bool int8_calibration_cache_available = false;
@@ -178,7 +176,7 @@ struct TensorrtComputeState {
178176
std::vector<nvinfer1::IOptimizationProfile*> profiles;
179177
bool context_memory_sharing_enable = false;
180178
size_t* max_context_mem_size_ptr = nullptr;
181-
IAllocatorUniquePtr<void>* context_memory = nullptr;
179+
AllocatorUniquePtr<void>* context_memory = nullptr;
182180
std::unordered_map<std::string, float> dynamic_range_map;
183181
bool engine_decryption_enable = false;
184182
int (*engine_decryption)(const char*, char*, size_t*) = nullptr;
@@ -193,10 +191,17 @@ struct TensorrtComputeState {
193191
int auxiliary_streams = -1;
194192
bool filter_tactic_sources = false;
195193
nvinfer1::TacticSources tactic_sources;
196-
bool cuda_graph_enable = 0;
194+
bool cuda_graph_enable = false;
195+
bool weight_stripped_engine_enable = false;
196+
bool weight_stripped_engine_refit = false;
197+
char* model_path;
198+
std::string onnx_model_folder_path;
199+
const void* onnx_model_bytestream;
200+
size_t onnx_model_bytestream_size;
197201
std::string cache_prefix;
198202
std::string cache_suffix;
199203
bool engine_hw_compatible = false;
204+
bool sync_stream_after_enqueue = true;
200205
};
201206

202207
// Minimum information to construct kernel function state for direct engine load code path
@@ -211,6 +216,7 @@ struct TensorrtComputeStateForEPContext {
211216
std::mutex* tensorrt_mu_ptr = nullptr;
212217
};
213218

219+
using ShapeRangesMap = std::unordered_map<std::string, std::unordered_map<size_t, std::vector<std::vector<int64_t>>>>;
214220
using DDSOutputAllocatorMap = std::unordered_map<std::string, std::unique_ptr<OutputAllocator>>;
215221
std::string GetWeightRefittedEnginePath(std::string engine_cache_path);
216222

@@ -220,54 +226,51 @@ static const std::string k_ep_ctx_onnx_model_filename = "onnx_model_filename";
220226

221227
/// <summary>
222228
///
223-
/// Plugin TensorRT EP OrtNodeComputeInfo that represents the computation function for a compiled OrtGraph.
224-
///
225-
/// </summary>
226-
struct TRTEpNodeComputeInfo : OrtNodeComputeInfo {
227-
explicit TRTEpNodeComputeInfo(TensorrtExecutionProvider& ep);
228-
229-
static OrtStatus* ORT_API_CALL CreateStateImpl(OrtNodeComputeInfo* this_ptr, OrtNodeComputeContext* compute_context,
230-
void** compute_state);
231-
static OrtStatus* ORT_API_CALL ComputeImpl(OrtNodeComputeInfo* this_ptr, void* compute_state,
232-
OrtKernelContext* kernel_context);
233-
static void ORT_API_CALL ReleaseStateImpl(OrtNodeComputeInfo* this_ptr, void* compute_state);
234-
235-
TensorrtExecutionProvider& ep;
236-
};
237-
238-
/// <summary>
239-
///
240-
/// Plugin TensorRT EP that implements OrtEp
229+
/// Plugin TensorRT EP
241230
///
242231
/// </summary>
243-
struct TensorrtExecutionProvider : OrtEp, ApiPtrs {
244-
TensorrtExecutionProvider(ApiPtrs apis, const std::string& name, const OrtHardwareDevice& device,
245-
const OrtSessionOptions& session_options, const OrtLogger& logger);
232+
struct TensorrtExecutionProvider : public OrtEp, public ApiPtrs {
233+
TensorrtExecutionProvider(TensorrtExecutionProviderFactory& factory, const std::string& name,
234+
const OrtHardwareDevice& device, const OrtSessionOptions& session_options,
235+
const OrtLogger& logger);
246236
~TensorrtExecutionProvider();
247237

238+
TensorrtExecutionProviderFactory& factory_;
248239
std::string name_;
249240
const OrtHardwareDevice& hardware_device_;
250241
const OrtSessionOptions& session_options_;
251242
const OrtLogger& logger_;
252243

253-
SubGraphCollection_t GetSupportedList(SubGraphCollection_t supported_nodes_list, int iterations, const int max_iterations,
254-
const OrtGraph* graph, bool* early_termination) const;
244+
SubGraphCollection_t GetSupportedList(SubGraphCollection_t supported_nodes_list, int iterations,
245+
const int max_iterations, const OrtGraph* graph, bool* early_termination) const;
255246

256247
OrtStatus* CreateNodeComputeInfoFromPrecompiledEngine(OrtEp* this_ptr, const OrtGraph* graph,
257248
const OrtNode* fused_node,
258249
std::unordered_map<std::string, size_t>& input_map,
259250
std::unordered_map<std::string, size_t>& output_map,
260-
OrtNodeComputeInfo* node_compute_info);
251+
OrtNodeComputeInfo** node_compute_info);
261252

262253
OrtStatus* CreateNodeComputeInfoFromGraph(OrtEp* this_ptr, const OrtGraph* graph, const OrtNode* fused_node,
263254
std::unordered_map<std::string, size_t>& input_map,
264255
std::unordered_map<std::string, size_t>& output_map,
265-
OrtNodeComputeInfo* node_compute_info);
256+
OrtNodeComputeInfo** node_compute_info);
257+
258+
OrtStatus* RefitEngine(std::string onnx_model_filename, std::string& onnx_model_folder_path,
259+
std::string& weight_stripped_engine_cath_path, bool path_check,
260+
const void* onnx_model_bytestream, size_t onnx_model_bytestream_size,
261+
nvinfer1::ICudaEngine* trt_engine, bool serialize_refitted_engine,
262+
bool detailed_build_log);
266263

267264
std::unordered_map<std::string, std::unique_ptr<TensorrtComputeState>>& GetComputeStates() { return compute_states_; }
268265

269-
std::unordered_map<std::string, std::unique_ptr<TensorrtComputeState>>& GetComputeStatesForEPContext() { return compute_states_; }
266+
std::unordered_map<std::string, std::unique_ptr<TensorrtComputeState>>& GetComputeStatesForEPContext() {
267+
return compute_states_;
268+
}
269+
270+
void GetAllocator(OrtAllocator** alloc) const { *alloc = alloc_; }
270271

272+
void SetAllocator(OrtAllocator* alloc) { alloc_ = alloc; }
273+
271274
std::unordered_map<std::string, DDSOutputAllocatorMap>& GetDDSOutputAllocators() {
272275
return dds_output_allocator_maps_;
273276
}
@@ -312,6 +315,19 @@ struct TensorrtExecutionProvider : OrtEp, ApiPtrs {
312315
std::unordered_map<std::string, std::string> cache_suffix_;
313316

314317
private:
318+
static const char* ORT_API_CALL GetNameImpl(const OrtEp* this_ptr) noexcept;
319+
static OrtStatus* ORT_API_CALL GetCapabilityImpl(OrtEp* this_ptr, const OrtGraph* graph,
320+
OrtEpGraphSupportInfo* graph_support_info);
321+
static OrtStatus* ORT_API_CALL CompileImpl(_In_ OrtEp* this_ptr, _In_ const OrtGraph** graphs,
322+
_In_ const OrtNode** fused_nodes, _In_ size_t count,
323+
_Out_writes_all_(count) OrtNodeComputeInfo** node_compute_infos,
324+
_Out_writes_(count) OrtNode** ep_context_nodes);
325+
static void ORT_API_CALL ReleaseNodeComputeInfosImpl(OrtEp* this_ptr, OrtNodeComputeInfo** node_compute_infos,
326+
size_t num_node_compute_infos);
327+
328+
OrtStatus* CreateEpContextNodes(gsl::span<const OrtNode*> fused_nodes,
329+
/*out*/ gsl::span<OrtNode*> ep_context_nodes);
330+
315331
mutable TensorrtExecutionProviderInfo info_;
316332
bool external_stream_ = false;
317333
cudaStream_t stream_ = nullptr;
@@ -331,6 +347,8 @@ struct TensorrtExecutionProvider : OrtEp, ApiPtrs {
331347
bool weight_stripped_engine_enable_ = false;
332348
bool weight_stripped_engine_refit_ = false;
333349
std::string onnx_model_folder_path_;
350+
const void* onnx_model_bytestream_;
351+
size_t onnx_model_bytestream_size_;
334352
bool build_heuristics_enable_ = false;
335353
bool sparsity_enable_ = false;
336354
int builder_optimization_level_ = 3;
@@ -344,7 +362,7 @@ struct TensorrtExecutionProvider : OrtEp, ApiPtrs {
344362
bool context_memory_sharing_enable_ = false;
345363
bool layer_norm_fp32_fallback_ = false;
346364
size_t max_ctx_mem_size_ = 0;
347-
IAllocatorUniquePtr<void> context_memory_ = nullptr;
365+
AllocatorUniquePtr<void> context_memory_ = nullptr;
348366
mutable char model_path_[4096] = {}; // Reserved for max path length
349367
bool engine_decryption_enable_ = false;
350368
int (*engine_decryption_)(const char*, char*, size_t*) = nullptr;
@@ -419,3 +437,20 @@ struct TensorrtExecutionProvider : OrtEp, ApiPtrs {
419437

420438
nvinfer1::IBuilder* GetBuilder(TensorrtLogger& trt_logger) const;
421439
};
440+
441+
/// <summary>
442+
///
443+
/// Plugin TensorRT EP OrtNodeComputeInfo that represents the computation function for a compiled OrtGraph.
444+
///
445+
/// </summary>
446+
struct TRTEpNodeComputeInfo : OrtNodeComputeInfo {
447+
explicit TRTEpNodeComputeInfo(TensorrtExecutionProvider& ep);
448+
449+
static OrtStatus* ORT_API_CALL CreateStateImpl(OrtNodeComputeInfo* this_ptr, OrtNodeComputeContext* compute_context,
450+
void** compute_state);
451+
static OrtStatus* ORT_API_CALL ComputeImpl(OrtNodeComputeInfo* this_ptr, void* compute_state,
452+
OrtKernelContext* kernel_context);
453+
static void ORT_API_CALL ReleaseStateImpl(OrtNodeComputeInfo* this_ptr, void* compute_state);
454+
455+
TensorrtExecutionProvider& ep;
456+
};

plugin_execution_providers/tensorrt/tensorrt_execution_provider_utils.h

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
#pragma once
2+
13
#define ORT_API_MANUAL_INIT
24
#include "onnxruntime_cxx_api.h"
35
#undef ORT_API_MANUAL_INIT
@@ -9,6 +11,7 @@
911
// #include "core/framework/murmurhash3.h"
1012

1113
#include"nv_includes.h"
14+
#include "gsl/narrow"
1215

1316
#include <fstream>
1417
#include <unordered_map>
@@ -41,6 +44,42 @@ struct ApiPtrs {
4144

4245
namespace fs = std::filesystem;
4346

47+
template <typename T>
48+
using AllocatorUniquePtr = std::unique_ptr<T, std::function<void(T*)>>;
49+
50+
bool CalcMemSizeForArrayWithAlignment(size_t nmemb, size_t size, size_t alignment, size_t* out) noexcept {
51+
size_t alloc_size = size;
52+
if (alignment == 0) {
53+
*out = alloc_size * nmemb;
54+
} else {
55+
size_t alignment_mask = alignment - 1;
56+
*out = (alloc_size * nmemb + alignment_mask) & ~static_cast<size_t>(alignment_mask);
57+
}
58+
return true;
59+
}
60+
61+
template <typename T>
62+
AllocatorUniquePtr<T> MakeUniquePtrFromOrtAllocator(OrtAllocator* ort_allocator, size_t count_or_bytes,
63+
bool use_reserve = false) {
64+
size_t alloc_size = count_or_bytes;
65+
// if T is not void, 'count_or_bytes' == number of items so allow for that
66+
if constexpr (!std::is_void<T>::value) {
67+
// sizeof(void) isn't valid, but the compiler isn't smart enough to ignore that this line isn't
68+
// reachable if T is void. use std::conditional to 'use' void* in the sizeof call
69+
constexpr auto size = sizeof(typename std::conditional<std::is_void<T>::value, void*, T>::type);
70+
CalcMemSizeForArrayWithAlignment(count_or_bytes, size, 0, &alloc_size);
71+
}
72+
73+
T* p = nullptr;
74+
if (use_reserve) {
75+
p = static_cast<T*>(ort_allocator->Reserve(ort_allocator, alloc_size));
76+
} else {
77+
p = static_cast<T*>(ort_allocator->Alloc(ort_allocator, alloc_size));
78+
}
79+
80+
return AllocatorUniquePtr<T>{p, [ort_allocator](T* p) { ort_allocator->Free(ort_allocator, p); }};
81+
}
82+
4483
// Check if cycle exists in the graph after partitioning
4584
/*
4685
bool FindCycleHelper(size_t i, gsl::span<const InlinedVector<size_t>> adjacency_map, gsl::span<bool> visited,
@@ -168,7 +207,6 @@ std::vector<std::string> SplitToStringVec(std::string const& s, char separator)
168207
return splitted;
169208
}
170209

171-
/*
172210
nvinfer1::TacticSources GetTacticSourceFromString(std::string& tactic_string) {
173211
nvinfer1::TacticSources disabledTactics = 0;
174212
nvinfer1::TacticSources enabledTactics = 0;
@@ -184,7 +222,7 @@ nvinfer1::TacticSources GetTacticSourceFromString(std::string& tactic_string) {
184222

185223
const auto toUpper = [](std::string& sourceName) {
186224
std::transform(sourceName.begin(), sourceName.end(), sourceName.begin(),
187-
[](char c) { return onnxruntime::narrow<char>(std::toupper(c)); });
225+
[](char c) { return gsl::narrow<char>(std::toupper(c)); });
188226
return sourceName;
189227
};
190228

@@ -223,7 +261,6 @@ nvinfer1::TacticSources GetTacticSourceFromString(std::string& tactic_string) {
223261
}
224262
return enabledTactics & ~disabledTactics;
225263
}
226-
*/
227264

228265
inline std::vector<char> loadTimingCacheFile(const std::string inFileName) {
229266
std::ifstream iFile(inFileName, std::ios::in | std::ios::binary);

plugin_execution_providers/tensorrt/tensorrt_provider_factory.cc

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,10 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProviderFactory::CreateDataTransferImpl
206206
return nullptr;
207207
}
208208

209+
OrtMemoryInfo* TensorrtExecutionProviderFactory::GetDefaultMemInfo() const {
210+
return default_gpu_memory_info_.get();
211+
}
212+
209213
// To make symbols visible on macOS/iOS
210214
#ifdef __APPLE__
211215
#define EXPORT_SYMBOL __attribute__((visibility("default")))
@@ -221,10 +225,10 @@ EXPORT_SYMBOL OrtStatus* CreateEpFactories(const char* registration_name, const
221225
OrtEpFactory** factories, size_t max_factories, size_t* num_factories) {
222226
const OrtApi* ort_api = ort_api_base->GetApi(ORT_API_VERSION);
223227
const OrtEpApi* ort_ep_api = ort_api->GetEpApi();
228+
const OrtModelEditorApi* model_editor_api = ort_api->GetModelEditorApi();
224229

225230
// Factory could use registration_name or define its own EP name.
226-
std::unique_ptr<OrtEpFactory> factory = std::make_unique<TensorrtExecutionProviderFactory>(registration_name,
227-
ApiPtrs{*ort_api, *ort_ep_api});
231+
std::unique_ptr<OrtEpFactory> factory = std::make_unique<TensorrtExecutionProviderFactory>(registration_name, ApiPtrs{*ort_api, *ort_ep_api, *model_editor_api});
228232

229233
if (max_factories < 1) {
230234
return ort_api->CreateStatus(ORT_INVALID_ARGUMENT,

plugin_execution_providers/tensorrt/tensorrt_provider_factory.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,10 @@
44
///
55
/// Plugin TensorRT EP factory that can create an OrtEp and return information about the supported hardware devices.
66
///
7-
struct TensorrtExecutionProviderFactory : OrtEpFactory, ApiPtrs {
7+
struct TensorrtExecutionProviderFactory : public OrtEpFactory, public ApiPtrs {
88
public:
99
TensorrtExecutionProviderFactory(const char* ep_name, ApiPtrs apis);
10+
OrtMemoryInfo* GetDefaultMemInfo() const;
1011

1112
private:
1213
static const char* ORT_API_CALL GetNameImpl(const OrtEpFactory* this_ptr) noexcept;

0 commit comments

Comments
 (0)