Skip to content

Commit 731ed72

Browse files
committed
fix a bunch of compile errors
1 parent 938a3fe commit 731ed72

File tree

7 files changed

+45
-62
lines changed

7 files changed

+45
-62
lines changed

plugin_execution_providers/tensorrt/cuda_allocator.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,14 @@ constexpr const char* CUDA_PINNED_ALLOCATOR = "CudaPinned";
1313
using DeviceId = int16_t;
1414

1515
struct CUDAAllocator : OrtAllocator {
16-
CUDAAllocator(DeviceId device_id, const char* name = CUDA_ALLOCATOR) {
16+
CUDAAllocator(const OrtMemoryInfo* mem_info, const char* name = CUDA_ALLOCATOR) {
1717
OrtAllocator::version = ORT_API_VERSION;
1818
OrtAllocator::Alloc = [](OrtAllocator* this_, size_t size) { return static_cast<CUDAAllocator*>(this_)->Alloc(size); };
1919
OrtAllocator::Free = [](OrtAllocator* this_, void* p) { static_cast<CUDAAllocator*>(this_)->Free(p); };
2020
OrtAllocator::Info = [](const OrtAllocator* this_) { return static_cast<const CUDAAllocator*>(this_)->Info(); };
2121

22+
mem_info_ = mem_info;
23+
2224
device_id_ = device_id;
2325

2426
const OrtApi* api = OrtGetApiBase()->GetApi(ORT_API_VERSION);
@@ -44,7 +46,7 @@ struct CUDAAllocator : OrtAllocator {
4446
void SetDevice(bool throw_when_fail) const;
4547

4648
DeviceId device_id_;
47-
OrtMemoryInfo* mem_info_ = nullptr;
49+
const OrtMemoryInfo* mem_info_ = nullptr;
4850
};
4951

5052
struct CUDAPinnedAllocator : OrtAllocator {

plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc

Lines changed: 17 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1156,7 +1156,7 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this
11561156
weight_stripped_engine_refit_ = true;
11571157
}
11581158

1159-
std::unique_ptr<nvinfer1::IHostMemory> serialized_engine = nullptr;
1159+
std::unique_ptr<nvinfer1::IHostMemory> serialized_engine;
11601160

11611161
if (!has_dynamic_shape) {
11621162
std::string timing_cache_path = "";
@@ -1258,7 +1258,7 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this
12581258
}
12591259

12601260
serialized_engine =
1261-
std::make_unique<nvinfer1::IHostMemory>(trt_builder->buildSerializedNetwork(*trt_network, *trt_config));
1261+
std::unique_ptr<nvinfer1::IHostMemory>(trt_builder->buildSerializedNetwork(*trt_network, *trt_config));
12621262

12631263
if (serialized_engine == nullptr) {
12641264
std::string err_msg = "TensorRT EP failed to create engine from network for fused node: " + fused_node_name;
@@ -1390,32 +1390,9 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this
13901390
input_shape_ranges_[fused_node_name] = input_implicit_shape_ranges;
13911391
profiles_.emplace(fused_node_name, std::move(trt_profiles));
13921392

1393-
/*
1394-
// For dynamic shape input model, firstly TRT EP creates a model proto which includes inputs, outputs and empty
1395-
// engine. TRT EP will serialize the model at inference time due to engine can be updated and the updated engine
1396-
// should be included in the model. However, if the embed_mode is 0 (only includes engine path), TRT EP will serialize
1397-
// it here.
1398-
if (dump_ep_context_model_ && has_dynamic_shape) {
1399-
// "ep_cache_context" node attribute should be a relative path to context model directory
1400-
if (ep_cache_context_attr_.empty()) {
1401-
auto cache_file_name = std::filesystem::path(engine_cache_path).filename();
1402-
ep_cache_context_attr_ = std::filesystem::path(engine_cache_relative_path_to_context_model_dir)
1403-
.append(cache_file_name.string())
1404-
.string();
1405-
}
1406-
std::string compute_capability_hw_compat = compute_capability_;
1407-
if (engine_cache_enable_ && engine_hw_compatible_) {
1408-
compute_capability_hw_compat = "80+";
1409-
}
1410-
model_proto_.reset(CreateCtxModel(graph_body_viewer, ep_cache_context_attr_, nullptr, 0, ep_context_embed_mode_,
1411-
compute_capability_hw_compat, model_path_, GetLogger()));
1412-
if (ep_context_embed_mode_ == 0) {
1413-
DumpCtxModel(model_proto_.get(), ctx_model_path_);
1414-
}
1415-
}
1416-
*/
14171393

1418-
std::unique_ptr<EPContextNodeHelper> ep_ctx_node_helper = std::make_unique<EPContextNodeHelper>(graph, fused_node);
1394+
// Create EP Context nodes
1395+
std::unique_ptr<EPContextNodeHelper> ep_ctx_node_helper = std::make_unique<EPContextNodeHelper>(*ep, graph, fused_node);
14191396
if (dump_ep_context_model_) {
14201397
std::string compute_capability_hw_compat = compute_capability_;
14211398
if (engine_cache_enable_ && engine_hw_compatible_) {
@@ -1490,6 +1467,8 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this
14901467
engine_hw_compatible_,
14911468
sync_stream_after_enqueue_};
14921469

1470+
ep->compute_states_[fused_node_name] = std::move(compute_state);
1471+
14931472
// Update the OrtNodeComputeInfo associated with the graph.
14941473
auto ep_node_compute_info = std::make_unique<TRTEpNodeComputeInfo>(*ep);
14951474
*node_compute_info = ep_node_compute_info.release();
@@ -1554,10 +1533,10 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProvider::GetCapabilityImpl(OrtEp* this
15541533
auto supported_control_flow_op = [&](const OrtNode* node) {
15551534
OrtStatus* status = nullptr;
15561535
size_t num_subgraphs = 0;
1557-
RETURN_FALSE_AND_PRINT_IF_ERROR(ort_api.Node_GetNumSubgraphs(node, &num_subgraphs), ort_api);
1536+
RETURN_FALSE_AND_PRINT_IF_ERROR(ort_api.Node_GetNumSubgraphs(node, &num_subgraphs));
15581537

15591538
std::vector<const OrtGraph*> node_subgraphs(num_subgraphs);
1560-
RETURN_FALSE_AND_PRINT_IF_ERROR(ort_api.Node_GetSubgraphs(node, node_subgraphs.data(), node_subgraphs.size(), nullptr), ort_api);
1539+
RETURN_FALSE_AND_PRINT_IF_ERROR(ort_api.Node_GetSubgraphs(node, node_subgraphs.data(), node_subgraphs.size(), nullptr));
15611540

15621541

15631542
// Iterate the node's subgraphs
@@ -1566,7 +1545,7 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProvider::GetCapabilityImpl(OrtEp* this
15661545

15671546
// Get number of subgraph's nodes
15681547
size_t num_subgraph_nodes = 0;
1569-
RETURN_FALSE_AND_PRINT_IF_ERROR(ort_api.Graph_GetNumNodes(subgraph, &num_subgraph_nodes), ort_api);
1548+
RETURN_FALSE_AND_PRINT_IF_ERROR(ort_api.Graph_GetNumNodes(subgraph, &num_subgraph_nodes));
15701549

15711550
// TRT EP should consider the empty subgraph is fully supported by TRT.
15721551
if (num_subgraph_nodes == 0) {
@@ -1926,13 +1905,11 @@ OrtStatus* TensorrtExecutionProvider::RefitEngine(
19261905
/// </summary>
19271906
TensorrtExecutionProvider::TensorrtExecutionProvider(TensorrtExecutionProviderFactory& factory,
19281907
const std::string& name,
1929-
const OrtHardwareDevice& device,
19301908
const OrtSessionOptions& session_options,
19311909
const OrtLogger& logger)
19321910
: ApiPtrs{static_cast<const ApiPtrs&>(factory)},
19331911
factory_(factory),
19341912
name_{name},
1935-
hardware_device_{device},
19361913
session_options_{session_options},
19371914
logger_{logger} {
19381915

@@ -2176,7 +2153,7 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(TensorrtExecutionProviderFa
21762153
* Please refer to ParserProfileShapes() for more details)
21772154
*
21782155
*/
2179-
bool status = true;
2156+
// bool status = true;
21802157
// if (status) {
21812158
// status = ParseProfileShapes(profile_min_shapes, profile_min_shapes_);
21822159
// if (!status) {
@@ -2266,14 +2243,14 @@ OrtStatus* TRTEpNodeComputeInfo::CreateStateImpl(OrtNodeComputeInfo* this_ptr, O
22662243
TensorrtExecutionProvider& ep = node_compute_info->ep;
22672244

22682245
std::string fused_node_name = ep.ep_api.NodeComputeContext_NodeName(compute_context);
2269-
auto state_it = ep.GetComputeStates().find(fused_node_name);
2270-
if (state_it == ep.GetComputeStates().end()) {
2246+
auto state_it = ep.compute_states_.find(fused_node_name);
2247+
if (state_it == ep.compute_states_.end()) {
22712248
std::string message = "Unable to TensorRT EP's compute state for fused node with name " + fused_node_name;
22722249
return ep.ort_api.CreateStatus(ORT_EP_FAIL, message.c_str());
22732250
}
22742251

2275-
TensorrtComputeState& compute_state = *state_it->second;
2276-
*compute_state = &compute_state;
2252+
TensorrtComputeState& trt_ep_compute_state = *state_it->second;
2253+
*compute_state = &trt_ep_compute_state;
22772254
return nullptr;
22782255
}
22792256

@@ -2335,7 +2312,7 @@ OrtStatus* TRTEpNodeComputeInfo::ComputeImpl(OrtNodeComputeInfo* this_ptr, void*
23352312
bool context_update = false;
23362313
std::unordered_set<std::string> input_names;
23372314

2338-
std::unordered_map<std::string, DDSOutputAllocatorMap> dds_output_allocator_maps = ep.GetDDSOutputAllocators();
2315+
std::unordered_map<std::string, DDSOutputAllocatorMap>& dds_output_allocator_maps = ep.GetDDSOutputAllocators();
23392316
auto& dds_output_allocator_map = dds_output_allocator_maps[fused_node_name];
23402317

23412318
// Get default OrtMemoryInfo from factory
@@ -2911,7 +2888,7 @@ OrtStatus* TRTEpNodeComputeInfo::ComputeImpl(OrtNodeComputeInfo* this_ptr, void*
29112888

29122889
void TRTEpNodeComputeInfo::ReleaseStateImpl(OrtNodeComputeInfo* this_ptr, void* compute_state) {
29132890
(void)this_ptr;
2914-
TensorrtComputeState& compute_state = *reinterpret_cast<TensorrtComputeState*>(compute_state);
2915-
(void)compute_state;
2891+
TensorrtComputeState& trt_ep_compute_state = *reinterpret_cast<TensorrtComputeState*>(compute_state);
2892+
(void)trt_ep_compute_state;
29162893
// Do nothing for here.
29172894
}

plugin_execution_providers/tensorrt/tensorrt_execution_provider.h

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include <string>
1414
#include <unordered_set>
1515
#include <mutex>
16+
#include <gsl/span>
1617

1718
#ifdef _WIN32
1819
#define EXPORT_API __declspec(dllexport)
@@ -231,16 +232,18 @@ static const std::string k_ep_ctx_onnx_model_filename = "onnx_model_filename";
231232
/// </summary>
232233
struct TensorrtExecutionProvider : public OrtEp, public ApiPtrs {
233234
TensorrtExecutionProvider(TensorrtExecutionProviderFactory& factory, const std::string& name,
234-
const OrtHardwareDevice& device, const OrtSessionOptions& session_options,
235+
const OrtSessionOptions& session_options,
235236
const OrtLogger& logger);
236237
~TensorrtExecutionProvider();
237238

238239
TensorrtExecutionProviderFactory& factory_;
239240
std::string name_;
240-
const OrtHardwareDevice& hardware_device_;
241241
const OrtSessionOptions& session_options_;
242242
const OrtLogger& logger_;
243243

244+
std::unordered_map<std::string, std::unique_ptr<TensorrtComputeState>> compute_states_;
245+
std::unordered_map<std::string, std::unique_ptr<TensorrtComputeStateForEPContext>> compute_states_for_ep_context_;
246+
244247
SubGraphCollection_t GetSupportedList(SubGraphCollection_t supported_nodes_list, int iterations,
245248
const int max_iterations, const OrtGraph* graph, bool* early_termination) const;
246249

@@ -262,12 +265,6 @@ struct TensorrtExecutionProvider : public OrtEp, public ApiPtrs {
262265
nvinfer1::ICudaEngine* trt_engine, bool serialize_refitted_engine,
263266
bool detailed_build_log);
264267

265-
std::unordered_map<std::string, std::unique_ptr<TensorrtComputeState>>& GetComputeStates() { return compute_states_; }
266-
267-
std::unordered_map<std::string, std::unique_ptr<TensorrtComputeState>>& GetComputeStatesForEPContext() {
268-
return compute_states_;
269-
}
270-
271268
void GetAllocator(OrtAllocator** alloc) const { *alloc = alloc_; }
272269

273270
void SetAllocator(OrtAllocator* alloc) { alloc_ = alloc; }
@@ -415,9 +412,6 @@ struct TensorrtExecutionProvider : public OrtEp, public ApiPtrs {
415412
std::unordered_map<std::string, std::vector<nvinfer1::IOptimizationProfile*>> profiles_;
416413
std::unordered_map<std::string, DDSOutputAllocatorMap> dds_output_allocator_maps_;
417414

418-
std::unordered_map<std::string, std::unique_ptr<TensorrtComputeState>> compute_states_;
419-
std::unordered_map<std::string, std::unique_ptr<TensorrtComputeStateForEPContext>> compute_states_for_ep_context;
420-
421415
// for external stream, we need to create its cudnn/cublass handle before cuda EP enable cuda graph capture
422416
// cudnnHandle_t external_cudnn_handle_ = nullptr;
423417
// cublasHandle_t external_cublas_handle_ = nullptr;

plugin_execution_providers/tensorrt/tensorrt_execution_provider_utils.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,15 @@ std::string ComposeString(Args&&... args) {
8282
} \
8383
} while (0)
8484

85+
#define RETURN_FALSE_AND_PRINT_IF_ERROR(fn) \
86+
do { \
87+
OrtStatus* status = (fn); \
88+
if (status != nullptr) { \
89+
std::cerr << Ort::GetApi().GetErrorMessage(status) << std::endl; \
90+
return false; \
91+
} \
92+
} while (0)
93+
8594
// Helper to release Ort one or more objects obtained from the public C API at the end of their scope.
8695
template <typename T>
8796
struct DeferOrtRelease {

plugin_execution_providers/tensorrt/tensorrt_provider_factory.cc

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ TensorrtExecutionProviderFactory::TensorrtExecutionProviderFactory(const char* e
3131

3232
// Default GPU allocator OrtMemoryInfo
3333
OrtMemoryInfo* mem_info = nullptr;
34-
auto* status = ort_api.CreateMemoryInfo_V2("ExampleEP GPU", OrtMemoryInfoDeviceType_GPU,
34+
auto* status = ort_api.CreateMemoryInfo_V2("Cuda", OrtMemoryInfoDeviceType_GPU,
3535
/*vendor*/ 0x10DE, /* device_id */ 0, OrtDeviceMemoryType_DEFAULT,
3636
/*alignment*/ 0, OrtAllocatorType::OrtDeviceAllocator, &mem_info);
3737
assert(status == nullptr); // should never fail.
@@ -40,7 +40,7 @@ TensorrtExecutionProviderFactory::TensorrtExecutionProviderFactory(const char* e
4040
// CUDA PINNED allocator OrtMemoryInfo
4141
// HOST_ACCESSIBLE memory should use the non-CPU device type
4242
mem_info = nullptr;
43-
status = ort_api.CreateMemoryInfo_V2("ExampleEP GPU pinned", OrtMemoryInfoDeviceType_GPU,
43+
status = ort_api.CreateMemoryInfo_V2("CudaPinned", OrtMemoryInfoDeviceType_GPU,
4444
/*vendor*/ 0x10DE, /* device_id */ 0, OrtDeviceMemoryType_HOST_ACCESSIBLE,
4545
/*alignment*/ 0, OrtAllocatorType::OrtDeviceAllocator, &mem_info);
4646
assert(status == nullptr); // should never fail.
@@ -56,12 +56,12 @@ TensorrtExecutionProviderFactory::TensorrtExecutionProviderFactory(const char* e
5656
data_transfer_impl_.reset(); // but we're CPU only so we return nullptr for the IDataTransfer.
5757
}
5858

59-
const char* ORT_API_CALL TensorrtExecutionProviderFactory::GetNameImpl(const OrtEpFactory* this_ptr) {
59+
const char* ORT_API_CALL TensorrtExecutionProviderFactory::GetNameImpl(const OrtEpFactory* this_ptr) noexcept {
6060
const auto* factory = static_cast<const TensorrtExecutionProviderFactory*>(this_ptr);
6161
return factory->ep_name_.c_str();
6262
}
6363

64-
const char* ORT_API_CALL TensorrtExecutionProviderFactory::GetVendorImpl(const OrtEpFactory* this_ptr) {
64+
const char* ORT_API_CALL TensorrtExecutionProviderFactory::GetVendorImpl(const OrtEpFactory* this_ptr) noexcept {
6565
const auto* factory = static_cast<const TensorrtExecutionProviderFactory*>(this_ptr);
6666
return factory->vendor_.c_str();
6767
}
@@ -72,7 +72,7 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProviderFactory::GetSupportedDevicesImp
7272
size_t num_devices,
7373
OrtEpDevice** ep_devices,
7474
size_t max_ep_devices,
75-
size_t* p_num_ep_devices) {
75+
size_t* p_num_ep_devices) noexcept {
7676
size_t& num_ep_devices = *p_num_ep_devices;
7777
auto* factory = static_cast<TensorrtExecutionProviderFactory*>(this_ptr);
7878

@@ -133,8 +133,7 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProviderFactory::CreateEpImpl(
133133
_In_reads_(num_devices) const OrtKeyValuePairs* const* /*ep_metadata*/,
134134
_In_ size_t num_devices,
135135
_In_ const OrtSessionOptions* session_options,
136-
_In_ const OrtLogger* logger,
137-
_Out_ OrtEp** ep) {
136+
_In_ const OrtLogger* logger, _Out_ OrtEp** ep) noexcept {
138137
auto* factory = static_cast<TensorrtExecutionProviderFactory*>(this_ptr);
139138
*ep = nullptr;
140139

@@ -161,7 +160,7 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProviderFactory::CreateEpImpl(
161160
return nullptr;
162161
}
163162

164-
void ORT_API_CALL TensorrtExecutionProviderFactory::ReleaseEpImpl(OrtEpFactory* /*this_ptr*/, OrtEp* ep) {
163+
void ORT_API_CALL TensorrtExecutionProviderFactory::ReleaseEpImpl(OrtEpFactory* /*this_ptr*/, OrtEp* ep) noexcept {
165164
TensorrtExecutionProvider* trt_ep = static_cast<TensorrtExecutionProvider*>(ep);
166165
delete trt_ep;
167166
}

plugin_execution_providers/tensorrt/tensorrt_provider_factory.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
#pragma once
2+
13
#include "tensorrt_execution_provider_utils.h"
24
#include "tensorrt_execution_provider_data_transfer.h"
35

plugin_execution_providers/tensorrt/utils/ort_graph_to_proto.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@
8181
#define INCLUDE_ONNXRUNTIME_CORE_PROVIDERS_UTILS_ORT_GRAPH_TO_PROTO_H_
8282

8383
#include <functional>
84-
#include "core/session/onnxruntime_cxx_api.h"
84+
#include "onnxruntime_cxx_api.h"
8585
#include "onnx/onnx_pb.h"
8686

8787
namespace OrtEpUtils {

0 commit comments

Comments
 (0)