Skip to content

Commit 0bb1a4d

Browse files
committed
address reviewer's comments
1 parent 5ab50ac commit 0bb1a4d

File tree

4 files changed

+30
-65
lines changed

4 files changed

+30
-65
lines changed

plugin_execution_providers/tensorrt/onnx_ctx_model_helper.cc

Lines changed: 30 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -15,34 +15,30 @@ extern TensorrtLogger& GetTensorrtLogger(bool verbose_log, const OrtLogger& ort_
1515
const OrtApi* ort_api);
1616

1717
bool IsAbsolutePath(const std::string& path_string) {
18-
#ifdef _WIN32
19-
PathString ort_path_string = ToPathString(path_string);
20-
auto path = std::filesystem::path(ort_path_string.c_str());
21-
return path.is_absolute();
22-
#else
23-
if (!path_string.empty() && path_string[0] == '/') {
24-
return true;
18+
if (path_string.empty()) {
19+
return false;
2520
}
26-
return false;
27-
#endif
21+
22+
std::filesystem::path path(path_string);
23+
return path.is_absolute();
2824
}
2925

30-
// Like "../file_path"
3126
bool IsRelativePathToParentPath(const std::string& path_string) {
32-
#ifdef _WIN32
33-
PathString ort_path_string = ToPathString(path_string);
34-
auto path = std::filesystem::path(ort_path_string.c_str());
35-
auto relative_path = path.lexically_normal().make_preferred().wstring();
36-
if (relative_path.find(L"..", 0) != std::string::npos) {
37-
return true;
38-
}
39-
return false;
40-
#else
41-
if (!path_string.empty() && path_string.find("..", 0) != std::string::npos) {
42-
return true;
27+
if (path_string.empty())
28+
return false;
29+
30+
std::filesystem::path path(path_string);
31+
32+
// Normalize things like "a/../b" or "foo//bar/.."
33+
path = path.lexically_normal();
34+
35+
// Check each path component
36+
for (const auto& part : path) {
37+
if (part == "..") {
38+
return true;
39+
}
4340
}
4441
return false;
45-
#endif
4642
}
4743

4844
/*
@@ -168,7 +164,7 @@ bool EPContextNodeReader::GraphHasCtxNode(const OrtGraph* graph, const OrtApi& o
168164
*/
169165
OrtStatus* EPContextNodeReader::ValidateEPCtxNode(const OrtGraph* graph) const {
170166
size_t num_nodes = 0;
171-
THROW_IF_ERROR(ort_api.Graph_GetNumNodes(graph, &num_nodes));
167+
RETURN_IF_ERROR(ort_api.Graph_GetNumNodes(graph, &num_nodes));
172168
RETURN_IF_NOT(num_nodes == 1, "Graph contains more than one node.");
173169

174170
std::vector<const OrtNode*> nodes(num_nodes);
@@ -184,9 +180,7 @@ OrtStatus* EPContextNodeReader::ValidateEPCtxNode(const OrtGraph* graph) const {
184180
}
185181

186182
OrtStatus* EPContextNodeReader::GetEpContextFromGraph(const OrtGraph& graph) {
187-
if (ValidateEPCtxNode(&graph) != nullptr) {
188-
return ort_api.CreateStatus(ORT_EP_FAIL, "It's not a valid EPContext node");
189-
}
183+
RETURN_IF_ERROR(ValidateEPCtxNode(&graph));
190184

191185
size_t num_nodes = 0;
192186
RETURN_IF_ERROR(ort_api.Graph_GetNumNodes(&graph, &num_nodes));
@@ -200,22 +194,22 @@ OrtStatus* EPContextNodeReader::GetEpContextFromGraph(const OrtGraph& graph) {
200194
Ort::ConstOpAttr node_attr;
201195

202196
// Get "embed_mode" attribute
203-
RETURN_IF_ORT_STATUS_ERROR(node.GetAttributeByName("embed_mode", node_attr));
197+
RETURN_IF_ERROR(node.GetAttributeByName("embed_mode", node_attr));
204198
RETURN_IF_NOT(node_attr.GetType() == OrtOpAttrType::ORT_OP_ATTR_INT, "\'embed_mode\' attribute should be integer type.");
205199

206200
int64_t embed_mode = 0;
207-
RETURN_IF_ORT_STATUS_ERROR(node_attr.GetValue(embed_mode));
201+
RETURN_IF_ERROR(node_attr.GetValue(embed_mode));
208202

209203
// Only make path checks if model not provided as byte buffer
210204
bool make_secure_path_checks = !ort_graph.GetModelPath().empty();
211205

212206
if (embed_mode) {
213207
// Get engine from byte stream.
214-
RETURN_IF_ORT_STATUS_ERROR(node.GetAttributeByName("ep_cache_context", node_attr));
208+
RETURN_IF_ERROR(node.GetAttributeByName("ep_cache_context", node_attr));
215209
RETURN_IF_NOT(node_attr.GetType() == OrtOpAttrType::ORT_OP_ATTR_STRING, "\'ep_cache_context\' attribute should be string type.");
216210

217211
std::string context_binary;
218-
RETURN_IF_ORT_STATUS_ERROR(node_attr.GetValue<std::string>(context_binary));
212+
RETURN_IF_ERROR(node_attr.GetValue<std::string>(context_binary));
219213

220214
*(trt_engine_) = std::unique_ptr<nvinfer1::ICudaEngine>(trt_runtime_->deserializeCudaEngine(const_cast<char*>(context_binary.c_str()),
221215
static_cast<size_t>(context_binary.length())));
@@ -229,10 +223,10 @@ OrtStatus* EPContextNodeReader::GetEpContextFromGraph(const OrtGraph& graph) {
229223
}
230224

231225
if (weight_stripped_engine_refit_) {
232-
RETURN_IF_ORT_STATUS_ERROR(node.GetAttributeByName("onnx_model_filename", node_attr));
226+
RETURN_IF_ERROR(node.GetAttributeByName("onnx_model_filename", node_attr));
233227
RETURN_IF_NOT(node_attr.GetType() == OrtOpAttrType::ORT_OP_ATTR_STRING, "\'onnx_model_filename\' attribute should be string type.");
234228
std::string onnx_model_filename;
235-
RETURN_IF_ORT_STATUS_ERROR(node_attr.GetValue<std::string>(onnx_model_filename));
229+
RETURN_IF_ERROR(node_attr.GetValue<std::string>(onnx_model_filename));
236230
std::string placeholder;
237231
RETURN_IF_ERROR(ep_.RefitEngine(onnx_model_filename,
238232
onnx_model_folder_path_,
@@ -248,10 +242,10 @@ OrtStatus* EPContextNodeReader::GetEpContextFromGraph(const OrtGraph& graph) {
248242
}
249243
} else {
250244
// Get engine from cache file.
251-
RETURN_IF_ORT_STATUS_ERROR(node.GetAttributeByName("ep_cache_context", node_attr));
245+
RETURN_IF_ERROR(node.GetAttributeByName("ep_cache_context", node_attr));
252246
RETURN_IF_NOT(node_attr.GetType() == OrtOpAttrType::ORT_OP_ATTR_STRING, "\'ep_cache_context\' attribute should be string type.");
253247
std::string cache_path;
254-
RETURN_IF_ORT_STATUS_ERROR(node_attr.GetValue<std::string>(cache_path));
248+
RETURN_IF_ERROR(node_attr.GetValue<std::string>(cache_path));
255249

256250
// For security purpose, in the case of running context model, TRT EP won't allow
257251
// engine cache path to be the relative path like "../file_path" or the absolute path.
@@ -317,10 +311,10 @@ OrtStatus* EPContextNodeReader::GetEpContextFromGraph(const OrtGraph& graph) {
317311
message.c_str(), ORT_FILE, __LINE__, __FUNCTION__));
318312

319313
if (weight_stripped_engine_refit_) {
320-
RETURN_IF_ORT_STATUS_ERROR(node.GetAttributeByName("onnx_model_filename", node_attr));
314+
RETURN_IF_ERROR(node.GetAttributeByName("onnx_model_filename", node_attr));
321315
RETURN_IF_NOT(node_attr.GetType() == OrtOpAttrType::ORT_OP_ATTR_STRING, "\'onnx_model_filename\' attribute should be string type.");
322316
std::string onnx_model_filename;
323-
RETURN_IF_ORT_STATUS_ERROR(node_attr.GetValue<std::string>(onnx_model_filename));
317+
RETURN_IF_ERROR(node_attr.GetValue<std::string>(onnx_model_filename));
324318
std::string weight_stripped_engine_cache = engine_cache_path.string();
325319
auto status = ep_.RefitEngine(onnx_model_filename,
326320
onnx_model_folder_path_,

plugin_execution_providers/tensorrt/tensorrt_provider_factory.cc

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -272,12 +272,7 @@ bool ORT_API_CALL TensorrtExecutionProviderFactory::IsStreamAwareImpl(const OrtE
272272

273273
} // namespace trt_ep
274274

275-
// To make symbols visible on macOS/iOS
276-
#ifdef __APPLE__
277-
#define EXPORT_SYMBOL __attribute__((visibility("default")))
278-
#else
279275
#define EXPORT_SYMBOL
280-
#endif
281276

282277
extern "C" {
283278
//

plugin_execution_providers/tensorrt/utils/ep_utils.h

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,6 @@ namespace trt_ep {
3838
#define THROW(...) \
3939
throw std::runtime_error(MakeString(__VA_ARGS__));
4040

41-
#define RETURN_IF_ORTSTATUS_ERROR(fn) RETURN_IF_ERROR(fn)
42-
4341
#define RETURN_IF_ERROR(fn) \
4442
do { \
4543
OrtStatus* _status = (fn); \
@@ -48,14 +46,6 @@ namespace trt_ep {
4846
} \
4947
} while (0)
5048

51-
#define RETURN_IF_ORT_STATUS_ERROR(fn) \
52-
do { \
53-
auto _status = (fn); \
54-
if (!_status.IsOK()) { \
55-
return _status; \
56-
} \
57-
} while (0)
58-
5949
#define RETURN_IF(cond, ...) \
6050
do { \
6151
if ((cond)) { \

plugin_execution_providers/tensorrt/utils/helper.cc

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -48,17 +48,3 @@ std::wstring ToWideString(std::string_view s) {
4848
return ret;
4949
}
5050
#endif // #ifdef _WIN32
51-
52-
#ifdef NO_EXCEPTIONS
53-
void PrintFinalMessage(const char* msg) {
54-
#if defined(__ANDROID__)
55-
__android_log_print(ANDROID_LOG_ERROR, "onnxruntime", "%s", msg);
56-
#else
57-
// TODO, consider changing the output of the error message from std::cerr to logging when the
58-
// exceptions are disabled, since using std::cerr might increase binary size, and std::cerr output
59-
// might not be easily accessible on some systems such as mobile
60-
// TODO, see if we need to change the output of the error message from std::cerr to NSLog for iOS
61-
std::cerr << msg << std::endl;
62-
#endif
63-
}
64-
#endif // #ifdef NO_EXCEPTIONS

0 commit comments

Comments
 (0)