Skip to content

Commit c103394

Browse files
committed
Add try/catch for c++ API that throws Ort::Exception
1 parent 5f17a2b commit c103394

File tree

2 files changed

+186
-151
lines changed

2 files changed

+186
-151
lines changed

plugin_execution_providers/tensorrt/onnx_ctx_model_helper.cc

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,11 @@ OrtStatus* EPContextNodeReader::GetEpContextFromGraph(const OrtGraph& graph) {
200200

201201
// Get "embed_mode" attribute
202202
RETURN_IF_ORT_STATUS_ERROR(node.GetAttributeByName("embed_mode", node_attr));
203-
ENFORCE(node_attr.GetType() == OrtOpAttrType::ORT_OP_ATTR_INT);
203+
try {
204+
ENFORCE(node_attr.GetType() == OrtOpAttrType::ORT_OP_ATTR_INT);
205+
} catch (const Ort::Exception& e) {
206+
return ort_api.CreateStatus(ORT_EP_FAIL, e.what());
207+
}
204208

205209
int64_t embed_mode = 0;
206210
RETURN_IF_ORT_STATUS_ERROR(node_attr.GetValue(embed_mode));
@@ -211,7 +215,12 @@ OrtStatus* EPContextNodeReader::GetEpContextFromGraph(const OrtGraph& graph) {
211215
if (embed_mode) {
212216
// Get engine from byte stream.
213217
RETURN_IF_ORT_STATUS_ERROR(node.GetAttributeByName("ep_cache_context", node_attr));
214-
ENFORCE(node_attr.GetType() == OrtOpAttrType::ORT_OP_ATTR_STRING);
218+
try {
219+
ENFORCE(node_attr.GetType() == OrtOpAttrType::ORT_OP_ATTR_STRING);
220+
} catch (const Ort::Exception& e) {
221+
return ort_api.CreateStatus(ORT_EP_FAIL, e.what());
222+
}
223+
215224
std::string context_binary;
216225
RETURN_IF_ORT_STATUS_ERROR(node_attr.GetValue<std::string>(context_binary));
217226

@@ -228,6 +237,11 @@ OrtStatus* EPContextNodeReader::GetEpContextFromGraph(const OrtGraph& graph) {
228237

229238
if (weight_stripped_engine_refit_) {
230239
RETURN_IF_ORT_STATUS_ERROR(node.GetAttributeByName("onnx_model_filename", node_attr));
240+
try {
241+
ENFORCE(node_attr.GetType() == OrtOpAttrType::ORT_OP_ATTR_STRING);
242+
} catch (const Ort::Exception& e) {
243+
return ort_api.CreateStatus(ORT_EP_FAIL, e.what());
244+
}
231245
std::string onnx_model_filename;
232246
RETURN_IF_ORT_STATUS_ERROR(node_attr.GetValue<std::string>(onnx_model_filename));
233247
std::string placeholder;
@@ -249,6 +263,11 @@ OrtStatus* EPContextNodeReader::GetEpContextFromGraph(const OrtGraph& graph) {
249263
} else {
250264
// Get engine from cache file.
251265
RETURN_IF_ORT_STATUS_ERROR(node.GetAttributeByName("ep_cache_context", node_attr));
266+
try {
267+
ENFORCE(node_attr.GetType() == OrtOpAttrType::ORT_OP_ATTR_STRING);
268+
} catch (const Ort::Exception& e) {
269+
return ort_api.CreateStatus(ORT_EP_FAIL, e.what());
270+
}
252271
std::string cache_path;
253272
RETURN_IF_ORT_STATUS_ERROR(node_attr.GetValue<std::string>(cache_path));
254273

@@ -317,6 +336,11 @@ OrtStatus* EPContextNodeReader::GetEpContextFromGraph(const OrtGraph& graph) {
317336

318337
if (weight_stripped_engine_refit_) {
319338
RETURN_IF_ORT_STATUS_ERROR(node.GetAttributeByName("onnx_model_filename", node_attr));
339+
try {
340+
ENFORCE(node_attr.GetType() == OrtOpAttrType::ORT_OP_ATTR_STRING);
341+
} catch (const Ort::Exception& e) {
342+
return ort_api.CreateStatus(ORT_EP_FAIL, e.what());
343+
}
320344
std::string onnx_model_filename;
321345
RETURN_IF_ORT_STATUS_ERROR(node_attr.GetValue<std::string>(onnx_model_filename));
322346
std::string weight_stripped_engine_cache = engine_cache_path.string();

0 commit comments

Comments
 (0)