99#include " onnxruntime_cxx_api.h"
1010#undef ORT_API_MANUAL_INIT
1111
12+ #define ORT_EP_UTILS_ORT_GRAPH_TO_PROTO_IMPL
13+ #include " ort_graph_to_proto.h"
14+
1215#include " ep_abi_utils.h"
1316// #include "tensorrt_execution_provider_utils.h"
1417#include " tensorrt_execution_provider.h"
@@ -716,6 +719,267 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect
716719 const OrtGraph* graph, bool * early_termination) const {
717720 // Return if iterations are exceeding predefined number
718721 SubGraphCollection_t nodes_list_output;
722+ if (iterations > max_iterations) {
723+ *early_termination = true ;
724+ return nodes_list_output;
725+ }
726+
727+ // Get parent graph output names
728+ std::unordered_set<std::string> graph_output_names;
729+ for (const auto * output_arg : graph.GetOutputs ()) {
730+ graph_output_names.insert (output_arg->Name ());
731+ }
732+
733+ iterations++;
734+ const std::vector<NodeIndex>& node_index = graph.GetNodesInTopologicalOrder (1 /* priority-based topological sort*/ );
735+ for (const auto & group : nodes_vector_input) {
736+ // Construct subgraph
737+ if (!group.first .empty ()) {
738+ if (group.second ) {
739+ nodes_list_output.push_back (group);
740+ } else {
741+ auto model_build = graph.CreateModel (*GetLogger ());
742+ auto & graph_build = model_build->MainGraph ();
743+ bool has_control_flow_op = false ;
744+
745+ // Add node and node args
746+ // If node output is also parent graph output, the output will be added to the
747+ // subgraph's output list
748+ std::vector<std::string> subgraph_output_names;
749+ for (const auto & index : group.first ) {
750+ // Initializers that refer to a memory location in OrtValue
751+ // can not be handled by TRT (unlike those that are on disk).
752+ // This prevents us from sharing the data and we have to make a copy here.
753+ constexpr const bool load_initializers_inline_true = true ;
754+ const auto & node = graph.GetNode (node_index[index]);
755+ std::vector<onnxruntime::NodeArg*> inputs, outputs;
756+ for (auto input : node->InputDefs ()) {
757+ auto & n_input = graph_build.GetOrCreateNodeArg (input->Name (), input->TypeAsProto ());
758+ inputs.push_back (&n_input);
759+ graph_utils::MakeInitializerCopyIfNotExist (graph.GetGraph (), graph_build, input->Name (),
760+ load_initializers_inline_true);
761+ }
762+
763+ for (auto input : node->ImplicitInputDefs ()) {
764+ graph_utils::MakeInitializerCopyIfNotExist (graph.GetGraph (), graph_build, input->Name (),
765+ load_initializers_inline_true);
766+ }
767+ for (auto output : node->OutputDefs ()) {
768+ auto & n_output = graph_build.GetOrCreateNodeArg (output->Name (), output->TypeAsProto ());
769+ outputs.push_back (&n_output);
770+ const auto name = output->Name ();
771+ if (graph_output_names.find (name) != graph_output_names.end ()) {
772+ subgraph_output_names.push_back (name);
773+ }
774+ }
775+
776+ if (control_flow_op_set_.find (node->OpType ()) != control_flow_op_set_.end ()) {
777+ has_control_flow_op = true ;
778+ }
779+
780+ // If the node has subgraph, it's possible that the ORT graph of that subgraph and the GraphProto in the node
781+ // attributes are not in sync because of graph optimization. Therefore, we need to force GraphProto attributes
782+ // to be updated in order to get the valid GraphProto.
783+ if (node->GetAttributes ().size () > 0 ) {
784+ auto node_proto = ONNX_NAMESPACE::NodeProto::Create ();
785+ // we need to update any GraphProto attributes for subgraphs so that any changes made by things
786+ // such as the optimizers are captured. otherwise we can end up saving an invalid graph.
787+ node->ToProto (*node_proto, /* update_subgraphs */ true );
788+ const int num_attributes = node_proto->attribute_size ();
789+ auto node_attributes = ONNX_NAMESPACE::NodeAttributes::Create ();
790+ node_attributes->reserve (num_attributes);
791+
792+ for (int i = 0 ; i < num_attributes; ++i) {
793+ auto & attr = node_proto->attribute (i);
794+ node_attributes->emplace (attr.name (), attr);
795+ }
796+
797+ // The GraphProto attributes are the updated ones.
798+ graph_build.AddNode (node->Name (), node->OpType (), node->Description (), inputs, outputs,
799+ node_attributes.get (), node->Domain ());
800+ } else {
801+ // The GraphProto attributes are the original ones.
802+ graph_build.AddNode (node->Name (), node->OpType (), node->Description (), inputs, outputs,
803+ &node->GetAttributes (), node->Domain ());
804+ }
805+ }
806+
807+ // Only if the newly built graph has control flow op as well as it has parent node,
808+ // it needs to handle outer scope values before calling graph.Resolve().
809+ if (has_control_flow_op && graph.ParentNode ()) {
810+ LOGS_DEFAULT (VERBOSE) << " [TensorRT EP] Handle outer scope values for the subgraph " << graph_build.Name ();
811+ BuildSubGraphContext (graph_build);
812+ SetGraphOuterScopeValuesAndInputs (graph_build, graph.GetGraph ());
813+ SetAllGraphInputs (graph_build);
814+ }
815+
816+ ORT_ENFORCE (graph_build.Resolve ().IsOK ());
817+
818+ // Add parent graph output to the subgraph
819+ int i = 0 ;
820+ std::vector<const NodeArg*> subgraph_outputs;
821+ subgraph_outputs.resize (subgraph_output_names.size ());
822+ for (auto & name : subgraph_output_names) {
823+ auto output_arg = graph.GetNodeArg (name);
824+ auto & subgraph_output_arg = graph_build.GetOrCreateNodeArg (output_arg->Name (), output_arg->TypeAsProto ());
825+ subgraph_outputs[i] = &subgraph_output_arg;
826+ ++i;
827+ }
828+ auto & graph_build_outputs = graph_build.GetOutputs ();
829+ subgraph_outputs.insert (subgraph_outputs.begin (), graph_build_outputs.begin (), graph_build_outputs.end ());
830+ graph_build.SetOutputs (graph_build_outputs);
831+ ORT_ENFORCE (graph_build.Resolve ().IsOK ());
832+
833+ // Check if input tensors have shapes
834+ if (iterations > 1 ) {
835+ auto graph_inputs = graph_build.GetInputs ();
836+ for (auto input_arg : graph_inputs) {
837+ bool has_dim_value_or_param = true ;
838+ auto input_shape = input_arg->Shape ();
839+ if (input_shape != nullptr ) {
840+ auto dim_size = input_shape->dim_size ();
841+ for (int i = 0 ; i < dim_size; ++i) {
842+ auto & dim = input_shape->dim (i);
843+ if (!dim.has_dim_value () && !dim.has_dim_param ()) {
844+ has_dim_value_or_param = false ;
845+ break ;
846+ }
847+ }
848+ }
849+
850+ if (input_shape == nullptr || !has_dim_value_or_param) {
851+ ORT_THROW_IF_ERROR (
852+ ORT_MAKE_STATUS (ONNXRUNTIME, FAIL,
853+ " TensorRT input: " + input_arg->Name () + " has no shape specified. " +
854+ " Please run shape inference on the onnx model first. Details can be found in " +
855+ " https://onnxruntime.ai/docs/execution-providers/"
856+ " TensorRT-ExecutionProvider.html#shape-inference-for-tensorrt-subgraphs" ));
857+ }
858+ }
859+ }
860+
861+
862+ /*
863+ //Save initializers to external file
864+ std::string ext_ini_file_path = "model_serialized.bin";
865+ std::filesystem::remove(ext_ini_file_path);
866+ std::ofstream ext_ini_ofs(ext_ini_file_path, std::ios::binary);
867+ auto handle_initializer_data = [&ext_ini_ofs, &ext_ini_file_path](
868+ const OrtValueInfo* value_info, const void* data, size_t bytes, bool& is_external,
869+ std::string& location, int64_t& offset) -> Ort::Status {
870+ // OrtValueInfo* could be used to query initializer's name, type, shape,
871+ // node consumers, etc.
872+ (void)value_info;
873+
874+ if (bytes <= 127) {
875+ is_external = false; // Keep small initializers stored inside the TensorProto.
876+ return Ort::Status{nullptr};
877+ }
878+
879+ offset = ext_ini_ofs.tellp();
880+ location = ext_ini_file_path;
881+ ext_ini_ofs.write(static_cast<const char*>(data), bytes);
882+ ext_ini_ofs.flush();
883+ is_external = true; // True if is external initializer.
884+
885+ return Ort::Status{nullptr};
886+ };
887+ */
888+
889+ // Construct ModelProto from OrtGraph
890+ ONNX_NAMESPACE::ModelProto model_proto;
891+
892+ // add back handle_initializer_data to save initializer to external file
893+ OrtEpUtils::OrtGraphToProto (*graph, model_proto /* , handle_initializer_data */ );
894+
895+ std::string string_buf;
896+ model_proto.SerializeToString (&string_buf);
897+
898+ if (dump_subgraphs_) {
899+ // Dump TensorRT subgraph for debugging
900+ std::fstream dump (" TensorrtExecutionProvider_TRT_Subgraph.onnx" ,
901+ std::ios::out | std::ios::trunc | std::ios::binary);
902+ model_proto.SerializeToOstream (&dump);
903+ }
904+
905+ // Get supported node list recursively
906+ SubGraphCollection_t parser_nodes_list;
907+ TensorrtLogger& trt_logger = GetTensorrtLogger (detailed_build_log_);
908+ auto trt_builder = GetBuilder (trt_logger);
909+ auto network_flags = 0 ;
910+ #if NV_TENSORRT_MAJOR > 8
911+ network_flags |= (fp16_enable_ || int8_enable_ || bf16_enable_)
912+ ? 0
913+ : 1U << static_cast <uint32_t >(nvinfer1::NetworkDefinitionCreationFlag::kSTRONGLY_TYPED );
914+ #else
915+ network_flags |= 1U << static_cast <uint32_t >(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH );
916+ #endif
917+
918+ auto trt_network = std::unique_ptr<nvinfer1::INetworkDefinition>(trt_builder->createNetworkV2 (network_flags));
919+ auto trt_parser =
920+ tensorrt_ptr::unique_pointer<nvonnxparser::IParser>(nvonnxparser::createParser (*trt_network, trt_logger));
921+
922+ #if (NV_TENSORRT_MAJOR == 10 && NV_TENSORRT_MINOR > 1) || NV_TENSORRT_MAJOR > 10
923+ auto is_model_supported = trt_parser->supportsModelV2 (string_buf.data (), string_buf.size (), model_path_);
924+
925+ // Note: Calling getNbSubgraphs or getSubgraphNodes before calling supportsModelV2 results in undefined
926+ // behavior.
927+ auto num_subgraphs = trt_parser->getNbSubgraphs ();
928+ parser_nodes_list.reserve (num_subgraphs);
929+
930+ for (int64_t i = 0 ; i < num_subgraphs; ++i) {
931+ int64_t subgraph_len = 0 ;
932+ int64_t * nodes = trt_parser->getSubgraphNodes (i, subgraph_len);
933+ parser_nodes_list.emplace_back ();
934+ parser_nodes_list.back ().first .reserve (subgraph_len);
935+ for (int64_t j = 0 ; j < subgraph_len; ++j) {
936+ parser_nodes_list.back ().first .push_back (nodes[j]);
937+ }
938+ parser_nodes_list.back ().second = is_model_supported ? true : false ;
939+ }
940+ #else
941+ trt_parser->supportsModel (string_buf.data (), string_buf.size (), parser_nodes_list, model_path_);
942+ #endif
943+
944+ SubGraphCollection_t next_nodes_list;
945+ const std::vector<NodeIndex>& subgraph_node_index =
946+ graph_viewer->GetNodesInTopologicalOrder (1 /* priority-based topological sort*/ );
947+ next_nodes_list =
948+ GetSupportedList (parser_nodes_list, iterations, max_iterations, *graph_viewer, early_termination);
949+ for (size_t i = 0 , end = next_nodes_list.size (); i < end; ++i) {
950+ for (size_t j = 0 , end = next_nodes_list[i].first .size (); j < end; ++j) {
951+ /*
952+ * Convert the supported node list returning from onnx-tensorrt parser to the node list recognized by ORT
953+ * TRT.
954+ *
955+ * TRT EP reconstructs the graph based on the nodes in group.first and feeds this graph (converts to model
956+ * proto and to string buffer) to onnx-tensorrt parser. The node index in the list returning from
957+ * onnx-tensorrt parser might not be the same as the node index in group.first. Therefore, TRT EP needs a
958+ * node index mapping table here.
959+ *
960+ * The order of iterating the nodes in group.first and calling graph_build.AddNode() determines the node
961+ * order in the newly constructed graph (see Graph::AllocateNode() in graph.cc), however, once the graph is
962+ * converted to model proto, the node proto order in model proto (ex: onnx-tensorrt calls
963+ * model.graph().node() to iterate NodeProto in ModelProto) is decided by topo sort.
964+ *
965+ * The topo sort list (i.e. subgraph_node_index) acts as the node index mapping table:
966+ * subgraph_node_index[node index from onnx-tensorrt parser] = index in group.first
967+ *
968+ * In the past, TRT EP uses ORT's default reversed DFS topo sort which might end up with the sorting result
969+ * not sequence of 0, 1, ... n-1, ex: the subgraph_node_index = [0,2,1,3,4]. With the change of using ORT's
970+ * priority-based topo sort (node with lower node index outputs first) the sorting result is the sequence of
971+ * 0, 1, ... n-1 for most of the cases, therefore subgraph_node_index as a mapping table is not needed
972+ * anymore.
973+ *
974+ * TODO: Remove the subgraph_node_index
975+ */
976+ next_nodes_list[i].first [j] = group.first [subgraph_node_index[next_nodes_list[i].first [j]]];
977+ }
978+ nodes_list_output.push_back (next_nodes_list[i]);
979+ }
980+ }
981+ }
982+ }
719983 return nodes_list_output;
720984}
721985
@@ -728,26 +992,50 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this
728992 TensorrtExecutionProvider* ep = static_cast <TensorrtExecutionProvider*>(this_ptr);
729993
730994 /*
731- // Reconstruct graph proto from fused node's function body
732- auto model = graph_body_viewer.CreateModel(*GetLogger());
733- auto model_proto = model->ToProto();
734-
735- // ORT's default topological sort is using reversed DFS.
736- // When creating model proto from graph viewer, let ORT use priority-based topological sort based on node index.
737- // The reason is, in some cases, for example ResNet50, using default topological sort will end up with generating
738- // the model proto that has different node ordering compared to original onnx model.
739- graph_body_viewer.ToProto(*model_proto->mutable_graph(), true, true, 1);
740- model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION);
995+ //Save initializers to external file
996+ std::string ext_ini_file_path = "model_serialized.bin";
997+ std::filesystem::remove(ext_ini_file_path);
998+ std::ofstream ext_ini_ofs(ext_ini_file_path, std::ios::binary);
999+ auto handle_initializer_data = [&ext_ini_ofs, &ext_ini_file_path](
1000+ const OrtValueInfo* value_info, const void* data, size_t bytes, bool& is_external,
1001+ std::string& location, int64_t& offset) -> Ort::Status {
1002+ // OrtValueInfo* could be used to query initializer's name, type, shape,
1003+ // node consumers, etc.
1004+ (void)value_info;
1005+
1006+ if (bytes <= 127) {
1007+ is_external = false; // Keep small initializers stored inside the TensorProto.
1008+ return Ort::Status{nullptr};
1009+ }
1010+
1011+ offset = ext_ini_ofs.tellp();
1012+ location = ext_ini_file_path;
1013+ ext_ini_ofs.write(static_cast<const char*>(data), bytes);
1014+ ext_ini_ofs.flush();
1015+ is_external = true; // True if is external initializer.
1016+
1017+ return Ort::Status{nullptr};
1018+ };
1019+ */
1020+
1021+ // Construct ModelProto from OrtGraph
1022+ ONNX_NAMESPACE::ModelProto model_proto;
1023+
1024+ // add back handle_initializer_data to save initializer to external file
1025+ OrtEpUtils::OrtGraphToProto (*graph, model_proto /* , handle_initializer_data */ );
1026+
7411027 std::string string_buf;
742- model_proto-> SerializeToString(string_buf);
1028+ model_proto. SerializeToString (& string_buf);
7431029
7441030 if (dump_subgraphs_) {
7451031 // Dump TensorRT subgraphs
746- std::fstream dump(fused_node.Name() + ".onnx", std::ios::out | std::ios::trunc | std::ios::binary);
747- model_proto->SerializeToOstream(dump);
1032+ const char * name = nullptr ;
1033+ RETURN_IF_ERROR (ort_api.Node_GetName (fused_node, &name));
1034+ std::string subgraph_name = name;
1035+ subgraph_name += " .onnx" ;
1036+ std::fstream dump (subgraph_name, std::ios::out | std::ios::trunc | std::ios::binary);
1037+ model_proto.SerializeToOstream (&dump);
7481038 }
749- */
750- std::string string_buf;
7511039
7521040 TensorrtLogger& trt_logger = GetTensorrtLogger (detailed_build_log_);
7531041 auto trt_builder = GetBuilder (trt_logger);
@@ -1356,6 +1644,7 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this
13561644 }
13571645
13581646 // Create input to index map
1647+ // TRT network input -> ORT fused_node input index
13591648 for (int i = 0 ; i < num_inputs; ++i) {
13601649 auto input = trt_network->getInput (i);
13611650 const std::string& input_name = input->getName ();
@@ -1366,6 +1655,7 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this
13661655 }
13671656
13681657 // Create output to index and type maps
1658+ // TRT network output -> ORT fused_node output index
13691659 const auto & graph_output = model_proto->graph ().output ();
13701660 for (int i = 0 ; i < num_outputs; ++i) {
13711661 const std::string& output_name = trt_network->getOutput (i)->getName ();
@@ -1789,7 +2079,8 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProvider::CompileImpl(_In_ OrtEp* this_
17892079 TensorrtExecutionProvider* ep = static_cast <TensorrtExecutionProvider*>(this_ptr);
17902080 const OrtApi& ort_api = ep->ort_api ;
17912081
1792- gsl::span<OrtNodeComputeInfo*> result (node_compute_infos, count);
2082+ gsl::span<OrtNodeComputeInfo*> node_compute_infos_result (node_compute_infos, count);
2083+ gsl::span<OrtNode*> ep_context_nodes_result (ep_context_nodes, count);
17932084
17942085 for (size_t fused_node_idx = 0 ; fused_node_idx < count; fused_node_idx++) {
17952086 auto fused_node = fused_nodes[fused_node_idx];
@@ -1833,11 +2124,12 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProvider::CompileImpl(_In_ OrtEp* this_
18332124 OrtStatus* status;
18342125 if (GraphHasCtxNode (graphs[fused_node_idx], ort_api)) {
18352126 RETURN_IF_ERROR (ep->CreateNodeComputeInfoFromPrecompiledEngine (this_ptr, graphs[fused_node_idx], fused_node,
1836- input_map,
1837- output_map, &result [fused_node_idx]));
2127+ input_map, output_map ,
2128+ &node_compute_infos_result [fused_node_idx]));
18382129 } else {
18392130 RETURN_IF_ERROR (ep->CreateNodeComputeInfoFromGraph (this_ptr, graphs[fused_node_idx], fused_node, input_map,
1840- output_map, &result[fused_node_idx]));
2131+ output_map, &node_compute_infos_result[fused_node_idx]),
2132+ &ep_context_nodes_result[fused_node_idx]);
18412133 }
18422134 }
18432135
0 commit comments