Skip to content

Commit 75240a4

Browse files
committed
Convert onnxruntime::Status to OrtStatus
1 parent 081de36 commit 75240a4

File tree

8 files changed

+118
-74
lines changed

8 files changed

+118
-74
lines changed

plugin_execution_providers/tensorrt/tensorrt_execution_provider_info.cc

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -59,29 +59,29 @@ TensorrtExecutionProviderInfo TensorrtExecutionProviderInfo::FromProviderOptions
5959
TensorrtExecutionProviderInfo info{};
6060

6161
void* user_compute_stream = nullptr;
62-
ORT_THROW_IF_ERROR(
62+
THROW_IF_ERROR(
6363
ProviderOptionsParser{}
6464
.AddValueParser(
6565
tensorrt::provider_option_names::kDeviceId,
66-
[&info](const std::string& value_str) -> Status {
67-
ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, info.device_id));
66+
[&info](const std::string& value_str) -> OrtStatus* {
67+
RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, info.device_id));
6868
int num_devices{};
6969
CUDA_RETURN_IF_ERROR(cudaGetDeviceCount(&num_devices));
70-
ORT_RETURN_IF_NOT(
70+
RETURN_IF_NOT(
7171
0 <= info.device_id && info.device_id < num_devices,
7272
"Invalid device ID: ", info.device_id,
7373
", must be between 0 (inclusive) and ", num_devices, " (exclusive).");
74-
return Status::OK();
74+
return nullptr;
7575
})
7676
.AddAssignmentToReference(tensorrt::provider_option_names::kMaxPartitionIterations, info.max_partition_iterations)
7777
.AddAssignmentToReference(tensorrt::provider_option_names::kHasUserComputeStream, info.has_user_compute_stream)
7878
.AddValueParser(
7979
tensorrt::provider_option_names::kUserComputeStream,
80-
[&user_compute_stream](const std::string& value_str) -> Status {
80+
[&user_compute_stream](const std::string& value_str) -> OrtStatus* {
8181
size_t address;
82-
ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, address));
82+
RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, address));
8383
user_compute_stream = reinterpret_cast<void*>(address);
84-
return Status::OK();
84+
return nullptr;
8585
})
8686
.AddAssignmentToReference(tensorrt::provider_option_names::kMinSubgraphSize, info.min_subgraph_size)
8787
.AddAssignmentToReference(tensorrt::provider_option_names::kMaxWorkspaceSize, info.max_workspace_size)

plugin_execution_providers/tensorrt/tensorrt_execution_provider_info.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@
33

44
#pragma once
55

6-
#include <string>
6+
#include "tensorrt_execution_provider_utils.h"
77
#include "provider_options.h"
8-
#include "common.h"
8+
9+
#include <string>
910

1011
#define TRT_DEFAULT_OPTIMIZER_LEVEL 3
1112

@@ -54,7 +55,7 @@ struct TensorrtExecutionProviderInfo {
5455
std::string engine_cache_prefix{""};
5556
bool engine_hw_compatible{false};
5657

57-
static TensorrtExecutionProviderInfo FromProviderOptions(const onnxruntime::ProviderOptions& options);
58+
static TensorrtExecutionProviderInfo FromProviderOptions(const ProviderOptions& options);
5859
// static ProviderOptions ToProviderOptions(const TensorrtExecutionProviderInfo& info);
5960
// static ProviderOptions ToProviderOptions(const OrtTensorRTProviderOptionsV2& info);
6061
// static void UpdateProviderOptions(void* provider_options, const ProviderOptions& options, bool string_copy);

plugin_execution_providers/tensorrt/tensorrt_execution_provider_utils.h

Lines changed: 73 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
#include "flatbuffers/idl.h"
88
#include "ort_trt_int8_cal_table.fbs.h"
9+
#include "make_string.h"
910
// #include "core/providers/cuda/cuda_pch.h"
1011
// #include "core/common/path_string.h"
1112
// #include "core/framework/murmurhash3.h"
@@ -21,6 +22,26 @@
2122
#include <iostream>
2223
#include <filesystem>
2324

25+
struct ApiPtrs {
26+
const OrtApi& ort_api;
27+
const OrtEpApi& ep_api;
28+
const OrtModelEditorApi& model_editor_api;
29+
};
30+
31+
const OrtApi* g_ort_api = nullptr;
32+
const OrtEpApi* g_ep_api = nullptr;
33+
const OrtModelEditorApi* g_model_editor_api = nullptr;
34+
35+
#define ENFORCE(condition, ...) \
36+
do { \
37+
if (!(condition)) { \
38+
throw std::runtime_error(MakeString(__VA_ARGS__)); \
39+
} \
40+
} while (false)
41+
42+
#define THROW(...) \
43+
throw std::runtime_error(MakeString(__VA_ARGS__));
44+
2445
#define RETURN_IF_ERROR(fn) \
2546
do { \
2647
OrtStatus* _status = (fn); \
@@ -29,17 +50,60 @@
2950
} \
3051
} while (0)
3152

32-
#define RETURN_IF(cond, ort_api, msg) \
33-
do { \
34-
if ((cond)) { \
35-
return (ort_api).CreateStatus(ORT_EP_FAIL, (msg)); \
36-
} \
53+
/*
54+
template <typename... Args>
55+
std::string ComposeString(Args&&... args) {
56+
std::ostringstream oss;
57+
(oss << ... << args);
58+
return oss.str();
59+
};
60+
*/
61+
62+
#define RETURN_IF(cond, ...) \
63+
do { \
64+
if ((cond)) { \
65+
return Ort::GetApi().CreateStatus(ORT_EP_FAIL, MakeString(__VA_ARGS__).c_str()); \
66+
} \
67+
} while (0)
68+
69+
#define RETURN_IF_NOT(condition, ...) RETURN_IF(!(condition), __VA_ARGS__)
70+
71+
#define MAKE_STATUS(error_code, msg) \
72+
Ort::GetApi().CreateStatus(error_code, (msg));
73+
74+
#define THROW_IF_ERROR(expr) \
75+
do { \
76+
auto _status = (expr); \
77+
if (_status != nullptr) { \
78+
std::ostringstream oss; \
79+
oss << Ort::GetApi().GetErrorMessage(_status); \
80+
Ort::GetApi().ReleaseStatus(_status); \
81+
throw std::runtime_error(oss.str()); \
82+
} \
3783
} while (0)
3884

39-
struct ApiPtrs {
40-
const OrtApi& ort_api;
41-
const OrtEpApi& ep_api;
42-
const OrtModelEditorApi& model_editor_api;
85+
// Helper to release Ort one or more objects obtained from the public C API at the end of their scope.
86+
template <typename T>
87+
struct DeferOrtRelease {
88+
DeferOrtRelease(T** object_ptr, std::function<void(T*)> release_func)
89+
: objects_(object_ptr), count_(1), release_func_(release_func) {}
90+
91+
DeferOrtRelease(T** objects, size_t count, std::function<void(T*)> release_func)
92+
: objects_(objects), count_(count), release_func_(release_func) {}
93+
94+
~DeferOrtRelease() {
95+
if (objects_ != nullptr && count_ > 0) {
96+
for (size_t i = 0; i < count_; ++i) {
97+
if (objects_[i] != nullptr) {
98+
release_func_(objects_[i]);
99+
objects_[i] = nullptr;
100+
}
101+
}
102+
}
103+
}
104+
T** objects_ = nullptr;
105+
size_t count_ = 0;
106+
std::function<void(T*)> release_func_ = nullptr;
43107
};
44108

45109
namespace fs = std::filesystem;

plugin_execution_providers/tensorrt/utils/cuda/cuda_call.h

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,18 @@
22
// Licensed under the MIT License.
33

44
#pragma once
5-
#include "../common.h"
6-
7-
namespace onnxruntime {
85

96
// -----------------------------------------------------------------------
107
// Error handling
118
// -----------------------------------------------------------------------
129
//
1310
template <typename ERRTYPE>
1411
const char* CudaErrString(ERRTYPE) {
15-
ORT_NOT_IMPLEMENTED();
12+
THROW();
1613
}
1714

1815
template <typename ERRTYPE, bool THRW>
19-
std::conditional_t<THRW, void, Status> CudaCall(
16+
std::conditional_t<THRW, void, OrtStatus*> CudaCall(
2017
ERRTYPE retCode, const char* exprString, const char* libName, ERRTYPE successCode, const char* msg, const char* file, const int line) {
2118
if (retCode != successCode) {
2219
try {
@@ -41,22 +38,20 @@ std::conditional_t<THRW, void, Status> CudaCall(
4138
file, line, exprString, msg);
4239
if constexpr (THRW) {
4340
// throw an exception with the error info
44-
ORT_THROW(str);
41+
THROW(str);
4542
} else {
46-
//LOGS_DEFAULT(ERROR) << str;
47-
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, str);
43+
return MAKE_STATUS(ORT_EP_FAIL, str);
4844
}
4945
} catch (const std::exception& e) { // catch, log, and rethrow since CUDA code sometimes hangs in destruction, so we'd never get to see the error
5046
if constexpr (THRW) {
51-
ORT_THROW(e.what());
47+
THROW(e.what());
5248
} else {
53-
//LOGS_DEFAULT(ERROR) << e.what();
54-
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, e.what());
49+
return MAKE_STATUS(ORT_EP_FAIL, e.what());
5550
}
5651
}
5752
}
5853
if constexpr (!THRW) {
59-
return Status::OK();
54+
return nullptr;
6055
}
6156
}
6257

@@ -65,5 +60,3 @@ std::conditional_t<THRW, void, Status> CudaCall(
6560
//ERRTYPE retCode, const char* exprString, const char* libName, ERRTYPE successCode, const char* msg, const char* file, const int line);
6661

6762
#define CUDA_CALL(expr) (CudaCall<cudaError, false>((expr), #expr, "CUDA", cudaSuccess, "", __FILE__, __LINE__))
68-
69-
} // namespace onnxruntime

plugin_execution_providers/tensorrt/utils/cuda/cuda_common.h

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,8 @@
55

66
#include "cuda_call.h"
77

8-
namespace onnxruntime {
98
namespace cuda {
109

11-
#define CUDA_RETURN_IF_ERROR(expr) ORT_RETURN_IF_ERROR(CUDA_CALL(expr))
10+
#define CUDA_RETURN_IF_ERROR(expr) RETURN_IF_ERROR(CUDA_CALL(expr))
1211

13-
} // namespace cuda
14-
} // namespace onnxruntime
12+
} // namespace cuda

plugin_execution_providers/tensorrt/utils/make_string.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,6 @@
2121
#include <sstream>
2222
#include <type_traits>
2323

24-
namespace onnxruntime {
25-
2624
namespace detail {
2725

2826
inline void MakeStringImpl(std::ostringstream& /*ss*/) noexcept {
@@ -122,5 +120,3 @@ inline std::string MakeStringWithClassicLocale(const std::string& str) {
122120
inline std::string MakeStringWithClassicLocale(const char* cstr) {
123121
return cstr;
124122
}
125-
126-
} // namespace onnxruntime

plugin_execution_providers/tensorrt/utils/parse_string.h

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,6 @@
88
#include <string_view>
99
#include <type_traits>
1010

11-
#include "common.h"
12-
13-
namespace onnxruntime {
14-
1511
/**
1612
* Tries to parse a value from an entire string.
1713
*/
@@ -67,9 +63,9 @@ inline bool TryParseStringWithClassicLocale(std::string_view str, bool& value) {
6763
* Parses a value from an entire string.
6864
*/
6965
template <typename T>
70-
Status ParseStringWithClassicLocale(std::string_view s, T& value) {
71-
ORT_RETURN_IF_NOT(TryParseStringWithClassicLocale(s, value), "Failed to parse value: \"", value, "\"");
72-
return Status::OK();
66+
OrtStatus* ParseStringWithClassicLocale(std::string_view s, T& value) {
67+
RETURN_IF_NOT(TryParseStringWithClassicLocale(s, value), "Failed to parse value: \"", value, "\"");
68+
return nullptr;
7369
}
7470

7571
/**
@@ -81,5 +77,3 @@ T ParseStringWithClassicLocale(std::string_view s) {
8177
ORT_THROW_IF_ERROR(ParseStringWithClassicLocale(s, value));
8278
return value;
8379
}
84-
85-
} // namespace onnxruntime

0 commit comments

Comments
 (0)