@@ -1263,15 +1263,22 @@ exec_graph_impl::enqueue(sycl::detail::queue_impl &Queue,
12631263
12641264void exec_graph_impl::duplicateNodes () {
12651265 // Map of original modifiable nodes (keys) to new duplicated nodes (values)
1266- std::map<node_impl *, node_impl *> NodesMap;
1267-
1266+ std::unordered_map<node_impl *, node_impl *> NodesMap;
12681267 nodes_range ModifiableNodes{MGraphImpl->MNodeStorage };
1269- std::deque<std::shared_ptr<node_impl>> NewNodes;
1268+ std::vector<std::shared_ptr<node_impl>> NewNodes;
1269+
1270+ const size_t NodeCount = ModifiableNodes.size ();
1271+ NodesMap.reserve (NodeCount);
1272+ NewNodes.reserve (NodeCount);
1273+
1274+ bool foundSubgraph = false ;
12701275
12711276 for (node_impl &OriginalNode : ModifiableNodes) {
12721277 NewNodes.push_back (std::make_shared<node_impl>(OriginalNode));
12731278 node_impl &NodeCopy = *NewNodes.back ();
12741279
1280+ foundSubgraph |= (NodeCopy.MNodeType == node_type::subgraph);
1281+
12751282 // Associate the ID of the original node with the node copy for later quick
12761283 // access
12771284 MIDCache.insert (std::make_pair (OriginalNode.MID , &NodeCopy));
@@ -1300,110 +1307,109 @@ void exec_graph_impl::duplicateNodes() {
13001307
13011308 // Subgraph nodes need special handling, we extract all subgraph nodes and
13021309 // merge them into the main node list
1310+ if (foundSubgraph) {
1311+ for (auto NewNodeIt = NewNodes.rbegin (); NewNodeIt != NewNodes.rend ();
1312+ ++NewNodeIt) {
1313+ auto NewNode = *NewNodeIt;
1314+ if (NewNode->MNodeType != node_type::subgraph) {
1315+ continue ;
1316+ }
1317+ nodes_range SubgraphNodes{NewNode->MSubGraphImpl ->MNodeStorage };
1318+ std::deque<std::shared_ptr<node_impl>> NewSubgraphNodes{};
1319+
1320+ // Map of original subgraph nodes (keys) to new duplicated nodes (values)
1321+ std::map<node_impl *, node_impl *> SubgraphNodesMap;
1322+
1323+ // Copy subgraph nodes
1324+ for (node_impl &SubgraphNode : SubgraphNodes) {
1325+ NewSubgraphNodes.push_back (std::make_shared<node_impl>(SubgraphNode));
1326+ node_impl &NodeCopy = *NewSubgraphNodes.back ();
1327+ // Associate the ID of the original subgraph node with all extracted
1328+ // node copies for future quick access.
1329+ MIDCache.insert (std::make_pair (SubgraphNode.MID , &NodeCopy));
1330+
1331+ SubgraphNodesMap.insert ({&SubgraphNode, &NodeCopy});
1332+ NodeCopy.MSuccessors .clear ();
1333+ NodeCopy.MPredecessors .clear ();
1334+ }
13031335
1304- for (auto NewNodeIt = NewNodes.rbegin (); NewNodeIt != NewNodes.rend ();
1305- ++NewNodeIt) {
1306- auto NewNode = *NewNodeIt;
1307- if (NewNode->MNodeType != node_type::subgraph) {
1308- continue ;
1309- }
1310- nodes_range SubgraphNodes{NewNode->MSubGraphImpl ->MNodeStorage };
1311- std::deque<std::shared_ptr<node_impl>> NewSubgraphNodes{};
1312-
1313- // Map of original subgraph nodes (keys) to new duplicated nodes (values)
1314- std::map<node_impl *, node_impl *> SubgraphNodesMap;
1315-
1316- // Copy subgraph nodes
1317- for (node_impl &SubgraphNode : SubgraphNodes) {
1318- NewSubgraphNodes.push_back (std::make_shared<node_impl>(SubgraphNode));
1319- node_impl &NodeCopy = *NewSubgraphNodes.back ();
1320- // Associate the ID of the original subgraph node with all extracted node
1321- // copies for future quick access.
1322- MIDCache.insert (std::make_pair (SubgraphNode.MID , &NodeCopy));
1323-
1324- SubgraphNodesMap.insert ({&SubgraphNode, &NodeCopy});
1325- NodeCopy.MSuccessors .clear ();
1326- NodeCopy.MPredecessors .clear ();
1327- }
1328-
1329- // Rebuild edges for new subgraph nodes
1330- auto OrigIt = SubgraphNodes.begin (), OrigEnd = SubgraphNodes.end ();
1331- for (auto NewIt = NewSubgraphNodes.begin (); OrigIt != OrigEnd;
1332- ++OrigIt, ++NewIt) {
1333- node_impl &SubgraphNode = *OrigIt;
1334- node_impl &NodeCopy = **NewIt;
1336+ // Rebuild edges for new subgraph nodes
1337+ auto OrigIt = SubgraphNodes.begin (), OrigEnd = SubgraphNodes.end ();
1338+ for (auto NewIt = NewSubgraphNodes.begin (); OrigIt != OrigEnd;
1339+ ++OrigIt, ++NewIt) {
1340+ node_impl &SubgraphNode = *OrigIt;
1341+ node_impl &NodeCopy = **NewIt;
13351342
1336- for (node_impl &NextNode : SubgraphNode.successors ()) {
1337- node_impl &Successor = *SubgraphNodesMap.at (&NextNode);
1338- NodeCopy.registerSuccessor (Successor);
1343+ for (node_impl &NextNode : SubgraphNode.successors ()) {
1344+ node_impl &Successor = *SubgraphNodesMap.at (&NextNode);
1345+ NodeCopy.registerSuccessor (Successor);
1346+ }
13391347 }
1340- }
13411348
1342- // Collect input and output nodes for the subgraph
1343- std::vector<node_impl *> Inputs;
1344- std::vector<node_impl *> Outputs;
1345- for (std::shared_ptr<node_impl> &NodeImpl : NewSubgraphNodes) {
1346- if (NodeImpl->MPredecessors .size () == 0 ) {
1347- Inputs.push_back (&*NodeImpl);
1348- }
1349- if (NodeImpl->MSuccessors .size () == 0 ) {
1350- Outputs.push_back (&*NodeImpl);
1349+ // Collect input and output nodes for the subgraph
1350+ std::vector<node_impl *> Inputs;
1351+ std::vector<node_impl *> Outputs;
1352+ for (std::shared_ptr<node_impl> &NodeImpl : NewSubgraphNodes) {
1353+ if (NodeImpl->MPredecessors .size () == 0 ) {
1354+ Inputs.push_back (&*NodeImpl);
1355+ }
1356+ if (NodeImpl->MSuccessors .size () == 0 ) {
1357+ Outputs.push_back (&*NodeImpl);
1358+ }
13511359 }
1352- }
13531360
1354- // Update the predecessors and successors of the nodes which reference the
1355- // original subgraph node
1361+ // Update the predecessors and successors of the nodes which reference the
1362+ // original subgraph node
13561363
1357- // Predecessors
1358- for (node_impl &PredNode : NewNode->predecessors ()) {
1359- auto &Successors = PredNode.MSuccessors ;
1364+ // Predecessors
1365+ for (node_impl &PredNode : NewNode->predecessors ()) {
1366+ auto &Successors = PredNode.MSuccessors ;
13601367
1361- // Remove the subgraph node from this nodes successors
1362- Successors.erase (
1363- std::remove (Successors.begin (), Successors.end (), NewNode.get ()),
1364- Successors.end ());
1368+ // Remove the subgraph node from this nodes successors
1369+ Successors.erase (
1370+ std::remove (Successors.begin (), Successors.end (), NewNode.get ()),
1371+ Successors.end ());
13651372
1366- // Add all input nodes from the subgraph as successors for this node
1367- // instead
1368- for (node_impl *Input : Inputs) {
1369- PredNode.registerSuccessor (*Input);
1373+ // Add all input nodes from the subgraph as successors for this node
1374+ // instead
1375+ for (node_impl *Input : Inputs) {
1376+ PredNode.registerSuccessor (*Input);
1377+ }
13701378 }
1371- }
13721379
1373- // Successors
1374- for (node_impl &SuccNode : NewNode->successors ()) {
1375- auto &Predecessors = SuccNode.MPredecessors ;
1380+ // Successors
1381+ for (node_impl &SuccNode : NewNode->successors ()) {
1382+ auto &Predecessors = SuccNode.MPredecessors ;
13761383
1377- // Remove the subgraph node from this nodes successors
1378- Predecessors.erase (
1379- std::remove (Predecessors. begin (), Predecessors. end (), NewNode.get ()),
1380- Predecessors.end ());
1384+ // Remove the subgraph node from this nodes successors
1385+ Predecessors.erase (std::remove (Predecessors. begin (), Predecessors. end (),
1386+ NewNode.get ()),
1387+ Predecessors.end ());
13811388
1382- // Add all Output nodes from the subgraph as predecessors for this node
1383- // instead
1384- for (node_impl *Output : Outputs) {
1385- Output->registerSuccessor (SuccNode);
1389+ // Add all Output nodes from the subgraph as predecessors for this node
1390+ // instead
1391+ for (node_impl *Output : Outputs) {
1392+ Output->registerSuccessor (SuccNode);
1393+ }
13861394 }
1387- }
13881395
1389- // Remove single subgraph node and add all new individual subgraph nodes
1390- // to the node storage in its place
1391- auto OldPositionIt =
1392- NewNodes.erase (std::find (NewNodes.begin (), NewNodes.end (), NewNode));
1393- // Also set the iterator to the newly added nodes so we can continue
1394- // iterating over all remaining nodes
1395- auto InsertIt = NewNodes.insert (
1396- OldPositionIt, std::make_move_iterator (NewSubgraphNodes.begin ()),
1397- std::make_move_iterator (NewSubgraphNodes.end ()));
1398- // Since the new reverse_iterator will be at i - 1 we need to advance it
1399- // when constructing
1400- NewNodeIt = std::make_reverse_iterator (std::next (InsertIt));
1396+ // Remove single subgraph node and add all new individual subgraph nodes
1397+ // to the node storage in its place
1398+ auto OldPositionIt =
1399+ NewNodes.erase (std::find (NewNodes.begin (), NewNodes.end (), NewNode));
1400+ // Also set the iterator to the newly added nodes so we can continue
1401+ // iterating over all remaining nodes
1402+ auto InsertIt = NewNodes.insert (
1403+ OldPositionIt, std::make_move_iterator (NewSubgraphNodes.begin ()),
1404+ std::make_move_iterator (NewSubgraphNodes.end ()));
1405+ // Since the new reverse_iterator will be at i - 1 we need to advance it
1406+ // when constructing
1407+ NewNodeIt = std::make_reverse_iterator (std::next (InsertIt));
1408+ }
14011409 }
14021410
14031411 // Store all the new nodes locally
1404- MNodeStorage.insert (MNodeStorage.begin (),
1405- std::make_move_iterator (NewNodes.begin ()),
1406- std::make_move_iterator (NewNodes.end ()));
1412+ MNodeStorage = std::move (NewNodes);
14071413}
14081414
14091415void exec_graph_impl::update (std::shared_ptr<graph_impl> GraphImpl) {
0 commit comments