@@ -166,24 +166,25 @@ bool EPContextNodeReader::GraphHasCtxNode(const OrtGraph* graph, const OrtApi& o
166166/*
167167 * The sanity check for EP context contrib op.
168168 */
169- bool EPContextNodeReader::ValidateEPCtxNode (const OrtGraph* graph) const {
169+ OrtStatus* EPContextNodeReader::ValidateEPCtxNode (const OrtGraph* graph) const {
170170 size_t num_nodes = 0 ;
171171 THROW_IF_ERROR (ort_api.Graph_GetNumNodes (graph, &num_nodes));
172- ENFORCE (num_nodes == 1 );
172+ RETURN_IF_NOT (num_nodes == 1 , " Graph contains more than one node. " );
173173
174174 std::vector<const OrtNode*> nodes (num_nodes);
175175 RETURN_IF_ERROR (ort_api.Graph_GetNodes (graph, nodes.data (), nodes.size ()));
176176
177177 const char * op_type = nullptr ;
178178 RETURN_IF_ERROR (ort_api.Node_GetOperatorType (nodes[0 ], &op_type));
179- ENFORCE (std::string (op_type) == " EPContext" );
179+ RETURN_IF_NOT (std::string (op_type) == " EPContext" , " Node is not an EPContext node. " );
180180
181181 // TODO: Check compute capability and others
182- return true ;
182+
183+ return nullptr ;
183184}
184185
185186OrtStatus* EPContextNodeReader::GetEpContextFromGraph (const OrtGraph& graph) {
186- if (! ValidateEPCtxNode (&graph)) {
187+ if (ValidateEPCtxNode (&graph) != nullptr ) {
187188 return ort_api.CreateStatus (ORT_EP_FAIL, " It's not a valid EPContext node" );
188189 }
189190
@@ -200,11 +201,7 @@ OrtStatus* EPContextNodeReader::GetEpContextFromGraph(const OrtGraph& graph) {
200201
201202 // Get "embed_mode" attribute
202203 RETURN_IF_ORT_STATUS_ERROR (node.GetAttributeByName (" embed_mode" , node_attr));
203- try {
204- ENFORCE (node_attr.GetType () == OrtOpAttrType::ORT_OP_ATTR_INT);
205- } catch (const Ort::Exception& e) {
206- return ort_api.CreateStatus (ORT_EP_FAIL, e.what ());
207- }
204+ RETURN_IF_NOT (node_attr.GetType () == OrtOpAttrType::ORT_OP_ATTR_INT, " \' embed_mode\' attribute should be integer type." );
208205
209206 int64_t embed_mode = 0 ;
210207 RETURN_IF_ORT_STATUS_ERROR (node_attr.GetValue (embed_mode));
@@ -215,11 +212,7 @@ OrtStatus* EPContextNodeReader::GetEpContextFromGraph(const OrtGraph& graph) {
215212 if (embed_mode) {
216213 // Get engine from byte stream.
217214 RETURN_IF_ORT_STATUS_ERROR (node.GetAttributeByName (" ep_cache_context" , node_attr));
218- try {
219- ENFORCE (node_attr.GetType () == OrtOpAttrType::ORT_OP_ATTR_STRING);
220- } catch (const Ort::Exception& e) {
221- return ort_api.CreateStatus (ORT_EP_FAIL, e.what ());
222- }
215+ RETURN_IF_NOT (node_attr.GetType () == OrtOpAttrType::ORT_OP_ATTR_STRING, " \' ep_cache_context\' attribute should be string type." );
223216
224217 std::string context_binary;
225218 RETURN_IF_ORT_STATUS_ERROR (node_attr.GetValue <std::string>(context_binary));
@@ -237,37 +230,26 @@ OrtStatus* EPContextNodeReader::GetEpContextFromGraph(const OrtGraph& graph) {
237230
238231 if (weight_stripped_engine_refit_) {
239232 RETURN_IF_ORT_STATUS_ERROR (node.GetAttributeByName (" onnx_model_filename" , node_attr));
240- try {
241- ENFORCE (node_attr.GetType () == OrtOpAttrType::ORT_OP_ATTR_STRING);
242- } catch (const Ort::Exception& e) {
243- return ort_api.CreateStatus (ORT_EP_FAIL, e.what ());
244- }
233+ RETURN_IF_NOT (node_attr.GetType () == OrtOpAttrType::ORT_OP_ATTR_STRING, " \' onnx_model_filename\' attribute should be string type." );
245234 std::string onnx_model_filename;
246235 RETURN_IF_ORT_STATUS_ERROR (node_attr.GetValue <std::string>(onnx_model_filename));
247236 std::string placeholder;
248- auto status = ep_.RefitEngine (onnx_model_filename,
249- onnx_model_folder_path_,
250- placeholder,
251- make_secure_path_checks,
252- onnx_model_bytestream_,
253- onnx_model_bytestream_size_,
254- onnx_external_data_bytestream_,
255- onnx_external_data_bytestream_size_,
256- (*trt_engine_).get (),
257- false , // serialize refitted engine to disk
258- detailed_build_log_);
259- if (status != nullptr ) {
260- return ort_api.CreateStatus (ORT_EP_FAIL, " RefitEngine failed." );
261- }
237+ RETURN_IF_ERROR (ep_.RefitEngine (onnx_model_filename,
238+ onnx_model_folder_path_,
239+ placeholder,
240+ make_secure_path_checks,
241+ onnx_model_bytestream_,
242+ onnx_model_bytestream_size_,
243+ onnx_external_data_bytestream_,
244+ onnx_external_data_bytestream_size_,
245+ (*trt_engine_).get (),
246+ false , // serialize refitted engine to disk
247+ detailed_build_log_));
262248 }
263249 } else {
264250 // Get engine from cache file.
265251 RETURN_IF_ORT_STATUS_ERROR (node.GetAttributeByName (" ep_cache_context" , node_attr));
266- try {
267- ENFORCE (node_attr.GetType () == OrtOpAttrType::ORT_OP_ATTR_STRING);
268- } catch (const Ort::Exception& e) {
269- return ort_api.CreateStatus (ORT_EP_FAIL, e.what ());
270- }
252+ RETURN_IF_NOT (node_attr.GetType () == OrtOpAttrType::ORT_OP_ATTR_STRING, " \' ep_cache_context\' attribute should be string type." );
271253 std::string cache_path;
272254 RETURN_IF_ORT_STATUS_ERROR (node_attr.GetValue <std::string>(cache_path));
273255
@@ -336,11 +318,7 @@ OrtStatus* EPContextNodeReader::GetEpContextFromGraph(const OrtGraph& graph) {
336318
337319 if (weight_stripped_engine_refit_) {
338320 RETURN_IF_ORT_STATUS_ERROR (node.GetAttributeByName (" onnx_model_filename" , node_attr));
339- try {
340- ENFORCE (node_attr.GetType () == OrtOpAttrType::ORT_OP_ATTR_STRING);
341- } catch (const Ort::Exception& e) {
342- return ort_api.CreateStatus (ORT_EP_FAIL, e.what ());
343- }
321+ RETURN_IF_NOT (node_attr.GetType () == OrtOpAttrType::ORT_OP_ATTR_STRING, " \' onnx_model_filename\' attribute should be string type." );
344322 std::string onnx_model_filename;
345323 RETURN_IF_ORT_STATUS_ERROR (node_attr.GetValue <std::string>(onnx_model_filename));
346324 std::string weight_stripped_engine_cache = engine_cache_path.string ();
0 commit comments