Skip to content

Commit 09138ee

Browse files
committed
clean up
1 parent c5363e6 commit 09138ee

File tree

7 files changed

+8
-13
lines changed

7 files changed

+8
-13
lines changed

plugin_execution_providers/tensorrt/cuda_allocator.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@
44
#pragma once
55
#include <atomic>
66
#include "onnxruntime_c_api.h"
7-
#define ORT_API_MANUAL_INIT
8-
#include "onnxruntime_cxx_api.h"
97

108
using DeviceId = int16_t;
119

plugin_execution_providers/tensorrt/onnx_ctx_model_helper.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ bool EPContextNodeHelper::GraphHasCtxNode(const OrtGraph* graph, const OrtApi& o
2121
RETURN_IF_ERROR(ort_api.Graph_GetNumNodes(graph, &num_nodes));
2222

2323
std::vector<const OrtNode*> nodes(num_nodes);
24+
RETURN_IF_ERROR(ort_api.Graph_GetNodes(graph, nodes.data(), nodes.size()));
2425

2526
for (size_t i = 0; i < num_nodes; ++i) {
2627
auto node = nodes[i];

plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,7 @@
55
#include <iostream>
66
#include <cuda_runtime.h>
77

8-
#define ORT_API_MANUAL_INIT
98
#include "onnxruntime_cxx_api.h"
10-
#undef ORT_API_MANUAL_INIT
119

1210
#define ORT_EP_UTILS_ORT_GRAPH_TO_PROTO_IMPL
1311
#include "ort_graph_to_proto.h"

plugin_execution_providers/tensorrt/tensorrt_execution_provider.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,5 @@
11
#pragma once
22

3-
#define ORT_API_MANUAL_INIT
4-
#include "onnxruntime_cxx_api.h"
5-
#undef ORT_API_MANUAL_INIT
6-
73
#include "tensorrt_provider_factory.h"
84
#include "utils/provider_options.h"
95
#include "tensorrt_execution_provider_info.h"

plugin_execution_providers/tensorrt/tensorrt_execution_provider_utils.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
#pragma once
22

3-
#define ORT_API_MANUAL_INIT
43
#include "onnxruntime_cxx_api.h"
5-
#undef ORT_API_MANUAL_INIT
64

75
#include "ep_utils.h"
86
#include "flatbuffers/idl.h"

plugin_execution_providers/tensorrt/tensorrt_provider_factory.cc

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
1-
#define ORT_API_MANUAL_INIT
21
#include "onnxruntime_cxx_api.h"
3-
#undef ORT_API_MANUAL_INIT
42
#include "tensorrt_provider_factory.h"
53
#include "tensorrt_execution_provider.h"
64
#include "cuda_allocator.h"
@@ -145,7 +143,7 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProviderFactory::GetSupportedDevicesImp
145143
cuda_gpu_mem_devices, // device memory
146144
cuda_pinned_mem_devices // shared memory
147145
);
148-
146+
factory->SetGPUDataTransfer(std::move(data_transfer_impl));
149147
return nullptr;
150148
}
151149

@@ -283,6 +281,10 @@ void TensorrtExecutionProviderFactory::SetHostAccessibleMemInfo(MemoryInfoUnique
283281
cuda_pinned_memory_infos_.push_back(std::move(mem_info));
284282
}
285283

284+
void TensorrtExecutionProviderFactory::SetGPUDataTransfer(std::unique_ptr<TRTEpDataTransfer> gpu_data_transfer) {
285+
data_transfer_impl_ = std::move(gpu_data_transfer);
286+
}
287+
286288
// To make symbols visible on macOS/iOS
287289
#ifdef __APPLE__
288290
#define EXPORT_SYMBOL __attribute__((visibility("default")))

plugin_execution_providers/tensorrt/tensorrt_provider_factory.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@ struct TensorrtExecutionProviderFactory : public OrtEpFactory, public ApiPtrs {
5252

5353
void SetHostAccessibleMemInfo(MemoryInfoUniquePtr mem_info, uint32_t device_id);
5454

55+
void SetGPUDataTransfer(std::unique_ptr<TRTEpDataTransfer> gpu_data_transfer);
56+
5557
const std::string ep_name_; // EP name
5658
const std::string vendor_{"Nvidia"}; // EP vendor name
5759
const std::string ep_version_{"0.1.0"}; // EP version

0 commit comments

Comments
 (0)