Skip to content

Commit 3ced4cf

Browse files
committed
add ort_graph_to_proto.h and leverage OrtGraphToProto utilities
1 parent 3ad7736 commit 3ced4cf

File tree

2 files changed

+1029
-19
lines changed

2 files changed

+1029
-19
lines changed

plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc

Lines changed: 311 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@
99
#include "onnxruntime_cxx_api.h"
1010
#undef ORT_API_MANUAL_INIT
1111

12+
#define ORT_EP_UTILS_ORT_GRAPH_TO_PROTO_IMPL
13+
#include "ort_graph_to_proto.h"
14+
1215
#include "ep_abi_utils.h"
1316
//#include "tensorrt_execution_provider_utils.h"
1417
#include "tensorrt_execution_provider.h"
@@ -716,6 +719,267 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect
716719
const OrtGraph* graph, bool* early_termination) const {
717720
// Return if iterations are exceeding predefined number
718721
SubGraphCollection_t nodes_list_output;
722+
if (iterations > max_iterations) {
723+
*early_termination = true;
724+
return nodes_list_output;
725+
}
726+
727+
// Get parent graph output names
728+
std::unordered_set<std::string> graph_output_names;
729+
for (const auto* output_arg : graph.GetOutputs()) {
730+
graph_output_names.insert(output_arg->Name());
731+
}
732+
733+
iterations++;
734+
const std::vector<NodeIndex>& node_index = graph.GetNodesInTopologicalOrder(1 /*priority-based topological sort*/);
735+
for (const auto& group : nodes_vector_input) {
736+
// Construct subgraph
737+
if (!group.first.empty()) {
738+
if (group.second) {
739+
nodes_list_output.push_back(group);
740+
} else {
741+
auto model_build = graph.CreateModel(*GetLogger());
742+
auto& graph_build = model_build->MainGraph();
743+
bool has_control_flow_op = false;
744+
745+
// Add node and node args
746+
// If node output is also parent graph output, the output will be added to the
747+
// subgraph's output list
748+
std::vector<std::string> subgraph_output_names;
749+
for (const auto& index : group.first) {
750+
// Initializers that refer to a memory location in OrtValue
751+
// can not be handled by TRT (unlike those that are on disk).
752+
// This prevents us from sharing the data and we have to make a copy here.
753+
constexpr const bool load_initializers_inline_true = true;
754+
const auto& node = graph.GetNode(node_index[index]);
755+
std::vector<onnxruntime::NodeArg*> inputs, outputs;
756+
for (auto input : node->InputDefs()) {
757+
auto& n_input = graph_build.GetOrCreateNodeArg(input->Name(), input->TypeAsProto());
758+
inputs.push_back(&n_input);
759+
graph_utils::MakeInitializerCopyIfNotExist(graph.GetGraph(), graph_build, input->Name(),
760+
load_initializers_inline_true);
761+
}
762+
763+
for (auto input : node->ImplicitInputDefs()) {
764+
graph_utils::MakeInitializerCopyIfNotExist(graph.GetGraph(), graph_build, input->Name(),
765+
load_initializers_inline_true);
766+
}
767+
for (auto output : node->OutputDefs()) {
768+
auto& n_output = graph_build.GetOrCreateNodeArg(output->Name(), output->TypeAsProto());
769+
outputs.push_back(&n_output);
770+
const auto name = output->Name();
771+
if (graph_output_names.find(name) != graph_output_names.end()) {
772+
subgraph_output_names.push_back(name);
773+
}
774+
}
775+
776+
if (control_flow_op_set_.find(node->OpType()) != control_flow_op_set_.end()) {
777+
has_control_flow_op = true;
778+
}
779+
780+
// If the node has subgraph, it's possible that the ORT graph of that subgraph and the GraphProto in the node
781+
// attributes are not in sync because of graph optimization. Therefore, we need to force GraphProto attributes
782+
// to be updated in order to get the valid GraphProto.
783+
if (node->GetAttributes().size() > 0) {
784+
auto node_proto = ONNX_NAMESPACE::NodeProto::Create();
785+
// we need to update any GraphProto attributes for subgraphs so that any changes made by things
786+
// such as the optimizers are captured. otherwise we can end up saving an invalid graph.
787+
node->ToProto(*node_proto, /* update_subgraphs */ true);
788+
const int num_attributes = node_proto->attribute_size();
789+
auto node_attributes = ONNX_NAMESPACE::NodeAttributes::Create();
790+
node_attributes->reserve(num_attributes);
791+
792+
for (int i = 0; i < num_attributes; ++i) {
793+
auto& attr = node_proto->attribute(i);
794+
node_attributes->emplace(attr.name(), attr);
795+
}
796+
797+
// The GraphProto attributes are the updated ones.
798+
graph_build.AddNode(node->Name(), node->OpType(), node->Description(), inputs, outputs,
799+
node_attributes.get(), node->Domain());
800+
} else {
801+
// The GraphProto attributes are the original ones.
802+
graph_build.AddNode(node->Name(), node->OpType(), node->Description(), inputs, outputs,
803+
&node->GetAttributes(), node->Domain());
804+
}
805+
}
806+
807+
// Only if the newly built graph has control flow op as well as it has parent node,
808+
// it needs to handle outer scope values before calling graph.Resolve().
809+
if (has_control_flow_op && graph.ParentNode()) {
810+
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Handle outer scope values for the subgraph " << graph_build.Name();
811+
BuildSubGraphContext(graph_build);
812+
SetGraphOuterScopeValuesAndInputs(graph_build, graph.GetGraph());
813+
SetAllGraphInputs(graph_build);
814+
}
815+
816+
ORT_ENFORCE(graph_build.Resolve().IsOK());
817+
818+
// Add parent graph output to the subgraph
819+
int i = 0;
820+
std::vector<const NodeArg*> subgraph_outputs;
821+
subgraph_outputs.resize(subgraph_output_names.size());
822+
for (auto& name : subgraph_output_names) {
823+
auto output_arg = graph.GetNodeArg(name);
824+
auto& subgraph_output_arg = graph_build.GetOrCreateNodeArg(output_arg->Name(), output_arg->TypeAsProto());
825+
subgraph_outputs[i] = &subgraph_output_arg;
826+
++i;
827+
}
828+
auto& graph_build_outputs = graph_build.GetOutputs();
829+
subgraph_outputs.insert(subgraph_outputs.begin(), graph_build_outputs.begin(), graph_build_outputs.end());
830+
graph_build.SetOutputs(graph_build_outputs);
831+
ORT_ENFORCE(graph_build.Resolve().IsOK());
832+
833+
// Check if input tensors have shapes
834+
if (iterations > 1) {
835+
auto graph_inputs = graph_build.GetInputs();
836+
for (auto input_arg : graph_inputs) {
837+
bool has_dim_value_or_param = true;
838+
auto input_shape = input_arg->Shape();
839+
if (input_shape != nullptr) {
840+
auto dim_size = input_shape->dim_size();
841+
for (int i = 0; i < dim_size; ++i) {
842+
auto& dim = input_shape->dim(i);
843+
if (!dim.has_dim_value() && !dim.has_dim_param()) {
844+
has_dim_value_or_param = false;
845+
break;
846+
}
847+
}
848+
}
849+
850+
if (input_shape == nullptr || !has_dim_value_or_param) {
851+
ORT_THROW_IF_ERROR(
852+
ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
853+
"TensorRT input: " + input_arg->Name() + " has no shape specified. " +
854+
"Please run shape inference on the onnx model first. Details can be found in " +
855+
"https://onnxruntime.ai/docs/execution-providers/"
856+
"TensorRT-ExecutionProvider.html#shape-inference-for-tensorrt-subgraphs"));
857+
}
858+
}
859+
}
860+
861+
862+
/*
863+
//Save initializers to external file
864+
std::string ext_ini_file_path = "model_serialized.bin";
865+
std::filesystem::remove(ext_ini_file_path);
866+
std::ofstream ext_ini_ofs(ext_ini_file_path, std::ios::binary);
867+
auto handle_initializer_data = [&ext_ini_ofs, &ext_ini_file_path](
868+
const OrtValueInfo* value_info, const void* data, size_t bytes, bool& is_external,
869+
std::string& location, int64_t& offset) -> Ort::Status {
870+
// OrtValueInfo* could be used to query initializer's name, type, shape,
871+
// node consumers, etc.
872+
(void)value_info;
873+
874+
if (bytes <= 127) {
875+
is_external = false; // Keep small initializers stored inside the TensorProto.
876+
return Ort::Status{nullptr};
877+
}
878+
879+
offset = ext_ini_ofs.tellp();
880+
location = ext_ini_file_path;
881+
ext_ini_ofs.write(static_cast<const char*>(data), bytes);
882+
ext_ini_ofs.flush();
883+
is_external = true; // True if is external initializer.
884+
885+
return Ort::Status{nullptr};
886+
};
887+
*/
888+
889+
// Construct ModelProto from OrtGraph
890+
ONNX_NAMESPACE::ModelProto model_proto;
891+
892+
// add back handle_initializer_data to save initializer to external file
893+
OrtEpUtils::OrtGraphToProto(*graph, model_proto /*, handle_initializer_data */);
894+
895+
std::string string_buf;
896+
model_proto.SerializeToString(&string_buf);
897+
898+
if (dump_subgraphs_) {
899+
// Dump TensorRT subgraph for debugging
900+
std::fstream dump("TensorrtExecutionProvider_TRT_Subgraph.onnx",
901+
std::ios::out | std::ios::trunc | std::ios::binary);
902+
model_proto.SerializeToOstream(&dump);
903+
}
904+
905+
// Get supported node list recursively
906+
SubGraphCollection_t parser_nodes_list;
907+
TensorrtLogger& trt_logger = GetTensorrtLogger(detailed_build_log_);
908+
auto trt_builder = GetBuilder(trt_logger);
909+
auto network_flags = 0;
910+
#if NV_TENSORRT_MAJOR > 8
911+
network_flags |= (fp16_enable_ || int8_enable_ || bf16_enable_)
912+
? 0
913+
: 1U << static_cast<uint32_t>(nvinfer1::NetworkDefinitionCreationFlag::kSTRONGLY_TYPED);
914+
#else
915+
network_flags |= 1U << static_cast<uint32_t>(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);
916+
#endif
917+
918+
auto trt_network = std::unique_ptr<nvinfer1::INetworkDefinition>(trt_builder->createNetworkV2(network_flags));
919+
auto trt_parser =
920+
tensorrt_ptr::unique_pointer<nvonnxparser::IParser>(nvonnxparser::createParser(*trt_network, trt_logger));
921+
922+
#if (NV_TENSORRT_MAJOR == 10 && NV_TENSORRT_MINOR > 1) || NV_TENSORRT_MAJOR > 10
923+
auto is_model_supported = trt_parser->supportsModelV2(string_buf.data(), string_buf.size(), model_path_);
924+
925+
// Note: Calling getNbSubgraphs or getSubgraphNodes before calling supportsModelV2 results in undefined
926+
// behavior.
927+
auto num_subgraphs = trt_parser->getNbSubgraphs();
928+
parser_nodes_list.reserve(num_subgraphs);
929+
930+
for (int64_t i = 0; i < num_subgraphs; ++i) {
931+
int64_t subgraph_len = 0;
932+
int64_t* nodes = trt_parser->getSubgraphNodes(i, subgraph_len);
933+
parser_nodes_list.emplace_back();
934+
parser_nodes_list.back().first.reserve(subgraph_len);
935+
for (int64_t j = 0; j < subgraph_len; ++j) {
936+
parser_nodes_list.back().first.push_back(nodes[j]);
937+
}
938+
parser_nodes_list.back().second = is_model_supported ? true : false;
939+
}
940+
#else
941+
trt_parser->supportsModel(string_buf.data(), string_buf.size(), parser_nodes_list, model_path_);
942+
#endif
943+
944+
SubGraphCollection_t next_nodes_list;
945+
const std::vector<NodeIndex>& subgraph_node_index =
946+
graph_viewer->GetNodesInTopologicalOrder(1 /*priority-based topological sort*/);
947+
next_nodes_list =
948+
GetSupportedList(parser_nodes_list, iterations, max_iterations, *graph_viewer, early_termination);
949+
for (size_t i = 0, end = next_nodes_list.size(); i < end; ++i) {
950+
for (size_t j = 0, end = next_nodes_list[i].first.size(); j < end; ++j) {
951+
/*
952+
* Convert the supported node list returning from onnx-tensorrt parser to the node list recognized by ORT
953+
* TRT.
954+
*
955+
* TRT EP reconstructs the graph based on the nodes in group.first and feeds this graph (converts to model
956+
* proto and to string buffer) to onnx-tensorrt parser. The node index in the list returning from
957+
* onnx-tensorrt parser might not be the same as the node index in group.first. Therefore, TRT EP needs a
958+
* node index mapping table here.
959+
*
960+
* The order of iterating the nodes in group.first and calling graph_build.AddNode() determines the node
961+
* order in the newly constructed graph (see Graph::AllocateNode() in graph.cc), however, once the graph is
962+
* converted to model proto, the node proto order in model proto (ex: onnx-tensorrt calls
963+
* model.graph().node() to iterate NodeProto in ModelProto) is decided by topo sort.
964+
*
965+
* The topo sort list (i.e. subgraph_node_index) acts as the node index mapping table:
966+
* subgraph_node_index[node index from onnx-tensorrt parser] = index in group.first
967+
*
968+
* In the past, TRT EP uses ORT's default reversed DFS topo sort which might end up with the sorting result
969+
* not sequence of 0, 1, ... n-1, ex: the subgraph_node_index = [0,2,1,3,4]. With the change of using ORT's
970+
* priority-based topo sort (node with lower node index outputs first) the sorting result is the sequence of
971+
* 0, 1, ... n-1 for most of the cases, therefore subgraph_node_index as a mapping table is not needed
972+
* anymore.
973+
*
974+
* TODO: Remove the subgraph_node_index
975+
*/
976+
next_nodes_list[i].first[j] = group.first[subgraph_node_index[next_nodes_list[i].first[j]]];
977+
}
978+
nodes_list_output.push_back(next_nodes_list[i]);
979+
}
980+
}
981+
}
982+
}
719983
return nodes_list_output;
720984
}
721985

@@ -728,26 +992,50 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this
728992
TensorrtExecutionProvider* ep = static_cast<TensorrtExecutionProvider*>(this_ptr);
729993

730994
/*
731-
// Reconstruct graph proto from fused node's function body
732-
auto model = graph_body_viewer.CreateModel(*GetLogger());
733-
auto model_proto = model->ToProto();
734-
735-
// ORT's default topological sort is using reversed DFS.
736-
// When creating model proto from graph viewer, let ORT use priority-based topological sort based on node index.
737-
// The reason is, in some cases, for example ResNet50, using default topological sort will end up with generating
738-
// the model proto that has different node ordering compared to original onnx model.
739-
graph_body_viewer.ToProto(*model_proto->mutable_graph(), true, true, 1);
740-
model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION);
995+
//Save initializers to external file
996+
std::string ext_ini_file_path = "model_serialized.bin";
997+
std::filesystem::remove(ext_ini_file_path);
998+
std::ofstream ext_ini_ofs(ext_ini_file_path, std::ios::binary);
999+
auto handle_initializer_data = [&ext_ini_ofs, &ext_ini_file_path](
1000+
const OrtValueInfo* value_info, const void* data, size_t bytes, bool& is_external,
1001+
std::string& location, int64_t& offset) -> Ort::Status {
1002+
// OrtValueInfo* could be used to query initializer's name, type, shape,
1003+
// node consumers, etc.
1004+
(void)value_info;
1005+
1006+
if (bytes <= 127) {
1007+
is_external = false; // Keep small initializers stored inside the TensorProto.
1008+
return Ort::Status{nullptr};
1009+
}
1010+
1011+
offset = ext_ini_ofs.tellp();
1012+
location = ext_ini_file_path;
1013+
ext_ini_ofs.write(static_cast<const char*>(data), bytes);
1014+
ext_ini_ofs.flush();
1015+
is_external = true; // True if is external initializer.
1016+
1017+
return Ort::Status{nullptr};
1018+
};
1019+
*/
1020+
1021+
// Construct ModelProto from OrtGraph
1022+
ONNX_NAMESPACE::ModelProto model_proto;
1023+
1024+
// add back handle_initializer_data to save initializer to external file
1025+
OrtEpUtils::OrtGraphToProto(*graph, model_proto /*, handle_initializer_data */);
1026+
7411027
std::string string_buf;
742-
model_proto->SerializeToString(string_buf);
1028+
model_proto.SerializeToString(&string_buf);
7431029

7441030
if (dump_subgraphs_) {
7451031
// Dump TensorRT subgraphs
746-
std::fstream dump(fused_node.Name() + ".onnx", std::ios::out | std::ios::trunc | std::ios::binary);
747-
model_proto->SerializeToOstream(dump);
1032+
const char* name = nullptr;
1033+
RETURN_IF_ERROR(ort_api.Node_GetName(fused_node, &name));
1034+
std::string subgraph_name = name;
1035+
subgraph_name += ".onnx";
1036+
std::fstream dump(subgraph_name, std::ios::out | std::ios::trunc | std::ios::binary);
1037+
model_proto.SerializeToOstream(&dump);
7481038
}
749-
*/
750-
std::string string_buf;
7511039

7521040
TensorrtLogger& trt_logger = GetTensorrtLogger(detailed_build_log_);
7531041
auto trt_builder = GetBuilder(trt_logger);
@@ -1356,6 +1644,7 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this
13561644
}
13571645

13581646
// Create input to index map
1647+
// TRT network input -> ORT fused_node input index
13591648
for (int i = 0; i < num_inputs; ++i) {
13601649
auto input = trt_network->getInput(i);
13611650
const std::string& input_name = input->getName();
@@ -1366,6 +1655,7 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this
13661655
}
13671656

13681657
// Create output to index and type maps
1658+
// TRT network output -> ORT fused_node output index
13691659
const auto& graph_output = model_proto->graph().output();
13701660
for (int i = 0; i < num_outputs; ++i) {
13711661
const std::string& output_name = trt_network->getOutput(i)->getName();
@@ -1789,7 +2079,8 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProvider::CompileImpl(_In_ OrtEp* this_
17892079
TensorrtExecutionProvider* ep = static_cast<TensorrtExecutionProvider*>(this_ptr);
17902080
const OrtApi& ort_api = ep->ort_api;
17912081

1792-
gsl::span<OrtNodeComputeInfo*> result(node_compute_infos, count);
2082+
gsl::span<OrtNodeComputeInfo*> node_compute_infos_result(node_compute_infos, count);
2083+
gsl::span<OrtNode*> ep_context_nodes_result(ep_context_nodes, count);
17932084

17942085
for (size_t fused_node_idx = 0; fused_node_idx < count; fused_node_idx++) {
17952086
auto fused_node = fused_nodes[fused_node_idx];
@@ -1833,11 +2124,12 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProvider::CompileImpl(_In_ OrtEp* this_
18332124
OrtStatus* status;
18342125
if (GraphHasCtxNode(graphs[fused_node_idx], ort_api)) {
18352126
RETURN_IF_ERROR(ep->CreateNodeComputeInfoFromPrecompiledEngine(this_ptr, graphs[fused_node_idx], fused_node,
1836-
input_map,
1837-
output_map, &result[fused_node_idx]));
2127+
input_map, output_map,
2128+
&node_compute_infos_result[fused_node_idx]));
18382129
} else {
18392130
RETURN_IF_ERROR(ep->CreateNodeComputeInfoFromGraph(this_ptr, graphs[fused_node_idx], fused_node, input_map,
1840-
output_map, &result[fused_node_idx]));
2131+
output_map, &node_compute_infos_result[fused_node_idx]),
2132+
&ep_context_nodes_result[fused_node_idx]);
18412133
}
18422134
}
18432135

0 commit comments

Comments
 (0)