Skip to content

Commit 1a821de

Browse files
authored
[SYCL][Graph] Graph duplication optimizations in finalize() (#20547)
When invoking `auto g = graph.finalize()` on a graph with a moderate amount of nodes, duplicating the nodes of the modifiable graph into the executable graph takes a non-trivial amount of time. This PR introduces a few changes to improve the performance of `duplicateNodes()`: 1. Replace the `std::dequeue` with a `std::vector` so that it can be directly moved, omitting unnecessary O(n) allocations 2. Keep track of whether there are any subgraphs in the first pass of the graph and skip the second pass of the graph if there are no subgraphs 3. Add a `reserve()` for the vector of nodes and map that we are going to create, since we know the total number of nodes. > Note that 1. optimizes the fast path when the graph does not contain a subgraph, but will likely make finalize slower for graphs that contain subgraphs. The changes for 2. make the diff seem much larger than it actually is, but the majority of it is whitespace from moving the subgraph handling into an if statement.
1 parent 64bc42d commit 1a821de

File tree

1 file changed

+95
-89
lines changed

1 file changed

+95
-89
lines changed

sycl/source/detail/graph/graph_impl.cpp

Lines changed: 95 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -1263,15 +1263,22 @@ exec_graph_impl::enqueue(sycl::detail::queue_impl &Queue,
12631263

12641264
void exec_graph_impl::duplicateNodes() {
12651265
// Map of original modifiable nodes (keys) to new duplicated nodes (values)
1266-
std::map<node_impl *, node_impl *> NodesMap;
1267-
1266+
std::unordered_map<node_impl *, node_impl *> NodesMap;
12681267
nodes_range ModifiableNodes{MGraphImpl->MNodeStorage};
1269-
std::deque<std::shared_ptr<node_impl>> NewNodes;
1268+
std::vector<std::shared_ptr<node_impl>> NewNodes;
1269+
1270+
const size_t NodeCount = ModifiableNodes.size();
1271+
NodesMap.reserve(NodeCount);
1272+
NewNodes.reserve(NodeCount);
1273+
1274+
bool foundSubgraph = false;
12701275

12711276
for (node_impl &OriginalNode : ModifiableNodes) {
12721277
NewNodes.push_back(std::make_shared<node_impl>(OriginalNode));
12731278
node_impl &NodeCopy = *NewNodes.back();
12741279

1280+
foundSubgraph |= (NodeCopy.MNodeType == node_type::subgraph);
1281+
12751282
// Associate the ID of the original node with the node copy for later quick
12761283
// access
12771284
MIDCache.insert(std::make_pair(OriginalNode.MID, &NodeCopy));
@@ -1300,110 +1307,109 @@ void exec_graph_impl::duplicateNodes() {
13001307

13011308
// Subgraph nodes need special handling, we extract all subgraph nodes and
13021309
// merge them into the main node list
1310+
if (foundSubgraph) {
1311+
for (auto NewNodeIt = NewNodes.rbegin(); NewNodeIt != NewNodes.rend();
1312+
++NewNodeIt) {
1313+
auto NewNode = *NewNodeIt;
1314+
if (NewNode->MNodeType != node_type::subgraph) {
1315+
continue;
1316+
}
1317+
nodes_range SubgraphNodes{NewNode->MSubGraphImpl->MNodeStorage};
1318+
std::deque<std::shared_ptr<node_impl>> NewSubgraphNodes{};
1319+
1320+
// Map of original subgraph nodes (keys) to new duplicated nodes (values)
1321+
std::map<node_impl *, node_impl *> SubgraphNodesMap;
1322+
1323+
// Copy subgraph nodes
1324+
for (node_impl &SubgraphNode : SubgraphNodes) {
1325+
NewSubgraphNodes.push_back(std::make_shared<node_impl>(SubgraphNode));
1326+
node_impl &NodeCopy = *NewSubgraphNodes.back();
1327+
// Associate the ID of the original subgraph node with all extracted
1328+
// node copies for future quick access.
1329+
MIDCache.insert(std::make_pair(SubgraphNode.MID, &NodeCopy));
1330+
1331+
SubgraphNodesMap.insert({&SubgraphNode, &NodeCopy});
1332+
NodeCopy.MSuccessors.clear();
1333+
NodeCopy.MPredecessors.clear();
1334+
}
13031335

1304-
for (auto NewNodeIt = NewNodes.rbegin(); NewNodeIt != NewNodes.rend();
1305-
++NewNodeIt) {
1306-
auto NewNode = *NewNodeIt;
1307-
if (NewNode->MNodeType != node_type::subgraph) {
1308-
continue;
1309-
}
1310-
nodes_range SubgraphNodes{NewNode->MSubGraphImpl->MNodeStorage};
1311-
std::deque<std::shared_ptr<node_impl>> NewSubgraphNodes{};
1312-
1313-
// Map of original subgraph nodes (keys) to new duplicated nodes (values)
1314-
std::map<node_impl *, node_impl *> SubgraphNodesMap;
1315-
1316-
// Copy subgraph nodes
1317-
for (node_impl &SubgraphNode : SubgraphNodes) {
1318-
NewSubgraphNodes.push_back(std::make_shared<node_impl>(SubgraphNode));
1319-
node_impl &NodeCopy = *NewSubgraphNodes.back();
1320-
// Associate the ID of the original subgraph node with all extracted node
1321-
// copies for future quick access.
1322-
MIDCache.insert(std::make_pair(SubgraphNode.MID, &NodeCopy));
1323-
1324-
SubgraphNodesMap.insert({&SubgraphNode, &NodeCopy});
1325-
NodeCopy.MSuccessors.clear();
1326-
NodeCopy.MPredecessors.clear();
1327-
}
1328-
1329-
// Rebuild edges for new subgraph nodes
1330-
auto OrigIt = SubgraphNodes.begin(), OrigEnd = SubgraphNodes.end();
1331-
for (auto NewIt = NewSubgraphNodes.begin(); OrigIt != OrigEnd;
1332-
++OrigIt, ++NewIt) {
1333-
node_impl &SubgraphNode = *OrigIt;
1334-
node_impl &NodeCopy = **NewIt;
1336+
// Rebuild edges for new subgraph nodes
1337+
auto OrigIt = SubgraphNodes.begin(), OrigEnd = SubgraphNodes.end();
1338+
for (auto NewIt = NewSubgraphNodes.begin(); OrigIt != OrigEnd;
1339+
++OrigIt, ++NewIt) {
1340+
node_impl &SubgraphNode = *OrigIt;
1341+
node_impl &NodeCopy = **NewIt;
13351342

1336-
for (node_impl &NextNode : SubgraphNode.successors()) {
1337-
node_impl &Successor = *SubgraphNodesMap.at(&NextNode);
1338-
NodeCopy.registerSuccessor(Successor);
1343+
for (node_impl &NextNode : SubgraphNode.successors()) {
1344+
node_impl &Successor = *SubgraphNodesMap.at(&NextNode);
1345+
NodeCopy.registerSuccessor(Successor);
1346+
}
13391347
}
1340-
}
13411348

1342-
// Collect input and output nodes for the subgraph
1343-
std::vector<node_impl *> Inputs;
1344-
std::vector<node_impl *> Outputs;
1345-
for (std::shared_ptr<node_impl> &NodeImpl : NewSubgraphNodes) {
1346-
if (NodeImpl->MPredecessors.size() == 0) {
1347-
Inputs.push_back(&*NodeImpl);
1348-
}
1349-
if (NodeImpl->MSuccessors.size() == 0) {
1350-
Outputs.push_back(&*NodeImpl);
1349+
// Collect input and output nodes for the subgraph
1350+
std::vector<node_impl *> Inputs;
1351+
std::vector<node_impl *> Outputs;
1352+
for (std::shared_ptr<node_impl> &NodeImpl : NewSubgraphNodes) {
1353+
if (NodeImpl->MPredecessors.size() == 0) {
1354+
Inputs.push_back(&*NodeImpl);
1355+
}
1356+
if (NodeImpl->MSuccessors.size() == 0) {
1357+
Outputs.push_back(&*NodeImpl);
1358+
}
13511359
}
1352-
}
13531360

1354-
// Update the predecessors and successors of the nodes which reference the
1355-
// original subgraph node
1361+
// Update the predecessors and successors of the nodes which reference the
1362+
// original subgraph node
13561363

1357-
// Predecessors
1358-
for (node_impl &PredNode : NewNode->predecessors()) {
1359-
auto &Successors = PredNode.MSuccessors;
1364+
// Predecessors
1365+
for (node_impl &PredNode : NewNode->predecessors()) {
1366+
auto &Successors = PredNode.MSuccessors;
13601367

1361-
// Remove the subgraph node from this nodes successors
1362-
Successors.erase(
1363-
std::remove(Successors.begin(), Successors.end(), NewNode.get()),
1364-
Successors.end());
1368+
// Remove the subgraph node from this nodes successors
1369+
Successors.erase(
1370+
std::remove(Successors.begin(), Successors.end(), NewNode.get()),
1371+
Successors.end());
13651372

1366-
// Add all input nodes from the subgraph as successors for this node
1367-
// instead
1368-
for (node_impl *Input : Inputs) {
1369-
PredNode.registerSuccessor(*Input);
1373+
// Add all input nodes from the subgraph as successors for this node
1374+
// instead
1375+
for (node_impl *Input : Inputs) {
1376+
PredNode.registerSuccessor(*Input);
1377+
}
13701378
}
1371-
}
13721379

1373-
// Successors
1374-
for (node_impl &SuccNode : NewNode->successors()) {
1375-
auto &Predecessors = SuccNode.MPredecessors;
1380+
// Successors
1381+
for (node_impl &SuccNode : NewNode->successors()) {
1382+
auto &Predecessors = SuccNode.MPredecessors;
13761383

1377-
// Remove the subgraph node from this nodes successors
1378-
Predecessors.erase(
1379-
std::remove(Predecessors.begin(), Predecessors.end(), NewNode.get()),
1380-
Predecessors.end());
1384+
// Remove the subgraph node from this nodes successors
1385+
Predecessors.erase(std::remove(Predecessors.begin(), Predecessors.end(),
1386+
NewNode.get()),
1387+
Predecessors.end());
13811388

1382-
// Add all Output nodes from the subgraph as predecessors for this node
1383-
// instead
1384-
for (node_impl *Output : Outputs) {
1385-
Output->registerSuccessor(SuccNode);
1389+
// Add all Output nodes from the subgraph as predecessors for this node
1390+
// instead
1391+
for (node_impl *Output : Outputs) {
1392+
Output->registerSuccessor(SuccNode);
1393+
}
13861394
}
1387-
}
13881395

1389-
// Remove single subgraph node and add all new individual subgraph nodes
1390-
// to the node storage in its place
1391-
auto OldPositionIt =
1392-
NewNodes.erase(std::find(NewNodes.begin(), NewNodes.end(), NewNode));
1393-
// Also set the iterator to the newly added nodes so we can continue
1394-
// iterating over all remaining nodes
1395-
auto InsertIt = NewNodes.insert(
1396-
OldPositionIt, std::make_move_iterator(NewSubgraphNodes.begin()),
1397-
std::make_move_iterator(NewSubgraphNodes.end()));
1398-
// Since the new reverse_iterator will be at i - 1 we need to advance it
1399-
// when constructing
1400-
NewNodeIt = std::make_reverse_iterator(std::next(InsertIt));
1396+
// Remove single subgraph node and add all new individual subgraph nodes
1397+
// to the node storage in its place
1398+
auto OldPositionIt =
1399+
NewNodes.erase(std::find(NewNodes.begin(), NewNodes.end(), NewNode));
1400+
// Also set the iterator to the newly added nodes so we can continue
1401+
// iterating over all remaining nodes
1402+
auto InsertIt = NewNodes.insert(
1403+
OldPositionIt, std::make_move_iterator(NewSubgraphNodes.begin()),
1404+
std::make_move_iterator(NewSubgraphNodes.end()));
1405+
// Since the new reverse_iterator will be at i - 1 we need to advance it
1406+
// when constructing
1407+
NewNodeIt = std::make_reverse_iterator(std::next(InsertIt));
1408+
}
14011409
}
14021410

14031411
// Store all the new nodes locally
1404-
MNodeStorage.insert(MNodeStorage.begin(),
1405-
std::make_move_iterator(NewNodes.begin()),
1406-
std::make_move_iterator(NewNodes.end()));
1412+
MNodeStorage = std::move(NewNodes);
14071413
}
14081414

14091415
void exec_graph_impl::update(std::shared_ptr<graph_impl> GraphImpl) {

0 commit comments

Comments
 (0)