Skip to content

Commit ed65a9f

Browse files
committed
clean up GetCapabilityImpl and make it pass compiler for now
1 parent 36c0dc1 commit ed65a9f

File tree

2 files changed

+32
-19
lines changed

2 files changed

+32
-19
lines changed

plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc

Lines changed: 29 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1330,8 +1330,9 @@ static OrtStatus* ORT_API_CALL GetCapabilityImpl(OrtEp* this_ptr, const OrtGraph
13301330
OrtEpGraphSupportInfo* graph_support_info) {
13311331
TensorrtExecutionProvider* ep = static_cast<TensorrtExecutionProvider*>(this_ptr);
13321332
const OrtApi& ort_api = ep->ort_api;
1333-
/*
1333+
13341334
// Get ModelPath
1335+
/*
13351336
const std::filesystem::path* model_path = nullptr;
13361337
graph_api_->OrtGraph_GetModelPath(graph, reinterpret_cast<const void**>(&model_path));
13371338
const auto& path_string = model_path->string();
@@ -1387,6 +1388,8 @@ static OrtStatus* ORT_API_CALL GetCapabilityImpl(OrtEp* this_ptr, const OrtGraph
13871388
SubGraphCollection_t parser_nodes_vector, supported_nodes_vector;
13881389
bool new_subgraph = true;
13891390

1391+
std::unordered_set<std::string> control_flow_op_set = {"If", "Loop", "Scan"};
1392+
13901393
/* Iterate all the nodes and exclude the node if:
13911394
* 1. It's a control flow op and its subgraph(s) is not fully TRT eligible.
13921395
* 2. Its op type is in the exclusion list.
@@ -1407,7 +1410,7 @@ static OrtStatus* ORT_API_CALL GetCapabilityImpl(OrtEp* this_ptr, const OrtGraph
14071410
const char* op_type = nullptr;
14081411
RETURN_IF_ERROR(ep->ort_api.Node_GetOperatorType(node, &op_type));
14091412

1410-
if (ep->control_flow_op_set_.find(op_type) != ep->control_flow_op_set_.end()) {
1413+
if (control_flow_op_set.find(op_type) != control_flow_op_set.end()) {
14111414
auto supported_control_flow_op = [&](const OrtNode* node) {
14121415
OrtStatus* status = nullptr;
14131416
size_t num_subgraphs = 0;
@@ -1467,24 +1470,31 @@ static OrtStatus* ORT_API_CALL GetCapabilityImpl(OrtEp* this_ptr, const OrtGraph
14671470
}
14681471
}
14691472

1473+
1474+
// Use this local definitions for now
1475+
// TODO: Use provider option
1476+
int max_partition_iterations = 1000;
1477+
int min_subgraph_size = 1;
1478+
14701479
bool early_termination = false;
1471-
supported_nodes_vector = ep->GetSupportedList(parser_nodes_vector, 0, p->max_partition_iterations_, graph, &early_termination);
1480+
supported_nodes_vector = ep->GetSupportedList(parser_nodes_vector, 0, max_partition_iterations, graph, &early_termination);
14721481
if (early_termination) {
14731482
supported_nodes_vector.clear();
14741483
}
14751484

14761485
// Remove subgraphs if its size is less than the predefined minimal size
14771486
for (auto it = supported_nodes_vector.begin(); it != supported_nodes_vector.end(); ++it) {
14781487
const size_t subgraph_size = it->first.size();
1479-
if (subgraph_size < p->min_subgraph_size_) {
1488+
if (subgraph_size < min_subgraph_size) {
14801489
supported_nodes_vector.erase(it--);
14811490
}
14821491
}
14831492

14841493
// Detect and remove cycles from supported node list
1485-
//p->DetectTensorRTGraphCycles(supported_nodes_vector, graph, model_hash);
1494+
/* ep->DetectTensorRTGraphCycles(supported_nodes_vector, graph, model_hash); */
14861495

14871496
// Consolidate supported node list
1497+
/*
14881498
if (supported_nodes_vector.size() > 1) {
14891499
nodes_vector.clear();
14901500
for (const auto& group : supported_nodes_vector) {
@@ -1500,11 +1510,12 @@ static OrtStatus* ORT_API_CALL GetCapabilityImpl(OrtEp* this_ptr, const OrtGraph
15001510
supported_nodes_vector = consolidated_supported_nodes_vector;
15011511
}
15021512
}
1513+
*/
15031514

1504-
std::vector<OrtIndexedSubGraph*> cache;
15051515
// Handle the case where the graph is subgraph of control flow op.
15061516
// The purpose is to make control flow op as well as its subgraphs run on TRT.
15071517
// Here we need to check whether subgraph is fully supported by TRT and don't fuse the nodes of the subgraph until control flow op level.
1518+
/*
15081519
if (p->IsSubGraphOfControlFlowOp(graph) && p->IsSubGraphFullySupported(supported_nodes_vector, number_of_ort_nodes)) {
15091520
bool all_subgraphs_are_supported = true;
15101521
@@ -1580,32 +1591,33 @@ static OrtStatus* ORT_API_CALL GetCapabilityImpl(OrtEp* this_ptr, const OrtGraph
15801591
return;
15811592
}
15821593
}
1594+
*/
15831595

1584-
int number_of_trt_nodes = 0, subgraph_index = 0;
1596+
int number_of_trt_nodes = 0;
15851597
for (const auto& group : supported_nodes_vector) {
15861598
if (!group.first.empty()) {
1587-
std::unique_ptr<OrtIndexedSubGraph> sub_graph = p->GetSubGraph(group, graph, model_hash, subgraph_index);
1588-
cache.push_back(sub_graph.release());
1599+
std::vector<const OrtNode*> supported_nodes;
1600+
for (const auto& index : group.first) {
1601+
const OrtNode* supported_node = nullptr;
1602+
RETURN_IF_ERROR(ep->ort_api.ArrayOfConstObjects_GetElementAt(nodes_container, index,
1603+
reinterpret_cast<const void**>(&supported_node)));
1604+
supported_nodes.push_back(supported_node);
1605+
}
1606+
RETURN_IF_ERROR(ep->ep_api.EpGraphSupportInfo_AddNodesToFuse(graph_support_info, supported_nodes.data(),
1607+
supported_nodes.size()));
15891608
number_of_trt_nodes += static_cast<int>(group.first.size());
1590-
subgraph_index++;
15911609
}
15921610
}
15931611

15941612
const size_t number_of_subgraphs = supported_nodes_vector.size();
15951613
if (number_of_trt_nodes == 0) {
15961614
// LOGS_DEFAULT(WARNING) << "[TensorRT EP] No graph will run on TensorRT execution provider";
1597-
} else if (number_of_trt_nodes == number_of_ort_nodes) {
1615+
} else if (number_of_trt_nodes == nodes.size()) {
15981616
// LOGS_DEFAULT(INFO) << "[TensorRT EP] Whole graph will run on TensorRT execution provider";
15991617
} else {
16001618
// LOGS_DEFAULT(INFO) << "[TensorRT EP] Graph is partitioned and number of subgraphs running on TensorRT execution provider is " << number_of_subgraphs;
16011619
}
16021620

1603-
*cnt = cache.size();
1604-
*indexed_sub_graph = new OrtIndexedSubGraph*[*cnt];
1605-
for (size_t i = 0; i < *cnt; i++) {
1606-
(*indexed_sub_graph)[i] = cache[i];
1607-
}
1608-
16091621
return nullptr;
16101622
}
16111623

plugin_execution_providers/tensorrt/tensorrt_execution_provider.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,8 @@ struct TensorrtExecutionProvider : OrtEp, ApiPtrs {
243243
const OrtSessionOptions& session_options_;
244244
const OrtLogger& logger_;
245245

246+
SubGraphCollection_t GetSupportedList(SubGraphCollection_t supported_nodes_list, int iterations, const int max_iterations,
247+
const OrtGraph* graph, bool* early_termination) const;
246248

247249
/*
248250
bool IsGraphCaptured(int graph_annotation_id) const { return false; }
@@ -283,7 +285,7 @@ struct TensorrtExecutionProvider : OrtEp, ApiPtrs {
283285
std::unordered_map<std::string, std::unordered_map<std::string, float>> dynamic_range_map_;
284286
std::unordered_map<std::string, std::string> cache_suffix_;
285287

286-
//private:
288+
private:
287289
mutable TensorrtExecutionProviderInfo info_;
288290
bool external_stream_ = false;
289291
cudaStream_t stream_ = nullptr;
@@ -346,7 +348,6 @@ struct TensorrtExecutionProvider : OrtEp, ApiPtrs {
346348

347349
// std::unique_ptr<ONNX_NAMESPACE::ModelProto> model_proto_ = ONNX_NAMESPACE::ModelProto::Create();
348350

349-
std::unordered_set<std::string> control_flow_op_set_ = {"If", "Loop", "Scan"};
350351
// mutable std::unordered_map<std::string, std::unique_ptr<SubGraphContext>> subgraph_context_map_;
351352

352353
mutable std::unique_ptr<nvinfer1::IBuilder> builder_;

0 commit comments

Comments
 (0)