Skip to content

Commit 5ab50ac

Browse files
committed
address reviewer's comments
1 parent c69dd60 commit 5ab50ac

File tree

7 files changed

+31
-68
lines changed

7 files changed

+31
-68
lines changed

plugin_execution_providers/tensorrt/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ endif()
2727

2828
add_definitions(-DONNX_NAMESPACE=onnx)
2929
add_definitions(-DONNX_ML)
30-
add_definitions(-DNV_TENSORRT_MAJOR=10)
3130
add_definitions(-DNOMINMAX)
3231
file(GLOB tensorrt_src "./*.cc" "./utils/*.cc" "./cuda/unary_elementwise_ops_impl.cu" "./*.h")
3332
add_library(TensorRTEp SHARED ${tensorrt_src})

plugin_execution_providers/tensorrt/cuda_allocator.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,6 @@ struct CUDAAllocator : OrtAllocator {
1919
OrtAllocator::GetStats = nullptr;
2020
OrtAllocator::AllocOnStream = nullptr; // Allocate memory, handling usage across different Streams. Not used for TRT EP.
2121
}
22-
// TODO: Handle destructor
23-
//~CUDAAllocator();
2422

2523
void* Alloc(size_t size);
2624
void Free(void* p);
@@ -48,8 +46,6 @@ struct CUDAPinnedAllocator : OrtAllocator {
4846
OrtAllocator::GetStats = nullptr;
4947
OrtAllocator::AllocOnStream = nullptr;
5048
}
51-
// TODO: Handle destructor
52-
//~CUDAPinnedAllocator();
5349

5450
void* Alloc(size_t size);
5551
void Free(void* p);

plugin_execution_providers/tensorrt/onnx_ctx_model_helper.cc

Lines changed: 22 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -166,24 +166,25 @@ bool EPContextNodeReader::GraphHasCtxNode(const OrtGraph* graph, const OrtApi& o
166166
/*
167167
* The sanity check for EP context contrib op.
168168
*/
169-
bool EPContextNodeReader::ValidateEPCtxNode(const OrtGraph* graph) const {
169+
OrtStatus* EPContextNodeReader::ValidateEPCtxNode(const OrtGraph* graph) const {
170170
size_t num_nodes = 0;
171171
THROW_IF_ERROR(ort_api.Graph_GetNumNodes(graph, &num_nodes));
172-
ENFORCE(num_nodes == 1);
172+
RETURN_IF_NOT(num_nodes == 1, "Graph contains more than one node.");
173173

174174
std::vector<const OrtNode*> nodes(num_nodes);
175175
RETURN_IF_ERROR(ort_api.Graph_GetNodes(graph, nodes.data(), nodes.size()));
176176

177177
const char* op_type = nullptr;
178178
RETURN_IF_ERROR(ort_api.Node_GetOperatorType(nodes[0], &op_type));
179-
ENFORCE(std::string(op_type) == "EPContext");
179+
RETURN_IF_NOT(std::string(op_type) == "EPContext", "Node is not an EPContext node.");
180180

181181
// TODO: Check compute capability and others
182-
return true;
182+
183+
return nullptr;
183184
}
184185

185186
OrtStatus* EPContextNodeReader::GetEpContextFromGraph(const OrtGraph& graph) {
186-
if (!ValidateEPCtxNode(&graph)) {
187+
if (ValidateEPCtxNode(&graph) != nullptr) {
187188
return ort_api.CreateStatus(ORT_EP_FAIL, "It's not a valid EPContext node");
188189
}
189190

@@ -200,11 +201,7 @@ OrtStatus* EPContextNodeReader::GetEpContextFromGraph(const OrtGraph& graph) {
200201

201202
// Get "embed_mode" attribute
202203
RETURN_IF_ORT_STATUS_ERROR(node.GetAttributeByName("embed_mode", node_attr));
203-
try {
204-
ENFORCE(node_attr.GetType() == OrtOpAttrType::ORT_OP_ATTR_INT);
205-
} catch (const Ort::Exception& e) {
206-
return ort_api.CreateStatus(ORT_EP_FAIL, e.what());
207-
}
204+
RETURN_IF_NOT(node_attr.GetType() == OrtOpAttrType::ORT_OP_ATTR_INT, "\'embed_mode\' attribute should be integer type.");
208205

209206
int64_t embed_mode = 0;
210207
RETURN_IF_ORT_STATUS_ERROR(node_attr.GetValue(embed_mode));
@@ -215,11 +212,7 @@ OrtStatus* EPContextNodeReader::GetEpContextFromGraph(const OrtGraph& graph) {
215212
if (embed_mode) {
216213
// Get engine from byte stream.
217214
RETURN_IF_ORT_STATUS_ERROR(node.GetAttributeByName("ep_cache_context", node_attr));
218-
try {
219-
ENFORCE(node_attr.GetType() == OrtOpAttrType::ORT_OP_ATTR_STRING);
220-
} catch (const Ort::Exception& e) {
221-
return ort_api.CreateStatus(ORT_EP_FAIL, e.what());
222-
}
215+
RETURN_IF_NOT(node_attr.GetType() == OrtOpAttrType::ORT_OP_ATTR_STRING, "\'ep_cache_context\' attribute should be string type.");
223216

224217
std::string context_binary;
225218
RETURN_IF_ORT_STATUS_ERROR(node_attr.GetValue<std::string>(context_binary));
@@ -237,37 +230,26 @@ OrtStatus* EPContextNodeReader::GetEpContextFromGraph(const OrtGraph& graph) {
237230

238231
if (weight_stripped_engine_refit_) {
239232
RETURN_IF_ORT_STATUS_ERROR(node.GetAttributeByName("onnx_model_filename", node_attr));
240-
try {
241-
ENFORCE(node_attr.GetType() == OrtOpAttrType::ORT_OP_ATTR_STRING);
242-
} catch (const Ort::Exception& e) {
243-
return ort_api.CreateStatus(ORT_EP_FAIL, e.what());
244-
}
233+
RETURN_IF_NOT(node_attr.GetType() == OrtOpAttrType::ORT_OP_ATTR_STRING, "\'onnx_model_filename\' attribute should be string type.");
245234
std::string onnx_model_filename;
246235
RETURN_IF_ORT_STATUS_ERROR(node_attr.GetValue<std::string>(onnx_model_filename));
247236
std::string placeholder;
248-
auto status = ep_.RefitEngine(onnx_model_filename,
249-
onnx_model_folder_path_,
250-
placeholder,
251-
make_secure_path_checks,
252-
onnx_model_bytestream_,
253-
onnx_model_bytestream_size_,
254-
onnx_external_data_bytestream_,
255-
onnx_external_data_bytestream_size_,
256-
(*trt_engine_).get(),
257-
false, // serialize refitted engine to disk
258-
detailed_build_log_);
259-
if (status != nullptr) {
260-
return ort_api.CreateStatus(ORT_EP_FAIL, "RefitEngine failed.");
261-
}
237+
RETURN_IF_ERROR(ep_.RefitEngine(onnx_model_filename,
238+
onnx_model_folder_path_,
239+
placeholder,
240+
make_secure_path_checks,
241+
onnx_model_bytestream_,
242+
onnx_model_bytestream_size_,
243+
onnx_external_data_bytestream_,
244+
onnx_external_data_bytestream_size_,
245+
(*trt_engine_).get(),
246+
false, // serialize refitted engine to disk
247+
detailed_build_log_));
262248
}
263249
} else {
264250
// Get engine from cache file.
265251
RETURN_IF_ORT_STATUS_ERROR(node.GetAttributeByName("ep_cache_context", node_attr));
266-
try {
267-
ENFORCE(node_attr.GetType() == OrtOpAttrType::ORT_OP_ATTR_STRING);
268-
} catch (const Ort::Exception& e) {
269-
return ort_api.CreateStatus(ORT_EP_FAIL, e.what());
270-
}
252+
RETURN_IF_NOT(node_attr.GetType() == OrtOpAttrType::ORT_OP_ATTR_STRING, "\'ep_cache_context\' attribute should be string type.");
271253
std::string cache_path;
272254
RETURN_IF_ORT_STATUS_ERROR(node_attr.GetValue<std::string>(cache_path));
273255

@@ -336,11 +318,7 @@ OrtStatus* EPContextNodeReader::GetEpContextFromGraph(const OrtGraph& graph) {
336318

337319
if (weight_stripped_engine_refit_) {
338320
RETURN_IF_ORT_STATUS_ERROR(node.GetAttributeByName("onnx_model_filename", node_attr));
339-
try {
340-
ENFORCE(node_attr.GetType() == OrtOpAttrType::ORT_OP_ATTR_STRING);
341-
} catch (const Ort::Exception& e) {
342-
return ort_api.CreateStatus(ORT_EP_FAIL, e.what());
343-
}
321+
RETURN_IF_NOT(node_attr.GetType() == OrtOpAttrType::ORT_OP_ATTR_STRING, "\'onnx_model_filename\' attribute should be string type.");
344322
std::string onnx_model_filename;
345323
RETURN_IF_ORT_STATUS_ERROR(node_attr.GetValue<std::string>(onnx_model_filename));
346324
std::string weight_stripped_engine_cache = engine_cache_path.string();

plugin_execution_providers/tensorrt/onnx_ctx_model_helper.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ class EPContextNodeReader : public ApiPtrs {
7575

7676
static bool GraphHasCtxNode(const OrtGraph* graph, const OrtApi& ort_api);
7777

78-
bool ValidateEPCtxNode(const OrtGraph* graph) const;
78+
OrtStatus* ValidateEPCtxNode(const OrtGraph* graph) const;
7979

8080
OrtStatus* GetEpContextFromGraph(const OrtGraph& graph);
8181

plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -964,7 +964,7 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProvider::GetCapabilityImpl(OrtEp* this
964964
const OrtGraph* subgraph_raw_pointer = subgraph;
965965
if (subgraph_raw_pointer != graph) {
966966
size_t num_subgraph_nodes = 0;
967-
THROW_IF_ERROR(ort_api.Graph_GetNumNodes(subgraph, &num_subgraph_nodes));
967+
RETURN_IF_ERROR(ort_api.Graph_GetNumNodes(subgraph, &num_subgraph_nodes));
968968

969969
// Another subgraph of "If" control flow op has no nodes.
970970
// In this case, TRT EP should consider this empty subgraph is fully supported by TRT.
@@ -2428,12 +2428,11 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(TensorrtExecutionProviderFa
24282428
CreateSyncStreamForDevice = CreateSyncStreamForDeviceImpl;
24292429

24302430
// Initialize the execution provider.
2431-
auto ort_status = ort_api.Logger_LogMessage(&logger_,
2432-
OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO,
2433-
("Plugin EP has been created with name " + name_).c_str(),
2434-
ORT_FILE, __LINE__, __FUNCTION__);
2435-
// ignore status for now
2436-
(void)ort_status;
2431+
2432+
Ort::Status ort_status(ort_api.Logger_LogMessage(&logger_,
2433+
OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO,
2434+
("Plugin EP has been created with name " + name_).c_str(),
2435+
ORT_FILE, __LINE__, __FUNCTION__));
24372436

24382437
// populate apis as global for utility functions
24392438
g_ort_api = &ort_api;

plugin_execution_providers/tensorrt/tensorrt_execution_provider.h

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -74,17 +74,8 @@ class TensorrtLogger : public nvinfer1::ILogger {
7474

7575
namespace tensorrt_ptr {
7676

77-
struct TensorrtInferDeleter {
78-
template <typename T>
79-
void operator()(T* obj) const {
80-
if (obj) {
81-
delete obj;
82-
}
83-
}
84-
};
85-
8677
template <typename T>
87-
using unique_pointer = std::unique_ptr<T, TensorrtInferDeleter>;
78+
using unique_pointer = std::unique_ptr<T, std::default_delete<T>>;
8879
}; // namespace tensorrt_ptr
8980

9081
class OutputAllocator : public nvinfer1::IOutputAllocator {

plugin_execution_providers/tensorrt/tensorrt_execution_provider_utils.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1107,4 +1107,4 @@ std::string GetCacheSuffix(const std::string& fused_node_name, const std::string
11071107
return "";
11081108
}
11091109

1110-
}
1110+
} // namespace trt_ep

0 commit comments

Comments
 (0)