@@ -142,12 +142,11 @@ void propagatePartitionUp(node_impl &Node, int PartitionNum) {
142142// / @param PartitionNum Number to propagate.
143143// / @param HostTaskList List of host tasks that have already been processed and
144144// / are encountered as successors to the node Node.
145- void propagatePartitionDown (
146- node_impl &Node, int PartitionNum,
147- std::list<std::shared_ptr<node_impl>> &HostTaskList) {
145+ void propagatePartitionDown (node_impl &Node, int PartitionNum,
146+ std::list<node_impl *> &HostTaskList) {
148147 if (Node.MCGType == sycl::detail::CGType::CodeplayHostTask) {
149148 if (Node.MPartitionNum != -1 ) {
150- HostTaskList.push_front (Node. shared_from_this () );
149+ HostTaskList.push_front (& Node);
151150 }
152151 return ;
153152 }
@@ -181,11 +180,11 @@ void partition::updateSchedule() {
181180
182181void exec_graph_impl::makePartitions () {
183182 int CurrentPartition = -1 ;
184- std::list<std::shared_ptr< node_impl> > HostTaskList;
183+ std::list<node_impl * > HostTaskList;
185184 // find all the host-tasks in the graph
186- for (auto &Node : MNodeStorage ) {
187- if (Node-> MCGType == sycl::detail::CGType::CodeplayHostTask) {
188- HostTaskList.push_back (Node);
185+ for (node_impl &Node : nodes () ) {
186+ if (Node. MCGType == sycl::detail::CGType::CodeplayHostTask) {
187+ HostTaskList.push_back (& Node);
189188 }
190189 }
191190
@@ -215,29 +214,29 @@ void exec_graph_impl::makePartitions() {
215214 // group that includes the predecessor of `B` can be merged with the group of
216215 // the predecessors of the node `A`.
217216 while (HostTaskList.size () > 0 ) {
218- auto Node = HostTaskList.front ();
217+ node_impl & Node = * HostTaskList.front ();
219218 HostTaskList.pop_front ();
220219 CurrentPartition++;
221- for (node_impl &Predecessor : Node-> predecessors ()) {
220+ for (node_impl &Predecessor : Node. predecessors ()) {
222221 propagatePartitionUp (Predecessor, CurrentPartition);
223222 }
224223 CurrentPartition++;
225- Node-> MPartitionNum = CurrentPartition;
224+ Node. MPartitionNum = CurrentPartition;
226225 CurrentPartition++;
227226 auto TmpSize = HostTaskList.size ();
228- for (node_impl &Successor : Node-> successors ()) {
227+ for (node_impl &Successor : Node. successors ()) {
229228 propagatePartitionDown (Successor, CurrentPartition, HostTaskList);
230229 }
231230 if (HostTaskList.size () > TmpSize) {
232231 // At least one HostTask has been re-numbered so group merge opportunities
233- for (const auto & HT : HostTaskList) {
232+ for (node_impl * HT : HostTaskList) {
234233 auto HTPartitionNum = HT->MPartitionNum ;
235234 if (HTPartitionNum != -1 ) {
236235 // can merge predecessors of node `Node` with predecessors of node
237236 // `HT` (HTPartitionNum-1) since HT must be reprocessed
238- for (const auto &NodeImpl : MNodeStorage ) {
239- if (NodeImpl-> MPartitionNum == Node-> MPartitionNum - 1 ) {
240- NodeImpl-> MPartitionNum = HTPartitionNum - 1 ;
237+ for (node_impl &NodeImpl : nodes () ) {
238+ if (NodeImpl. MPartitionNum == Node. MPartitionNum - 1 ) {
239+ NodeImpl. MPartitionNum = HTPartitionNum - 1 ;
241240 }
242241 }
243242 } else {
@@ -251,12 +250,12 @@ void exec_graph_impl::makePartitions() {
251250 int PartitionFinalNum = 0 ;
252251 for (int i = -1 ; i <= CurrentPartition; i++) {
253252 const std::shared_ptr<partition> &Partition = std::make_shared<partition>();
254- for (auto &Node : MNodeStorage ) {
255- if (Node-> MPartitionNum == i) {
256- MPartitionNodes[Node. get () ] = PartitionFinalNum;
257- if (isPartitionRoot (* Node)) {
258- Partition->MRoots .insert (Node. get () );
259- if (Node-> MCGType == CGType::CodeplayHostTask) {
253+ for (node_impl &Node : nodes () ) {
254+ if (Node. MPartitionNum == i) {
255+ MPartitionNodes[& Node] = PartitionFinalNum;
256+ if (isPartitionRoot (Node)) {
257+ Partition->MRoots .insert (& Node);
258+ if (Node. MCGType == CGType::CodeplayHostTask) {
260259 Partition->MIsHostTask = true ;
261260 }
262261 }
@@ -295,8 +294,8 @@ void exec_graph_impl::makePartitions() {
295294 }
296295
297296 // Reset node groups (if node have to be re-processed - e.g. subgraph)
298- for (auto &Node : MNodeStorage ) {
299- Node-> MPartitionNum = -1 ;
297+ for (node_impl &Node : nodes () ) {
298+ Node. MPartitionNum = -1 ;
300299 }
301300}
302301
@@ -376,19 +375,19 @@ std::set<node_impl *> graph_impl::getCGEdges(
376375 // A unique set of dependencies obtained by checking requirements and events
377376 for (auto &Req : Requirements) {
378377 // Look through the graph for nodes which share this requirement
379- for (auto &Node : MNodeStorage ) {
380- if (Node-> hasRequirementDependency (Req)) {
378+ for (node_impl &Node : nodes () ) {
379+ if (Node. hasRequirementDependency (Req)) {
381380 bool ShouldAddDep = true ;
382381 // If any of this node's successors have this requirement then we skip
383382 // adding the current node as a dependency.
384- for (node_impl &Succ : Node-> successors ()) {
383+ for (node_impl &Succ : Node. successors ()) {
385384 if (Succ.hasRequirementDependency (Req)) {
386385 ShouldAddDep = false ;
387386 break ;
388387 }
389388 }
390389 if (ShouldAddDep) {
391- UniqueDeps.insert (Node. get () );
390+ UniqueDeps.insert (& Node);
392391 }
393392 }
394393 }
@@ -487,7 +486,7 @@ node_impl &graph_impl::add(std::function<void(handler &)> CGF,
487486 }
488487
489488 for (auto &[DynamicParam, ArgIndex] : DynamicParams) {
490- DynamicParam->registerNode (NodeImpl. shared_from_this () , ArgIndex);
489+ DynamicParam->registerNode (NodeImpl, ArgIndex);
491490 }
492491
493492 return NodeImpl;
@@ -611,10 +610,9 @@ void graph_impl::setLastInorderNode(sycl::detail::queue_impl &Queue,
611610 MInorderQueueMap[Queue.weak_from_this ()] = &Node;
612611}
613612
614- void graph_impl::makeEdge (std::shared_ptr<node_impl> Src,
615- std::shared_ptr<node_impl> Dest) {
613+ void graph_impl::makeEdge (node_impl &Src, node_impl &Dest) {
616614 throwIfGraphRecordingQueue (" make_edge()" );
617- if (Src == Dest) {
615+ if (& Src == & Dest) {
618616 throw sycl::exception (
619617 make_error_code (sycl::errc::invalid),
620618 " make_edge() cannot be called when Src and Dest are the same." );
@@ -624,8 +622,8 @@ void graph_impl::makeEdge(std::shared_ptr<node_impl> Src,
624622 bool DestFound = false ;
625623 for (const auto &Node : MNodeStorage) {
626624
627- SrcFound |= Node == Src;
628- DestFound |= Node == Dest;
625+ SrcFound |= Node. get () == & Src;
626+ DestFound |= Node. get () == & Dest;
629627
630628 if (SrcFound && DestFound) {
631629 break ;
@@ -641,49 +639,49 @@ void graph_impl::makeEdge(std::shared_ptr<node_impl> Src,
641639 " Dest must be a node inside the graph." );
642640 }
643641
644- bool DestWasGraphRoot = Dest-> MPredecessors .size () == 0 ;
642+ bool DestWasGraphRoot = Dest. MPredecessors .size () == 0 ;
645643
646644 // We need to add the edges first before checking for cycles
647- Src-> registerSuccessor (* Dest);
645+ Src. registerSuccessor (Dest);
648646
649- bool DestLostRootStatus = DestWasGraphRoot && Dest-> MPredecessors .size () == 1 ;
647+ bool DestLostRootStatus = DestWasGraphRoot && Dest. MPredecessors .size () == 1 ;
650648 if (DestLostRootStatus) {
651649 // Dest is no longer a Root node, so we need to remove it from MRoots.
652- MRoots.erase (Dest. get () );
650+ MRoots.erase (& Dest);
653651 }
654652
655653 // We can skip cycle checks if either Dest has no successors (cycle not
656654 // possible) or cycle checks have been disabled with the no_cycle_check
657655 // property;
658- if (Dest-> MSuccessors .empty () || !MSkipCycleChecks) {
656+ if (Dest. MSuccessors .empty () || !MSkipCycleChecks) {
659657 bool CycleFound = checkForCycles ();
660658
661659 if (CycleFound) {
662660 // Remove the added successor and predecessor.
663- Src-> MSuccessors .pop_back ();
664- Dest-> MPredecessors .pop_back ();
661+ Src. MSuccessors .pop_back ();
662+ Dest. MPredecessors .pop_back ();
665663 if (DestLostRootStatus) {
666664 // Add Dest back into MRoots.
667- MRoots.insert (Dest. get () );
665+ MRoots.insert (& Dest);
668666 }
669667
670668 throw sycl::exception (make_error_code (sycl::errc::invalid),
671669 " Command graphs cannot contain cycles." );
672670 }
673671 }
674- removeRoot (* Dest); // remove receiver from root node list
672+ removeRoot (Dest); // remove receiver from root node list
675673}
676674
677675std::vector<sycl::detail::EventImplPtr> graph_impl::getExitNodesEvents (
678676 std::weak_ptr<sycl::detail::queue_impl> RecordedQueue) {
679677 std::vector<sycl::detail::EventImplPtr> Events;
680678
681679 auto RecordedQueueSP = RecordedQueue.lock ();
682- for (auto &Node : MNodeStorage ) {
683- if (Node-> MSuccessors .empty ()) {
684- auto EventForNode = getEventForNode (* Node);
680+ for (node_impl &Node : nodes () ) {
681+ if (Node. MSuccessors .empty ()) {
682+ auto EventForNode = getEventForNode (Node);
685683 if (EventForNode->getSubmittedQueue () == RecordedQueueSP) {
686- Events.push_back (getEventForNode (* Node));
684+ Events.push_back (getEventForNode (Node));
687685 }
688686 }
689687 }
@@ -1433,15 +1431,14 @@ void exec_graph_impl::update(std::shared_ptr<graph_impl> GraphImpl) {
14331431 std::make_pair (GraphImpl->MNodeStorage [i]->MID , MNodeStorage[i].get ()));
14341432 }
14351433
1436- update (GraphImpl->MNodeStorage );
1434+ update (GraphImpl->nodes () );
14371435}
14381436
1439- void exec_graph_impl::update (std::shared_ptr< node_impl> Node) {
1440- this ->update (std::vector<std::shared_ptr< node_impl>>{ Node});
1437+ void exec_graph_impl::update (node_impl & Node) {
1438+ this ->update (std::vector<node_impl *>{& Node});
14411439}
14421440
1443- void exec_graph_impl::update (
1444- const std::vector<std::shared_ptr<node_impl>> &Nodes) {
1441+ void exec_graph_impl::update (nodes_range Nodes) {
14451442 if (!MIsUpdatable) {
14461443 throw sycl::exception (sycl::make_error_code (errc::invalid),
14471444 " update() cannot be called on a executable graph "
@@ -1502,7 +1499,7 @@ void exec_graph_impl::update(
15021499}
15031500
15041501bool exec_graph_impl::needsScheduledUpdate (
1505- const std::vector<std::shared_ptr<node_impl>> & Nodes,
1502+ nodes_range Nodes,
15061503 std::vector<sycl::detail::AccessorImplHost *> &UpdateRequirements) {
15071504 // If there are any accessor requirements, we have to update through the
15081505 // scheduler to ensure that any allocations have taken place before trying to
@@ -1511,30 +1508,30 @@ bool exec_graph_impl::needsScheduledUpdate(
15111508 // At worst we may have as many requirements as there are for the entire graph
15121509 // for updating.
15131510 UpdateRequirements.reserve (MRequirements.size ());
1514- for (auto &Node : Nodes) {
1511+ for (node_impl &Node : Nodes) {
15151512 // Check if node(s) derived from this modifiable node exists in this graph
1516- if (MIDCache.count (Node-> getID ()) == 0 ) {
1513+ if (MIDCache.count (Node. getID ()) == 0 ) {
15171514 throw sycl::exception (
15181515 sycl::make_error_code (errc::invalid),
15191516 " Node passed to update() is not part of the graph." );
15201517 }
15211518
1522- if (!Node-> isUpdatable ()) {
1519+ if (!Node. isUpdatable ()) {
15231520 std::string ErrorString = " node_type::" ;
1524- ErrorString += nodeTypeToString (Node-> MNodeType );
1521+ ErrorString += nodeTypeToString (Node. MNodeType );
15251522 ErrorString +=
15261523 " nodes are not supported for update. Only kernel, host_task, "
15271524 " barrier and empty nodes are supported." ;
15281525 throw sycl::exception (errc::invalid, ErrorString);
15291526 }
15301527
1531- if (const auto &CG = Node-> MCommandGroup ;
1528+ if (const auto &CG = Node. MCommandGroup ;
15321529 CG && CG->getRequirements ().size () != 0 ) {
15331530 NeedScheduledUpdate = true ;
15341531
15351532 UpdateRequirements.insert (UpdateRequirements.end (),
1536- Node-> MCommandGroup ->getRequirements ().begin (),
1537- Node-> MCommandGroup ->getRequirements ().end ());
1533+ Node. MCommandGroup ->getRequirements ().begin (),
1534+ Node. MCommandGroup ->getRequirements ().end ());
15381535 }
15391536 }
15401537
@@ -1740,18 +1737,17 @@ exec_graph_impl::getURUpdatableNodes(nodes_range Nodes) const {
17401737 return PartitionedNodes;
17411738}
17421739
1743- void exec_graph_impl::updateHostTasksImpl (
1744- const std::vector<std::shared_ptr<node_impl>> &Nodes) const {
1745- for (auto &Node : Nodes) {
1746- if (Node->MNodeType != node_type::host_task) {
1740+ void exec_graph_impl::updateHostTasksImpl (nodes_range Nodes) const {
1741+ for (node_impl &Node : Nodes) {
1742+ if (Node.MNodeType != node_type::host_task) {
17471743 continue ;
17481744 }
17491745 // Query the ID cache to find the equivalent exec node for the node passed
17501746 // to this function.
1751- auto ExecNode = MIDCache.find (Node-> MID );
1747+ auto ExecNode = MIDCache.find (Node. MID );
17521748 assert (ExecNode != MIDCache.end () && " Node ID was not found in ID cache" );
17531749
1754- ExecNode->second ->updateFromOtherNode (* Node);
1750+ ExecNode->second ->updateFromOtherNode (Node);
17551751 }
17561752}
17571753
@@ -1852,21 +1848,18 @@ node modifiable_command_graph::addImpl(std::function<void(handler &)> CGF,
18521848void modifiable_command_graph::addGraphLeafDependencies (node Node) {
18531849 // Find all exit nodes in the current graph and add them to the dependency
18541850 // vector
1855- std::shared_ptr<detail::node_impl> DstImpl =
1856- sycl::detail::getSyclObjImpl (Node);
1851+ detail::node_impl &DstImpl = *sycl::detail::getSyclObjImpl (Node);
18571852 graph_impl::WriteLock Lock (impl->MMutex );
18581853 for (auto &NodeImpl : impl->MNodeStorage ) {
1859- if ((NodeImpl->MSuccessors .size () == 0 ) && (NodeImpl != DstImpl)) {
1860- impl->makeEdge (NodeImpl, DstImpl);
1854+ if ((NodeImpl->MSuccessors .size () == 0 ) && (NodeImpl. get () != & DstImpl)) {
1855+ impl->makeEdge (* NodeImpl, DstImpl);
18611856 }
18621857 }
18631858}
18641859
18651860void modifiable_command_graph::make_edge (node &Src, node &Dest) {
1866- std::shared_ptr<detail::node_impl> SenderImpl =
1867- sycl::detail::getSyclObjImpl (Src);
1868- std::shared_ptr<detail::node_impl> ReceiverImpl =
1869- sycl::detail::getSyclObjImpl (Dest);
1861+ detail::node_impl &SenderImpl = *sycl::detail::getSyclObjImpl (Src);
1862+ detail::node_impl &ReceiverImpl = *sycl::detail::getSyclObjImpl (Dest);
18701863
18711864 graph_impl::WriteLock Lock (impl->MMutex );
18721865 impl->makeEdge (SenderImpl, ReceiverImpl);
@@ -2030,17 +2023,11 @@ void executable_command_graph::update(
20302023}
20312024
20322025void executable_command_graph::update (const node &Node) {
2033- impl->update (sycl::detail::getSyclObjImpl (Node));
2026+ impl->update (* sycl::detail::getSyclObjImpl (Node));
20342027}
20352028
20362029void executable_command_graph::update (const std::vector<node> &Nodes) {
2037- std::vector<std::shared_ptr<node_impl>> NodeImpls{};
2038- NodeImpls.reserve (Nodes.size ());
2039- for (auto &Node : Nodes) {
2040- NodeImpls.push_back (sycl::detail::getSyclObjImpl (Node));
2041- }
2042-
2043- impl->update (NodeImpls);
2030+ impl->update (Nodes);
20442031}
20452032
20462033size_t executable_command_graph::get_required_mem_size () const {
0 commit comments