Skip to content

Commit 4d32867

Browse files
committed
update GetCapabilityImpl()
1 parent 632d224 commit 4d32867

File tree

2 files changed

+120
-132
lines changed

2 files changed

+120
-132
lines changed

plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc

Lines changed: 111 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include <list>
44
#include <functional>
55
#include <iostream>
6+
#include <numeric>
67
#include <cuda_runtime.h>
78

89
#include "onnxruntime_cxx_api.h"
@@ -737,11 +738,61 @@ OrtStatusPtr BindKernelOutput(Ort::KernelContext& ctx,
737738
return nullptr;
738739
}
739740

741+
bool TensorrtExecutionProvider::AllNodesAssignedToSpecificEP(const OrtGraph* graph, const std::string& provider_type) const {
742+
size_t num_nodes = 0;
743+
THROW_IF_ERROR(ort_api.Graph_GetNumNodes(graph, &num_nodes));
744+
745+
// Get all the nodes from the graph
746+
std::vector<const OrtNode*> nodes(num_nodes);
747+
THROW_IF_ERROR(ort_api.Graph_GetNodes(graph, nodes.data(), nodes.size()));
748+
749+
for (const auto node : nodes) {
750+
const char* ep_name;
751+
THROW_IF_ERROR(ort_api.Node_GetEpName(node, &ep_name));
752+
753+
if (std::string(ep_name) != provider_type) {
754+
return false;
755+
}
756+
}
757+
758+
return num_nodes != 0;
759+
}
760+
761+
// Check the graph is the subgraph of control flow op
762+
bool TensorrtExecutionProvider::IsSubGraphOfControlFlowOp(const OrtGraph* graph) const {
763+
const OrtNode* parent_node = nullptr;
764+
THROW_IF_ERROR(ort_api.Graph_GetParentNode(graph, &parent_node));
765+
if (parent_node) {
766+
const char* op_type = nullptr;
767+
THROW_IF_ERROR(ort_api.Node_GetOperatorType(parent_node, &op_type));
768+
769+
if (control_flow_op_set_.find(std::string(op_type)) != control_flow_op_set_.end()) {
770+
return true;
771+
}
772+
}
773+
return false;
774+
}
775+
776+
// Check whether all the nodes of subgraph are supported
777+
bool TensorrtExecutionProvider::IsSubGraphFullySupported(const OrtGraph* graph, SubGraphCollection_t supported_nodes_vector) const {
778+
size_t num_nodes = 0;
779+
THROW_IF_ERROR(ort_api.Graph_GetNumNodes(graph, &num_nodes));
780+
781+
int number_of_trt_nodes = 0;
782+
for (const auto& group : supported_nodes_vector) {
783+
if (!group.first.empty()) {
784+
number_of_trt_nodes += static_cast<int>(group.first.size());
785+
}
786+
}
787+
788+
return number_of_trt_nodes == num_nodes;
789+
}
790+
740791
SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollection_t nodes_vector_input,
741792
int iterations, const int max_iterations,
742793
const OrtGraph* graph, bool* early_termination) const {
743-
// Return if iterations are exceeding predefined number
744-
SubGraphCollection_t nodes_list_output;
794+
// Temporarily make all nodes supported
795+
SubGraphCollection_t nodes_list_output = nodes_vector_input;
745796

746797
return nodes_list_output;
747798
}
@@ -750,6 +801,7 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProvider::GetCapabilityImpl(OrtEp* this
750801
OrtEpGraphSupportInfo* graph_support_info) noexcept {
751802
TensorrtExecutionProvider* ep = static_cast<TensorrtExecutionProvider*>(this_ptr);
752803
const OrtApi& ort_api = ep->ort_api;
804+
auto ort_graph = Ort::ConstGraph(graph);
753805

754806
size_t num_nodes = 0;
755807
RETURN_IF_ERROR(ort_api.Graph_GetNumNodes(graph, &num_nodes));
@@ -776,8 +828,7 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProvider::GetCapabilityImpl(OrtEp* this
776828
return set;
777829
};
778830

779-
// auto exclude_ops_set = get_exclude_ops_set(op_types_to_exclude_);
780-
auto exclude_ops_set = get_exclude_ops_set("");
831+
auto exclude_ops_set = get_exclude_ops_set(ep->op_types_to_exclude_);
781832

782833
/* Iterate all the nodes and exclude the node if:
783834
* 1. It's a control flow op and its subgraph(s) is not fully TRT eligible.
@@ -821,12 +872,10 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProvider::GetCapabilityImpl(OrtEp* this
821872
continue;
822873
}
823874

824-
/*
825-
if (!ep->AllNodesAssignedToSpecificEP(*(subgraph->CreateGraphViewer()), kTensorrtExecutionProvider)) {
875+
if (!ep->AllNodesAssignedToSpecificEP(subgraph, ep->name_)) {
826876
// if not all its subgraphs are supported, we need to exclude this control flow op
827877
return false;
828878
}
829-
*/
830879
}
831880
return true;
832881
};
@@ -862,9 +911,6 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProvider::GetCapabilityImpl(OrtEp* this
862911
supported_nodes_vector.clear();
863912
}
864913

865-
// Temporarily make all nodes supported
866-
supported_nodes_vector = parser_nodes_vector;
867-
868914
// Remove subgraphs if its size is less than the predefined minimal size
869915
for (auto it = supported_nodes_vector.begin(); it != supported_nodes_vector.end(); ++it) {
870916
const size_t subgraph_size = it->first.size();
@@ -873,108 +919,83 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProvider::GetCapabilityImpl(OrtEp* this
873919
}
874920
}
875921

876-
// Detect and remove cycles from supported node list
877-
/* ep->DetectTensorRTGraphCycles(supported_nodes_vector, graph, model_hash); */
878-
879-
// Consolidate supported node list
880-
/*
881-
if (supported_nodes_vector.size() > 1) {
882-
nodes_vector.clear();
883-
for (const auto& group : supported_nodes_vector) {
884-
if (!group.first.empty()) {
885-
nodes_vector.insert(nodes_vector.end(), group.first.begin(), group.first.end());
886-
}
887-
}
888-
SubGraphCollection_t consolidated_supported_nodes_vector = {{nodes_vector, true}};
889-
if (p->DetectTensorRTGraphCycles(consolidated_supported_nodes_vector, graph, model_hash, false)) {
890-
// LOGS_DEFAULT(INFO) << "[TensorRT EP] TensorRT nodes are not consolidated because graph will have cycles after consolidation";
891-
} else {
892-
// LOGS_DEFAULT(INFO) << "[TensorRT EP] TensorRT nodes are consolidated into one subgraph";
893-
supported_nodes_vector = consolidated_supported_nodes_vector;
894-
}
895-
}
896-
*/
922+
// TODO: Detect and remove cycles from supported node list
897923

924+
// TODO: Consolidate supported node list
925+
898926
// Handle the case where the graph is subgraph of control flow op.
899927
// The purpose is to make control flow op as well as its subgraphs run on TRT.
900928
// Here we need to check whether subgraph is fully supported by TRT and don't fuse the nodes of the subgraph until control flow op level.
901-
/*
902-
if (p->IsSubGraphOfControlFlowOp(graph) && p->IsSubGraphFullySupported(supported_nodes_vector, number_of_ort_nodes)) {
929+
if (ep->IsSubGraphOfControlFlowOp(graph) && ep->IsSubGraphFullySupported(graph, supported_nodes_vector)) {
930+
//const std::vector<NodeIndex>& node_index = graph.GetNodesInTopologicalOrder(1);
903931
bool all_subgraphs_are_supported = true;
904932

905933
// "If" control flow op has two subgraph bodies, "then" body and "else" body respectively.
906934
// Check its parent node's another subgraph to see whether that subgraph is also fully supported by TRT.
907-
const OrtNode* parent_node = nullptr;
908-
graph_api_->OrtGraph_GetParenNode(graph, &parent_node);
909-
const char* parent_node_op_type = nullptr;
910-
graph_api_->OrtNode_GetOpType(parent_node, &parent_node_op_type);
911-
if (strcmp(parent_node_op_type, "If") == 0) {
935+
Ort::ConstNode parent_node = ort_graph.GetParentNode();
936+
if (parent_node.GetOperatorType() == "If") {
912937
all_subgraphs_are_supported = false;
913938
SubGraphCollection_t subgraph_supported_nodes_vector;
914-
const OrtGraphViewer** subgraphs = nullptr;
915-
size_t subgraph_count = 0;
916-
graph_api_->OrtNode_GetSubgraphs(parent_node, &subgraphs, &subgraph_count);
917-
for (size_t i = 0; i < subgraph_count; i++) {
918-
bool same_graph = false;
919-
graph_api_->OrtGraph_IsSameGraph(graph, subgraphs[i], &same_graph);
920-
if (same_graph) {
921-
continue;
922-
}
923-
int number_of_ort_subgraph_nodes = 0;
924-
graph_api_->OrtGraph_NumberOfNodes(subgraphs[i], &number_of_ort_subgraph_nodes);
925-
std::vector<size_t> subgraph_nodes_vector(number_of_ort_subgraph_nodes);
926-
std::iota(std::begin(subgraph_nodes_vector), std::end(subgraph_nodes_vector), 0);
927-
SubGraphCollection_t parser_subgraph_nodes_vector = {{subgraph_nodes_vector, false}};
928-
bool subgraph_early_termination = false;
929-
930-
// Another subgraph of "If" control flow op has no nodes.
931-
// In this case, TRT EP should consider this empty subgraph is fully supported by TRT.
932-
if (number_of_ort_subgraph_nodes == 0) {
933-
all_subgraphs_are_supported = true;
934-
break;
935-
}
936-
// Another subgraph of "If" control flow op has been parsed by GetCapability before and all subgraph's nodes assigned to TRT EP.
937-
else if (p->AllNodesAssignedToSpecificEP(subgraphs[i], tensorrtEp)) {
938-
all_subgraphs_are_supported = true;
939-
break;
940-
}
941-
// Another subgraph of "If" control flow has been parsed by GetCapability and not all subgraph's nodes assigned to TRT EP.
942-
// (Note: GetExecutionProviderType() returns "" meaning node has not yet been assigned to any EPs)
943-
else if (!p->AllNodesAssignedToSpecificEP(subgraphs[i], "")) {
944-
all_subgraphs_are_supported = false;
939+
940+
std::vector<Ort::AttrNameSubgraph> attr_name_subgraphs = parent_node.GetSubgraphs();
941+
for (auto attr_name_subgraph : attr_name_subgraphs) {
942+
auto subgraph = attr_name_subgraph.sub_graph;
943+
const OrtGraph* subgraph_raw_pointer = subgraph;
944+
if (subgraph_raw_pointer != graph) {
945+
946+
size_t num_subgraph_nodes = 0;
947+
THROW_IF_ERROR(ort_api.Graph_GetNumNodes(subgraph, &num_subgraph_nodes));
948+
949+
// Another subgraph of "If" control flow op has no nodes.
950+
// In this case, TRT EP should consider this empty subgraph is fully supported by TRT.
951+
if (num_subgraph_nodes == 0) {
952+
all_subgraphs_are_supported = true;
953+
break;
954+
}
955+
// Another subgraph of "If" control flow op has been parsed by GetCapability before and all subgraph's nodes assigned to TRT EP.
956+
else if (ep->AllNodesAssignedToSpecificEP(subgraph, ep->name_)) {
957+
all_subgraphs_are_supported = true;
958+
break;
959+
}
960+
// Another subgraph of "If" control flow has been parsed by GetCapability and not all subgraph's nodes assigned to TRT EP.
961+
// (Note: GetExecutionProviderType() returns "" meaning node has not yet been assigned to any EPs)
962+
else if (!ep->AllNodesAssignedToSpecificEP(subgraph, "")) {
963+
all_subgraphs_are_supported = false;
964+
break;
965+
}
966+
967+
std::vector<size_t> subgraph_nodes_vector(num_subgraph_nodes);
968+
std::iota(std::begin(subgraph_nodes_vector), std::end(subgraph_nodes_vector), 0);
969+
SubGraphCollection_t parser_subgraph_nodes_vector = {{subgraph_nodes_vector, false}};
970+
bool subgraph_early_termination = false;
971+
972+
// Another subgraph of "If" control flow has not yet been parsed by GetCapability.
973+
subgraph_supported_nodes_vector = ep->GetSupportedList(parser_subgraph_nodes_vector, 0, ep->max_partition_iterations_, subgraph, &subgraph_early_termination);
974+
all_subgraphs_are_supported = ep->IsSubGraphFullySupported(subgraph, subgraph_supported_nodes_vector);
945975
break;
946976
}
947-
948-
// Another subgraph of "If" control flow has not yet been parsed by GetCapability.
949-
subgraph_supported_nodes_vector = p->GetSupportedList(parser_subgraph_nodes_vector, 0, p->max_partition_iterations_, subgraphs[i], &subgraph_early_termination);
950-
all_subgraphs_are_supported = p->IsSubGraphFullySupported(subgraph_supported_nodes_vector, number_of_ort_subgraph_nodes);
951-
break;
952977
}
953-
graph_api_->OrtGraph_ReleaseGraphViewerArray(subgraphs, subgraph_count);
954978
}
955979

956980
if (all_subgraphs_are_supported) {
981+
// We want the subgraph nodes to be assigned to TRT EP but don't want them to be fused until later at the control flow op level.
982+
// Simply request the subgraph nodes with a single ComputeCapability for each with no MetaDef (i.e. what the default implementation for IExecutionProvider::GetCapability does).
957983
for (const auto& group : supported_nodes_vector) {
958984
if (!group.first.empty()) {
959985
for (const auto& index : group.first) {
960-
std::unique_ptr<OrtIndexedSubGraph> sub_graph = std::make_unique<OrtIndexedSubGraph>();
961-
sub_graph->node_index_len = 1;
962-
sub_graph->node_index = new size_t[sub_graph->node_index_len];
963-
sub_graph->node_index[0] = nodes_index[index];
964-
cache.push_back(sub_graph.release());
986+
const OrtNode* supported_node = nodes[index];
987+
RETURN_IF_ERROR(ep->ep_api.EpGraphSupportInfo_AddSingleNode(graph_support_info, supported_node));
965988
}
966989
}
967990
}
968-
*cnt = cache.size();
969-
*indexed_sub_graph = new OrtIndexedSubGraph*[*cnt];
970-
for (size_t i = 0; i < *cnt; i++) {
971-
(*indexed_sub_graph)[i] = cache[i];
972-
}
973-
// LOGS_DEFAULT(INFO) << "[TensorRT EP] Whole graph will run on TensorRT execution provider";
974-
return;
991+
std::string message = "[TensorRT EP] Whole graph will run on TensorRT execution provider";
992+
Ort::ThrowOnError(ep->ort_api.Logger_LogMessage(&ep->logger_,
993+
OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO,
994+
message.c_str(), ORT_FILE, __LINE__, __FUNCTION__));
995+
996+
return nullptr;
975997
}
976998
}
977-
*/
978999

9791000
int number_of_trt_nodes = 0;
9801001
for (const auto& group : supported_nodes_vector) {
@@ -2251,7 +2272,7 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(TensorrtExecutionProviderFa
22512272

22522273
// The implementation of the SessionOptionsAppendExecutionProvider C API function automatically adds EP options to
22532274
// the session option configurations with the key prefix "ep.<lowercase_ep_name>.".
2254-
// We extract those EP options to create a new "provider options" key/value map.
2275+
// We extract those EP options to create a new "provider options" key-value map.
22552276
std::string lowercase_ep_name = name_.c_str();
22562277
std::transform(lowercase_ep_name.begin(), lowercase_ep_name.end(), lowercase_ep_name.begin(),
22572278
[](unsigned char c) { return static_cast<char>(std::tolower(c)); });
@@ -2289,7 +2310,6 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(TensorrtExecutionProviderFa
22892310
info_ = TensorrtExecutionProviderInfo::FromProviderOptions(provider_options);
22902311
info_.has_trt_options = true;
22912312
device_id_ = info_.device_id;
2292-
// api_->CreateDevice(OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_GPU, OrtMemoryType::OrtMemoryType_Default, device_id_, &default_device);
22932313

22942314
std::string profile_min_shapes, profile_max_shapes, profile_opt_shapes;
22952315

@@ -2358,6 +2378,7 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(TensorrtExecutionProviderFa
23582378
profile_opt_shapes = info_.profile_opt_shapes;
23592379
cuda_graph_enable_ = info_.cuda_graph_enable;
23602380
engine_hw_compatible_ = info_.engine_hw_compatible;
2381+
op_types_to_exclude_ = info_.op_types_to_exclude;
23612382
} else {
23622383
// deprecate env provider option
23632384
}

plugin_execution_providers/tensorrt/tensorrt_execution_provider.h

Lines changed: 9 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -22,48 +22,6 @@ using AllocateFunc = void* (*)(void*, size_t, size_t);
2222
using DestroyFunc = void (*)(void*, void*);
2323

2424
namespace trt_ep {
25-
namespace tensorrt_env_vars {
26-
static const std::string kMaxPartitionIterations = "ORT_TENSORRT_MAX_PARTITION_ITERATIONS";
27-
static const std::string kMinSubgraphSize = "ORT_TENSORRT_MIN_SUBGRAPH_SIZE";
28-
static const std::string kMaxWorkspaceSize = "ORT_TENSORRT_MAX_WORKSPACE_SIZE";
29-
static const std::string kFP16Enable = "ORT_TENSORRT_FP16_ENABLE";
30-
static const std::string kINT8Enable = "ORT_TENSORRT_INT8_ENABLE";
31-
static const std::string kINT8CalibrationTableName = "ORT_TENSORRT_INT8_CALIBRATION_TABLE_NAME";
32-
static const std::string kINT8UseNativeTensorrtCalibrationTable = "ORT_TENSORRT_INT8_USE_NATIVE_CALIBRATION_TABLE";
33-
static const std::string kDLAEnable = "ORT_TENSORRT_DLA_ENABLE";
34-
static const std::string kDLACore = "ORT_TENSORRT_DLA_CORE";
35-
static const std::string kDumpSubgraphs = "ORT_TENSORRT_DUMP_SUBGRAPHS";
36-
static const std::string kEngineCacheEnable = "ORT_TENSORRT_ENGINE_CACHE_ENABLE";
37-
static const std::string kCachePath = "ORT_TENSORRT_CACHE_PATH";
38-
static const std::string kWeightStrippedEngineEnable = "ORT_TENSORRT_WEIGHT_STRIPPED_ENGINE_ENABLE";
39-
static const std::string kOnnxModelFolderPath = "ORT_TENSORRT_ONNX_MODEL_FOLDER_PATH";
40-
// As a timing cache can be used across multiple ONNX files it makes sense to have a separate cache path
41-
static const std::string kTimingCachePath = "ORT_TENSORRT_GLOBAL_CACHE_PATH";
42-
static const std::string kDecryptionEnable = "ORT_TENSORRT_ENGINE_DECRYPTION_ENABLE";
43-
static const std::string kDecryptionLibPath = "ORT_TENSORRT_ENGINE_DECRYPTION_LIB_PATH";
44-
static const std::string kForceSequentialEngineBuild = "ORT_TENSORRT_FORCE_SEQUENTIAL_ENGINE_BUILD";
45-
static const std::string kContextMemorySharingEnable = "ORT_TENSORRT_CONTEXT_MEMORY_SHARING_ENABLE";
46-
static const std::string kLayerNormFP32Fallback = "ORT_TENSORRT_LAYER_NORM_FP32_FALLBACK";
47-
static const std::string kTimingCacheEnable = "ORT_TENSORRT_TIMING_CACHE_ENABLE";
48-
static const std::string kForceTimingCache = "ORT_TENSORRT_FORCE_TIMING_CACHE_ENABLE";
49-
static const std::string kDetailedBuildLog = "ORT_TENSORRT_DETAILED_BUILD_LOG_ENABLE";
50-
static const std::string kBuildHeuristics = "ORT_TENSORRT_BUILD_HEURISTICS_ENABLE";
51-
static const std::string kSparsityEnable = "ORT_TENSORRT_SPARSITY_ENABLE";
52-
static const std::string kBuilderOptimizationLevel = "ORT_TENSORRT_BUILDER_OPTIMIZATION_LEVEL";
53-
static const std::string kAuxiliaryStreams = "ORT_TENSORRT_AUXILIARY_STREAMS";
54-
static const std::string kTacticSources = "ORT_TENSORRT_TACTIC_SOURCES";
55-
static const std::string kExtraPluginLibPaths = "ORT_TENSORRT_EXTRA_PLUGIN_LIB_PATHS";
56-
static const std::string kProfilesMinShapes = "ORT_TENSORRT_PROFILE_MIN_SHAPES";
57-
static const std::string kProfilesMaxShapes = "ORT_TENSORRT_PROFILE_MAX_SHAPES";
58-
static const std::string kProfilesOptShapes = "ORT_TENSORRT_PROFILE_OPT_SHAPES";
59-
static const std::string kCudaGraphEnable = "ORT_TENSORRT_CUDA_GRAPH_ENABLE";
60-
static const std::string kDumpEpContextModel = "ORT_DUMP_EP_CONTEXT_MODEL";
61-
static const std::string kEpContextEmbedMode = "ORT_EP_CONTEXT_EMBED_MODE";
62-
static const std::string kEpContextComputeCapabilityEnable = "ORT_EP_CONTEXT_COMPUTE_CAPABILITY_ENABLE";
63-
static const std::string kEngineCachePrefix = "ORT_TENSORRT_CACHE_PREFIX";
64-
// Old env variable for backward compatibility
65-
static const std::string kEngineCachePath = "ORT_TENSORRT_ENGINE_CACHE_PATH";
66-
} // namespace tensorrt_env_vars
6725

6826
class TensorrtLogger : public nvinfer1::ILogger {
6927
nvinfer1::ILogger::Severity verbosity_;
@@ -386,6 +344,7 @@ struct TensorrtExecutionProvider : public OrtEp, public ApiPtrs {
386344
bool cuda_graph_enable_ = false;
387345
std::string cache_prefix_;
388346
bool engine_hw_compatible_ = false;
347+
std::string op_types_to_exclude_;
389348

390349
// For create/dump EP context node model
391350
bool dump_ep_context_model_ = false;
@@ -399,6 +358,8 @@ struct TensorrtExecutionProvider : public OrtEp, public ApiPtrs {
399358
std::vector<const char*> extra_attr_keys_;
400359
std::vector<const char*> extra_attr_values_;
401360

361+
std::unordered_set<std::string> control_flow_op_set_ = {"If", "Loop", "Scan"};
362+
402363
// std::unique_ptr<ONNX_NAMESPACE::ModelProto> model_proto_ = ONNX_NAMESPACE::ModelProto::Create();
403364

404365
// mutable std::unordered_map<std::string, std::unique_ptr<SubGraphContext>> subgraph_context_map_;
@@ -442,6 +403,12 @@ struct TensorrtExecutionProvider : public OrtEp, public ApiPtrs {
442403
bool IsGraphCaptureAllowed() const { return false; };
443404

444405
nvinfer1::IBuilder* GetBuilder(TensorrtLogger& trt_logger) const;
406+
407+
bool AllNodesAssignedToSpecificEP(const OrtGraph* graph, const std::string& provider_type) const;
408+
409+
bool IsSubGraphOfControlFlowOp(const OrtGraph* graph) const;
410+
411+
bool IsSubGraphFullySupported(const OrtGraph* graph, SubGraphCollection_t supported_nodes_vector) const;
445412
};
446413

447414
/// <summary>

0 commit comments

Comments
 (0)