Skip to content

Commit a65908f

Browse files
committed
Fix compile errors/issues
1 parent 35b0cf1 commit a65908f

11 files changed

+55
-107
lines changed

plugin_execution_providers/tensorrt/onnx_ctx_model_helper.cc

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
#include <fstream>
66
#include <filesystem>
77

8-
#include "tensorrt_execution_provider_utils.h"
8+
#include "ep_utils.h"
99
#include "onnx_ctx_model_helper.h"
1010

1111
extern TensorrtLogger& GetTensorrtLogger(bool verbose_log);
@@ -109,3 +109,20 @@ OrtStatus* EPContextNodeHelper::CreateEPContextNode(const std::string& engine_ca
109109

110110
return nullptr;
111111
}
112+
113+
/*
114+
* Get the weight-refitted engine cache path from a weight-stripped engine cache path
115+
*
116+
* Weight-stipped engine:
117+
* An engine with weights stripped and its size is smaller than a regualr engine.
118+
* The cache name of weight-stripped engine is TensorrtExecutionProvider_TRTKernel_XXXXX.stripped.engine
119+
*
120+
* Weight-refitted engine:
121+
* An engine that its weights have been refitted and it's simply a regular engine.
122+
* The cache name of weight-refitted engine is TensorrtExecutionProvider_TRTKernel_XXXXX.engine
123+
*/
124+
std::string GetWeightRefittedEnginePath(std::string stripped_engine_cache) {
125+
std::filesystem::path stripped_engine_cache_path(stripped_engine_cache);
126+
std::string refitted_engine_cache_path = stripped_engine_cache_path.stem().stem().string() + ".engine";
127+
return refitted_engine_cache_path;
128+
}

plugin_execution_providers/tensorrt/onnx_ctx_model_helper.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33

44
#pragma once
55

6-
#include "tensorrt_execution_provider_utils.h"
76
#include "tensorrt_execution_provider.h"
7+
#include "ep_utils.h"
88
#include "nv_includes.h"
99

1010
#include <string>

plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,13 @@
1212
#define ORT_EP_UTILS_ORT_GRAPH_TO_PROTO_IMPL
1313
#include "ort_graph_to_proto.h"
1414

15-
//#include "tensorrt_execution_provider_utils.h"
15+
#include "tensorrt_execution_provider_utils.h"
1616
#include "tensorrt_execution_provider.h"
1717
#include "cuda_allocator.h"
1818
#include "onnx_ctx_model_helper.h"
1919
#include "onnx/onnx_pb.h"
2020
#include "cuda/unary_elementwise_ops_impl.h"
21+
#include "ep_utils.h"
2122

2223
#ifdef _WIN32
2324
#include <windows.h>
@@ -31,6 +32,10 @@
3132
#define LIBFUNC(lib, fn) dlsym((lib), (fn))
3233
#endif
3334

35+
const OrtApi* g_ort_api = nullptr;
36+
const OrtEpApi* g_ep_api = nullptr;
37+
const OrtModelEditorApi* g_model_editor_api = nullptr;
38+
3439
void CUDA_RETURN_IF_ERROR(cudaError_t res) {
3540
if (res != cudaSuccess) abort();
3641
}
@@ -1795,9 +1800,9 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProvider::CompileImpl(_In_ OrtEp* this_
17951800

17961801
OrtStatus* status;
17971802
if (EPContextNodeHelper::GraphHasCtxNode(graphs[fused_node_idx], ort_api)) {
1798-
RETURN_IF_ERROR(ep->CreateNodeComputeInfoFromPrecompiledEngine(this_ptr, graphs[fused_node_idx], fused_node,
1799-
input_map, output_map,
1800-
&node_compute_infos_result[fused_node_idx]));
1803+
//RETURN_IF_ERROR(ep->CreateNodeComputeInfoFromPrecompiledEngine(this_ptr, graphs[fused_node_idx], fused_node,
1804+
// input_map, output_map,
1805+
// &node_compute_infos_result[fused_node_idx]));
18011806
} else {
18021807
RETURN_IF_ERROR(ep->CreateNodeComputeInfoFromGraph(this_ptr, graphs[fused_node_idx], fused_node, input_map,
18031808
output_map, &node_compute_infos_result[fused_node_idx],
@@ -1899,6 +1904,8 @@ OrtStatus* TensorrtExecutionProvider::RefitEngine(
18991904
#endif
19001905
}
19011906

1907+
TensorrtExecutionProvider::~TensorrtExecutionProvider() = default;
1908+
19021909
/// <summary>
19031910
///
19041911
/// Plugin TensorRT EP that implements OrtEp
@@ -1908,7 +1915,8 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(TensorrtExecutionProviderFa
19081915
const std::string& name,
19091916
const OrtSessionOptions& session_options,
19101917
const OrtLogger& logger)
1911-
: ApiPtrs{static_cast<const ApiPtrs&>(factory)},
1918+
: OrtEp{}, // explicitly call the struct ctor to ensure all optional values are default initialized
1919+
ApiPtrs{static_cast<const ApiPtrs&>(factory)},
19121920
factory_(factory),
19131921
name_{name},
19141922
session_options_{session_options},
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
LIBRARY "TensorRTEp.dll"
2+
EXPORTS
3+
CreateEpFactories @1
4+
ReleaseEpFactory @2
5+
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
VERS_1.0.0 {
2+
global:
3+
CreateEpFactories;
4+
ReleaseEpFactory;
5+
local:
6+
*;
7+
};

plugin_execution_providers/tensorrt/tensorrt_execution_provider_data_transfer.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
#pragma once
55

6-
#include "tensorrt_execution_provider_utils.h"
6+
#include "ep_utils.h"
77

88
struct TRTEpDataTransfer : OrtDataTransferImpl, ApiPtrs {
99
TRTEpDataTransfer(ApiPtrs api_ptrs, std::vector<const OrtMemoryDevice*> device_mem_infos,

plugin_execution_providers/tensorrt/tensorrt_execution_provider_info.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include "tensorrt_execution_provider_info.h"
77
#include "provider_options_utils.h"
88
#include "cuda/cuda_common.h"
9+
#include "ep_utils.h"
910

1011
namespace tensorrt {
1112
namespace provider_option_names {

plugin_execution_providers/tensorrt/tensorrt_execution_provider_info.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
#pragma once
55

6-
#include "tensorrt_execution_provider_utils.h"
76
#include "provider_options.h"
87

98
#include <string>

plugin_execution_providers/tensorrt/tensorrt_execution_provider_utils.h

Lines changed: 1 addition & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include "onnxruntime_cxx_api.h"
55
#undef ORT_API_MANUAL_INIT
66

7+
#include "ep_utils.h"
78
#include "flatbuffers/idl.h"
89
#include "ort_trt_int8_cal_table.fbs.h"
910
#include "make_string.h"
@@ -22,104 +23,8 @@
2223
#include <iostream>
2324
#include <filesystem>
2425

25-
struct ApiPtrs {
26-
const OrtApi& ort_api;
27-
const OrtEpApi& ep_api;
28-
const OrtModelEditorApi& model_editor_api;
29-
};
30-
31-
const OrtApi* g_ort_api = nullptr;
32-
const OrtEpApi* g_ep_api = nullptr;
33-
const OrtModelEditorApi* g_model_editor_api = nullptr;
34-
35-
#define ENFORCE(condition, ...) \
36-
do { \
37-
if (!(condition)) { \
38-
throw std::runtime_error(MakeString(__VA_ARGS__)); \
39-
} \
40-
} while (false)
41-
42-
#define THROW(...) \
43-
throw std::runtime_error(MakeString(__VA_ARGS__));
44-
45-
#define RETURN_IF_ERROR(fn) \
46-
do { \
47-
OrtStatus* _status = (fn); \
48-
if (_status != nullptr) { \
49-
return _status; \
50-
} \
51-
} while (0)
52-
53-
/*
54-
template <typename... Args>
55-
std::string ComposeString(Args&&... args) {
56-
std::ostringstream oss;
57-
(oss << ... << args);
58-
return oss.str();
59-
};
60-
*/
61-
62-
#define RETURN_IF(cond, ...) \
63-
do { \
64-
if ((cond)) { \
65-
return Ort::GetApi().CreateStatus(ORT_EP_FAIL, MakeString(__VA_ARGS__).c_str()); \
66-
} \
67-
} while (0)
68-
69-
#define RETURN_IF_NOT(condition, ...) RETURN_IF(!(condition), __VA_ARGS__)
70-
71-
#define MAKE_STATUS(error_code, msg) \
72-
Ort::GetApi().CreateStatus(error_code, (msg));
73-
74-
#define THROW_IF_ERROR(expr) \
75-
do { \
76-
auto _status = (expr); \
77-
if (_status != nullptr) { \
78-
std::ostringstream oss; \
79-
oss << Ort::GetApi().GetErrorMessage(_status); \
80-
Ort::GetApi().ReleaseStatus(_status); \
81-
throw std::runtime_error(oss.str()); \
82-
} \
83-
} while (0)
84-
85-
#define RETURN_FALSE_AND_PRINT_IF_ERROR(fn) \
86-
do { \
87-
OrtStatus* status = (fn); \
88-
if (status != nullptr) { \
89-
std::cerr << Ort::GetApi().GetErrorMessage(status) << std::endl; \
90-
return false; \
91-
} \
92-
} while (0)
93-
94-
// Helper to release Ort one or more objects obtained from the public C API at the end of their scope.
95-
template <typename T>
96-
struct DeferOrtRelease {
97-
DeferOrtRelease(T** object_ptr, std::function<void(T*)> release_func)
98-
: objects_(object_ptr), count_(1), release_func_(release_func) {}
99-
100-
DeferOrtRelease(T** objects, size_t count, std::function<void(T*)> release_func)
101-
: objects_(objects), count_(count), release_func_(release_func) {}
102-
103-
~DeferOrtRelease() {
104-
if (objects_ != nullptr && count_ > 0) {
105-
for (size_t i = 0; i < count_; ++i) {
106-
if (objects_[i] != nullptr) {
107-
release_func_(objects_[i]);
108-
objects_[i] = nullptr;
109-
}
110-
}
111-
}
112-
}
113-
T** objects_ = nullptr;
114-
size_t count_ = 0;
115-
std::function<void(T*)> release_func_ = nullptr;
116-
};
117-
11826
namespace fs = std::filesystem;
11927

120-
template <typename T>
121-
using AllocatorUniquePtr = std::unique_ptr<T, std::function<void(T*)>>;
122-
12328
bool CalcMemSizeForArrayWithAlignment(size_t nmemb, size_t size, size_t alignment, size_t* out) noexcept {
12429
size_t alloc_size = size;
12530
if (alignment == 0) {

plugin_execution_providers/tensorrt/tensorrt_provider_factory.cc

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,11 +58,16 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProviderFactory::GetSupportedDevicesImp
5858

5959
std::vector<const OrtMemoryDevice*> cuda_gpu_mem_devices;
6060
std::vector<const OrtMemoryDevice*> cuda_pinned_mem_devices;
61+
int GPU_cnt = 0;
6162

6263
for (size_t i = 0; i < num_devices && num_ep_devices < max_ep_devices; ++i) {
6364
// C API
6465
const OrtHardwareDevice& device = *devices[i];
6566
if (factory->ort_api.HardwareDevice_Type(&device) == OrtHardwareDeviceType::OrtHardwareDeviceType_GPU) {
67+
if (GPU_cnt > 0) {
68+
continue;
69+
}
70+
GPU_cnt++;
6671
// These can be returned as nullptr if you have nothing to add.
6772
OrtKeyValuePairs* ep_metadata = nullptr;
6873
OrtKeyValuePairs* ep_options = nullptr;
@@ -87,7 +92,8 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProviderFactory::GetSupportedDevicesImp
8792
}
8893

8994
uint32_t vendor_id = factory->ort_api.HardwareDevice_VendorId(&device);
90-
uint32_t device_id = factory->ort_api.HardwareDevice_DeviceId(&device);
95+
//uint32_t device_id = factory->ort_api.HardwareDevice_DeviceId(&device);
96+
uint32_t device_id = 0;
9197

9298
// CUDA allocator OrtMemoryInfo
9399
OrtMemoryInfo* mem_info = nullptr;

0 commit comments

Comments
 (0)