@@ -1330,8 +1330,9 @@ static OrtStatus* ORT_API_CALL GetCapabilityImpl(OrtEp* this_ptr, const OrtGraph
13301330 OrtEpGraphSupportInfo* graph_support_info) {
13311331 TensorrtExecutionProvider* ep = static_cast <TensorrtExecutionProvider*>(this_ptr);
13321332 const OrtApi& ort_api = ep->ort_api ;
1333- /*
1333+
13341334 // Get ModelPath
1335+ /*
13351336 const std::filesystem::path* model_path = nullptr;
13361337 graph_api_->OrtGraph_GetModelPath(graph, reinterpret_cast<const void**>(&model_path));
13371338 const auto& path_string = model_path->string();
@@ -1387,6 +1388,8 @@ static OrtStatus* ORT_API_CALL GetCapabilityImpl(OrtEp* this_ptr, const OrtGraph
13871388 SubGraphCollection_t parser_nodes_vector, supported_nodes_vector;
13881389 bool new_subgraph = true ;
13891390
1391+ std::unordered_set<std::string> control_flow_op_set = {" If" , " Loop" , " Scan" };
1392+
13901393 /* Iterate all the nodes and exclude the node if:
13911394 * 1. It's a control flow op and its subgraph(s) is not fully TRT eligible.
13921395 * 2. Its op type is in the exclusion list.
@@ -1407,7 +1410,7 @@ static OrtStatus* ORT_API_CALL GetCapabilityImpl(OrtEp* this_ptr, const OrtGraph
14071410 const char * op_type = nullptr ;
14081411 RETURN_IF_ERROR (ep->ort_api .Node_GetOperatorType (node, &op_type));
14091412
1410- if (ep-> control_flow_op_set_ .find (op_type) != ep-> control_flow_op_set_ .end ()) {
1413+ if (control_flow_op_set .find (op_type) != control_flow_op_set .end ()) {
14111414 auto supported_control_flow_op = [&](const OrtNode* node) {
14121415 OrtStatus* status = nullptr ;
14131416 size_t num_subgraphs = 0 ;
@@ -1467,24 +1470,31 @@ static OrtStatus* ORT_API_CALL GetCapabilityImpl(OrtEp* this_ptr, const OrtGraph
14671470 }
14681471 }
14691472
1473+
1474+ // Use this local definitions for now
1475+ // TODO: Use provider option
1476+ int max_partition_iterations = 1000 ;
1477+ int min_subgraph_size = 1 ;
1478+
14701479 bool early_termination = false ;
1471- supported_nodes_vector = ep->GetSupportedList (parser_nodes_vector, 0 , p-> max_partition_iterations_ , graph, &early_termination);
1480+ supported_nodes_vector = ep->GetSupportedList (parser_nodes_vector, 0 , max_partition_iterations , graph, &early_termination);
14721481 if (early_termination) {
14731482 supported_nodes_vector.clear ();
14741483 }
14751484
14761485 // Remove subgraphs if its size is less than the predefined minimal size
14771486 for (auto it = supported_nodes_vector.begin (); it != supported_nodes_vector.end (); ++it) {
14781487 const size_t subgraph_size = it->first .size ();
1479- if (subgraph_size < p-> min_subgraph_size_ ) {
1488+ if (subgraph_size < min_subgraph_size ) {
14801489 supported_nodes_vector.erase (it--);
14811490 }
14821491 }
14831492
14841493 // Detect and remove cycles from supported node list
1485- // p ->DetectTensorRTGraphCycles(supported_nodes_vector, graph, model_hash);
1494+ /* ep ->DetectTensorRTGraphCycles(supported_nodes_vector, graph, model_hash); */
14861495
14871496 // Consolidate supported node list
1497+ /*
14881498 if (supported_nodes_vector.size() > 1) {
14891499 nodes_vector.clear();
14901500 for (const auto& group : supported_nodes_vector) {
@@ -1500,11 +1510,12 @@ static OrtStatus* ORT_API_CALL GetCapabilityImpl(OrtEp* this_ptr, const OrtGraph
15001510 supported_nodes_vector = consolidated_supported_nodes_vector;
15011511 }
15021512 }
1513+ */
15031514
1504- std::vector<OrtIndexedSubGraph*> cache;
15051515 // Handle the case where the graph is subgraph of control flow op.
15061516 // The purpose is to make control flow op as well as its subgraphs run on TRT.
15071517 // Here we need to check whether subgraph is fully supported by TRT and don't fuse the nodes of the subgraph until control flow op level.
1518+ /*
15081519 if (p->IsSubGraphOfControlFlowOp(graph) && p->IsSubGraphFullySupported(supported_nodes_vector, number_of_ort_nodes)) {
15091520 bool all_subgraphs_are_supported = true;
15101521
@@ -1580,32 +1591,33 @@ static OrtStatus* ORT_API_CALL GetCapabilityImpl(OrtEp* this_ptr, const OrtGraph
15801591 return;
15811592 }
15821593 }
1594+ */
15831595
1584- int number_of_trt_nodes = 0 , subgraph_index = 0 ;
1596+ int number_of_trt_nodes = 0 ;
15851597 for (const auto & group : supported_nodes_vector) {
15861598 if (!group.first .empty ()) {
1587- std::unique_ptr<OrtIndexedSubGraph> sub_graph = p->GetSubGraph (group, graph, model_hash, subgraph_index);
1588- cache.push_back (sub_graph.release ());
1599+ std::vector<const OrtNode*> supported_nodes;
1600+ for (const auto & index : group.first ) {
1601+ const OrtNode* supported_node = nullptr ;
1602+ RETURN_IF_ERROR (ep->ort_api .ArrayOfConstObjects_GetElementAt (nodes_container, index,
1603+ reinterpret_cast <const void **>(&supported_node)));
1604+ supported_nodes.push_back (supported_node);
1605+ }
1606+ RETURN_IF_ERROR (ep->ep_api .EpGraphSupportInfo_AddNodesToFuse (graph_support_info, supported_nodes.data (),
1607+ supported_nodes.size ()));
15891608 number_of_trt_nodes += static_cast <int >(group.first .size ());
1590- subgraph_index++;
15911609 }
15921610 }
15931611
15941612 const size_t number_of_subgraphs = supported_nodes_vector.size ();
15951613 if (number_of_trt_nodes == 0 ) {
15961614 // LOGS_DEFAULT(WARNING) << "[TensorRT EP] No graph will run on TensorRT execution provider";
1597- } else if (number_of_trt_nodes == number_of_ort_nodes ) {
1615+ } else if (number_of_trt_nodes == nodes. size () ) {
15981616 // LOGS_DEFAULT(INFO) << "[TensorRT EP] Whole graph will run on TensorRT execution provider";
15991617 } else {
16001618 // LOGS_DEFAULT(INFO) << "[TensorRT EP] Graph is partitioned and number of subgraphs running on TensorRT execution provider is " << number_of_subgraphs;
16011619 }
16021620
1603- *cnt = cache.size ();
1604- *indexed_sub_graph = new OrtIndexedSubGraph*[*cnt];
1605- for (size_t i = 0 ; i < *cnt; i++) {
1606- (*indexed_sub_graph)[i] = cache[i];
1607- }
1608-
16091621 return nullptr ;
16101622}
16111623
0 commit comments