33#include < list>
44#include < functional>
55#include < iostream>
6+ #include < numeric>
67#include < cuda_runtime.h>
78
89#include " onnxruntime_cxx_api.h"
@@ -737,11 +738,61 @@ OrtStatusPtr BindKernelOutput(Ort::KernelContext& ctx,
737738 return nullptr ;
738739}
739740
741+ bool TensorrtExecutionProvider::AllNodesAssignedToSpecificEP (const OrtGraph* graph, const std::string& provider_type) const {
742+ size_t num_nodes = 0 ;
743+ THROW_IF_ERROR (ort_api.Graph_GetNumNodes (graph, &num_nodes));
744+
745+ // Get all the nodes from the graph
746+ std::vector<const OrtNode*> nodes (num_nodes);
747+ THROW_IF_ERROR (ort_api.Graph_GetNodes (graph, nodes.data (), nodes.size ()));
748+
749+ for (const auto node : nodes) {
750+ const char * ep_name;
751+ THROW_IF_ERROR (ort_api.Node_GetEpName (node, &ep_name));
752+
753+ if (std::string (ep_name) != provider_type) {
754+ return false ;
755+ }
756+ }
757+
758+ return num_nodes != 0 ;
759+ }
760+
761+ // Check the graph is the subgraph of control flow op
762+ bool TensorrtExecutionProvider::IsSubGraphOfControlFlowOp (const OrtGraph* graph) const {
763+ const OrtNode* parent_node = nullptr ;
764+ THROW_IF_ERROR (ort_api.Graph_GetParentNode (graph, &parent_node));
765+ if (parent_node) {
766+ const char * op_type = nullptr ;
767+ THROW_IF_ERROR (ort_api.Node_GetOperatorType (parent_node, &op_type));
768+
769+ if (control_flow_op_set_.find (std::string (op_type)) != control_flow_op_set_.end ()) {
770+ return true ;
771+ }
772+ }
773+ return false ;
774+ }
775+
776+ // Check whether all the nodes of subgraph are supported
777+ bool TensorrtExecutionProvider::IsSubGraphFullySupported (const OrtGraph* graph, SubGraphCollection_t supported_nodes_vector) const {
778+ size_t num_nodes = 0 ;
779+ THROW_IF_ERROR (ort_api.Graph_GetNumNodes (graph, &num_nodes));
780+
781+ int number_of_trt_nodes = 0 ;
782+ for (const auto & group : supported_nodes_vector) {
783+ if (!group.first .empty ()) {
784+ number_of_trt_nodes += static_cast <int >(group.first .size ());
785+ }
786+ }
787+
788+ return number_of_trt_nodes == num_nodes;
789+ }
790+
740791SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList (SubGraphCollection_t nodes_vector_input,
741792 int iterations, const int max_iterations,
742793 const OrtGraph* graph, bool * early_termination) const {
743- // Return if iterations are exceeding predefined number
744- SubGraphCollection_t nodes_list_output;
794+ // Temporarily make all nodes supported
795+ SubGraphCollection_t nodes_list_output = nodes_vector_input ;
745796
746797 return nodes_list_output;
747798}
@@ -750,6 +801,7 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProvider::GetCapabilityImpl(OrtEp* this
750801 OrtEpGraphSupportInfo* graph_support_info) noexcept {
751802 TensorrtExecutionProvider* ep = static_cast <TensorrtExecutionProvider*>(this_ptr);
752803 const OrtApi& ort_api = ep->ort_api ;
804+ auto ort_graph = Ort::ConstGraph (graph);
753805
754806 size_t num_nodes = 0 ;
755807 RETURN_IF_ERROR (ort_api.Graph_GetNumNodes (graph, &num_nodes));
@@ -776,8 +828,7 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProvider::GetCapabilityImpl(OrtEp* this
776828 return set;
777829 };
778830
779- // auto exclude_ops_set = get_exclude_ops_set(op_types_to_exclude_);
780- auto exclude_ops_set = get_exclude_ops_set (" " );
831+ auto exclude_ops_set = get_exclude_ops_set (ep->op_types_to_exclude_ );
781832
782833 /* Iterate all the nodes and exclude the node if:
783834 * 1. It's a control flow op and its subgraph(s) is not fully TRT eligible.
@@ -821,12 +872,10 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProvider::GetCapabilityImpl(OrtEp* this
821872 continue ;
822873 }
823874
824- /*
825- if (!ep->AllNodesAssignedToSpecificEP(*(subgraph->CreateGraphViewer()), kTensorrtExecutionProvider)) {
875+ if (!ep->AllNodesAssignedToSpecificEP (subgraph, ep->name_ )) {
826876 // if not all its subgraphs are supported, we need to exclude this control flow op
827877 return false ;
828878 }
829- */
830879 }
831880 return true ;
832881 };
@@ -862,9 +911,6 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProvider::GetCapabilityImpl(OrtEp* this
862911 supported_nodes_vector.clear ();
863912 }
864913
865- // Temporarily make all nodes supported
866- supported_nodes_vector = parser_nodes_vector;
867-
868914 // Remove subgraphs if its size is less than the predefined minimal size
869915 for (auto it = supported_nodes_vector.begin (); it != supported_nodes_vector.end (); ++it) {
870916 const size_t subgraph_size = it->first .size ();
@@ -873,108 +919,83 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProvider::GetCapabilityImpl(OrtEp* this
873919 }
874920 }
875921
876- // Detect and remove cycles from supported node list
877- /* ep->DetectTensorRTGraphCycles(supported_nodes_vector, graph, model_hash); */
878-
879- // Consolidate supported node list
880- /*
881- if (supported_nodes_vector.size() > 1) {
882- nodes_vector.clear();
883- for (const auto& group : supported_nodes_vector) {
884- if (!group.first.empty()) {
885- nodes_vector.insert(nodes_vector.end(), group.first.begin(), group.first.end());
886- }
887- }
888- SubGraphCollection_t consolidated_supported_nodes_vector = {{nodes_vector, true}};
889- if (p->DetectTensorRTGraphCycles(consolidated_supported_nodes_vector, graph, model_hash, false)) {
890- // LOGS_DEFAULT(INFO) << "[TensorRT EP] TensorRT nodes are not consolidated because graph will have cycles after consolidation";
891- } else {
892- // LOGS_DEFAULT(INFO) << "[TensorRT EP] TensorRT nodes are consolidated into one subgraph";
893- supported_nodes_vector = consolidated_supported_nodes_vector;
894- }
895- }
896- */
922+ // TODO: Detect and remove cycles from supported node list
897923
924+ // TODO: Consolidate supported node list
925+
898926 // Handle the case where the graph is subgraph of control flow op.
899927 // The purpose is to make control flow op as well as its subgraphs run on TRT.
900928 // Here we need to check whether subgraph is fully supported by TRT and don't fuse the nodes of the subgraph until control flow op level.
901- /*
902- if (p->IsSubGraphOfControlFlowOp(graph) && p->IsSubGraphFullySupported(supported_nodes_vector, number_of_ort_nodes)) {
929+ if (ep-> IsSubGraphOfControlFlowOp (graph) && ep-> IsSubGraphFullySupported (graph, supported_nodes_vector)) {
930+ // const std::vector<NodeIndex>& node_index = graph.GetNodesInTopologicalOrder(1);
903931 bool all_subgraphs_are_supported = true ;
904932
905933 // "If" control flow op has two subgraph bodies, "then" body and "else" body respectively.
906934 // Check its parent node's another subgraph to see whether that subgraph is also fully supported by TRT.
907- const OrtNode* parent_node = nullptr;
908- graph_api_->OrtGraph_GetParenNode(graph, &parent_node);
909- const char* parent_node_op_type = nullptr;
910- graph_api_->OrtNode_GetOpType(parent_node, &parent_node_op_type);
911- if (strcmp(parent_node_op_type, "If") == 0) {
935+ Ort::ConstNode parent_node = ort_graph.GetParentNode ();
936+ if (parent_node.GetOperatorType () == " If" ) {
912937 all_subgraphs_are_supported = false ;
913938 SubGraphCollection_t subgraph_supported_nodes_vector;
914- const OrtGraphViewer** subgraphs = nullptr;
915- size_t subgraph_count = 0;
916- graph_api_->OrtNode_GetSubgraphs(parent_node, &subgraphs, &subgraph_count);
917- for (size_t i = 0; i < subgraph_count; i++) {
918- bool same_graph = false;
919- graph_api_->OrtGraph_IsSameGraph(graph, subgraphs[i], &same_graph);
920- if (same_graph) {
921- continue;
922- }
923- int number_of_ort_subgraph_nodes = 0;
924- graph_api_->OrtGraph_NumberOfNodes(subgraphs[i], &number_of_ort_subgraph_nodes);
925- std::vector<size_t> subgraph_nodes_vector(number_of_ort_subgraph_nodes);
926- std::iota(std::begin(subgraph_nodes_vector), std::end(subgraph_nodes_vector), 0);
927- SubGraphCollection_t parser_subgraph_nodes_vector = {{subgraph_nodes_vector, false}};
928- bool subgraph_early_termination = false;
929-
930- // Another subgraph of "If" control flow op has no nodes.
931- // In this case, TRT EP should consider this empty subgraph is fully supported by TRT.
932- if (number_of_ort_subgraph_nodes == 0) {
933- all_subgraphs_are_supported = true;
934- break;
935- }
936- // Another subgraph of "If" control flow op has been parsed by GetCapability before and all subgraph's nodes assigned to TRT EP.
937- else if (p->AllNodesAssignedToSpecificEP(subgraphs[i], tensorrtEp)) {
938- all_subgraphs_are_supported = true;
939- break;
940- }
941- // Another subgraph of "If" control flow has been parsed by GetCapability and not all subgraph's nodes assigned to TRT EP.
942- // (Note: GetExecutionProviderType() returns "" meaning node has not yet been assigned to any EPs)
943- else if (!p->AllNodesAssignedToSpecificEP(subgraphs[i], "")) {
944- all_subgraphs_are_supported = false;
939+
940+ std::vector<Ort::AttrNameSubgraph> attr_name_subgraphs = parent_node.GetSubgraphs ();
941+ for (auto attr_name_subgraph : attr_name_subgraphs) {
942+ auto subgraph = attr_name_subgraph.sub_graph ;
943+ const OrtGraph* subgraph_raw_pointer = subgraph;
944+ if (subgraph_raw_pointer != graph) {
945+
946+ size_t num_subgraph_nodes = 0 ;
947+ THROW_IF_ERROR (ort_api.Graph_GetNumNodes (subgraph, &num_subgraph_nodes));
948+
949+ // Another subgraph of "If" control flow op has no nodes.
950+ // In this case, TRT EP should consider this empty subgraph is fully supported by TRT.
951+ if (num_subgraph_nodes == 0 ) {
952+ all_subgraphs_are_supported = true ;
953+ break ;
954+ }
955+ // Another subgraph of "If" control flow op has been parsed by GetCapability before and all subgraph's nodes assigned to TRT EP.
956+ else if (ep->AllNodesAssignedToSpecificEP (subgraph, ep->name_ )) {
957+ all_subgraphs_are_supported = true ;
958+ break ;
959+ }
960+ // Another subgraph of "If" control flow has been parsed by GetCapability and not all subgraph's nodes assigned to TRT EP.
961+ // (Note: GetExecutionProviderType() returns "" meaning node has not yet been assigned to any EPs)
962+ else if (!ep->AllNodesAssignedToSpecificEP (subgraph, " " )) {
963+ all_subgraphs_are_supported = false ;
964+ break ;
965+ }
966+
967+ std::vector<size_t > subgraph_nodes_vector (num_subgraph_nodes);
968+ std::iota (std::begin (subgraph_nodes_vector), std::end (subgraph_nodes_vector), 0 );
969+ SubGraphCollection_t parser_subgraph_nodes_vector = {{subgraph_nodes_vector, false }};
970+ bool subgraph_early_termination = false ;
971+
972+ // Another subgraph of "If" control flow has not yet been parsed by GetCapability.
973+ subgraph_supported_nodes_vector = ep->GetSupportedList (parser_subgraph_nodes_vector, 0 , ep->max_partition_iterations_ , subgraph, &subgraph_early_termination);
974+ all_subgraphs_are_supported = ep->IsSubGraphFullySupported (subgraph, subgraph_supported_nodes_vector);
945975 break ;
946976 }
947-
948- // Another subgraph of "If" control flow has not yet been parsed by GetCapability.
949- subgraph_supported_nodes_vector = p->GetSupportedList(parser_subgraph_nodes_vector, 0, p->max_partition_iterations_, subgraphs[i], &subgraph_early_termination);
950- all_subgraphs_are_supported = p->IsSubGraphFullySupported(subgraph_supported_nodes_vector, number_of_ort_subgraph_nodes);
951- break;
952977 }
953- graph_api_->OrtGraph_ReleaseGraphViewerArray(subgraphs, subgraph_count);
954978 }
955979
956980 if (all_subgraphs_are_supported) {
981+ // We want the subgraph nodes to be assigned to TRT EP but don't want them to be fused until later at the control flow op level.
982+ // Simply request the subgraph nodes with a single ComputeCapability for each with no MetaDef (i.e. what the default implementation for IExecutionProvider::GetCapability does).
957983 for (const auto & group : supported_nodes_vector) {
958984 if (!group.first .empty ()) {
959985 for (const auto & index : group.first ) {
960- std::unique_ptr<OrtIndexedSubGraph> sub_graph = std::make_unique<OrtIndexedSubGraph>();
961- sub_graph->node_index_len = 1;
962- sub_graph->node_index = new size_t[sub_graph->node_index_len];
963- sub_graph->node_index[0] = nodes_index[index];
964- cache.push_back(sub_graph.release());
986+ const OrtNode* supported_node = nodes[index];
987+ RETURN_IF_ERROR (ep->ep_api .EpGraphSupportInfo_AddSingleNode (graph_support_info, supported_node));
965988 }
966989 }
967990 }
968- *cnt = cache.size();
969- *indexed_sub_graph = new OrtIndexedSubGraph*[*cnt];
970- for (size_t i = 0; i < *cnt; i++) {
971- (*indexed_sub_graph)[i] = cache[i];
972- }
973- // LOGS_DEFAULT(INFO) << "[TensorRT EP] Whole graph will run on TensorRT execution provider";
974- return;
991+ std::string message = " [TensorRT EP] Whole graph will run on TensorRT execution provider" ;
992+ Ort::ThrowOnError (ep->ort_api .Logger_LogMessage (&ep->logger_ ,
993+ OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO,
994+ message.c_str (), ORT_FILE, __LINE__, __FUNCTION__));
995+
996+ return nullptr ;
975997 }
976998 }
977- */
978999
9791000 int number_of_trt_nodes = 0 ;
9801001 for (const auto & group : supported_nodes_vector) {
@@ -2251,7 +2272,7 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(TensorrtExecutionProviderFa
22512272
22522273 // The implementation of the SessionOptionsAppendExecutionProvider C API function automatically adds EP options to
22532274 // the session option configurations with the key prefix "ep.<lowercase_ep_name>.".
2254- // We extract those EP options to create a new "provider options" key/ value map.
2275+ // We extract those EP options to create a new "provider options" key- value map.
22552276 std::string lowercase_ep_name = name_.c_str ();
22562277 std::transform (lowercase_ep_name.begin (), lowercase_ep_name.end (), lowercase_ep_name.begin (),
22572278 [](unsigned char c) { return static_cast <char >(std::tolower (c)); });
@@ -2289,7 +2310,6 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(TensorrtExecutionProviderFa
22892310 info_ = TensorrtExecutionProviderInfo::FromProviderOptions (provider_options);
22902311 info_.has_trt_options = true ;
22912312 device_id_ = info_.device_id ;
2292- // api_->CreateDevice(OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_GPU, OrtMemoryType::OrtMemoryType_Default, device_id_, &default_device);
22932313
22942314 std::string profile_min_shapes, profile_max_shapes, profile_opt_shapes;
22952315
@@ -2358,6 +2378,7 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(TensorrtExecutionProviderFa
23582378 profile_opt_shapes = info_.profile_opt_shapes ;
23592379 cuda_graph_enable_ = info_.cuda_graph_enable ;
23602380 engine_hw_compatible_ = info_.engine_hw_compatible ;
2381+ op_types_to_exclude_ = info_.op_types_to_exclude ;
23612382 } else {
23622383 // deprecate env provider option
23632384 }
0 commit comments