@@ -100,17 +100,16 @@ void sortTopological(std::set<std::weak_ptr<node_impl>,
100100 Source.pop ();
101101 SortedNodes.push_back (Node);
102102
103- for (auto &SuccWP : Node->MSuccessors ) {
104- auto Succ = SuccWP.lock ();
103+ for (node_impl &Succ : Node->successors ()) {
105104
106- if (PartitionBounded && (Succ-> MPartitionNum != Node->MPartitionNum )) {
105+ if (PartitionBounded && (Succ. MPartitionNum != Node->MPartitionNum )) {
107106 continue ;
108107 }
109108
110- auto &TotalVisitedEdges = Succ-> MTotalVisitedEdges ;
109+ auto &TotalVisitedEdges = Succ. MTotalVisitedEdges ;
111110 ++TotalVisitedEdges;
112- if (TotalVisitedEdges == Succ-> MPredecessors .size ()) {
113- Source.push (Succ);
111+ if (TotalVisitedEdges == Succ. MPredecessors .size ()) {
112+ Source.push (Succ. weak_from_this () );
114113 }
115114 }
116115 }
@@ -127,14 +126,14 @@ void sortTopological(std::set<std::weak_ptr<node_impl>,
127126// / a node with a smaller partition number.
128127// / @param Node Node to assign to the partition.
129128// / @param PartitionNum Number to propagate.
130- void propagatePartitionUp (std::shared_ptr< node_impl> Node, int PartitionNum) {
131- if (((Node-> MPartitionNum != -1 ) && (Node-> MPartitionNum <= PartitionNum)) ||
132- (Node-> MCGType == sycl::detail::CGType::CodeplayHostTask)) {
129+ void propagatePartitionUp (node_impl & Node, int PartitionNum) {
130+ if (((Node. MPartitionNum != -1 ) && (Node. MPartitionNum <= PartitionNum)) ||
131+ (Node. MCGType == sycl::detail::CGType::CodeplayHostTask)) {
133132 return ;
134133 }
135- Node-> MPartitionNum = PartitionNum;
136- for (auto &Predecessor : Node-> MPredecessors ) {
137- propagatePartitionUp (Predecessor. lock () , PartitionNum);
134+ Node. MPartitionNum = PartitionNum;
135+ for (node_impl &Predecessor : Node. predecessors () ) {
136+ propagatePartitionUp (Predecessor, PartitionNum);
138137 }
139138}
140139
@@ -146,17 +145,17 @@ void propagatePartitionUp(std::shared_ptr<node_impl> Node, int PartitionNum) {
146145// / @param HostTaskList List of host tasks that have already been processed and
147146// / are encountered as successors to the node Node.
148147void propagatePartitionDown (
149- const std::shared_ptr< node_impl> &Node, int PartitionNum,
148+ node_impl &Node, int PartitionNum,
150149 std::list<std::shared_ptr<node_impl>> &HostTaskList) {
151- if (Node-> MCGType == sycl::detail::CGType::CodeplayHostTask) {
152- if (Node-> MPartitionNum != -1 ) {
153- HostTaskList.push_front (Node);
150+ if (Node. MCGType == sycl::detail::CGType::CodeplayHostTask) {
151+ if (Node. MPartitionNum != -1 ) {
152+ HostTaskList.push_front (Node. shared_from_this () );
154153 }
155154 return ;
156155 }
157- Node-> MPartitionNum = PartitionNum;
158- for (auto &Successor : Node-> MSuccessors ) {
159- propagatePartitionDown (Successor. lock () , PartitionNum, HostTaskList);
156+ Node. MPartitionNum = PartitionNum;
157+ for (node_impl &Successor : Node. successors () ) {
158+ propagatePartitionDown (Successor, PartitionNum, HostTaskList);
160159 }
161160}
162161
@@ -165,8 +164,8 @@ void propagatePartitionDown(
165164// / @param Node node to test
166165// / @return True is `Node` is a root of its partition
167166bool isPartitionRoot (std::shared_ptr<node_impl> Node) {
168- for (auto &Predecessor : Node->MPredecessors ) {
169- if (Predecessor.lock ()-> MPartitionNum == Node->MPartitionNum ) {
167+ for (node_impl &Predecessor : Node->predecessors () ) {
168+ if (Predecessor.MPartitionNum == Node->MPartitionNum ) {
170169 return false ;
171170 }
172171 }
@@ -221,15 +220,15 @@ void exec_graph_impl::makePartitions() {
221220 auto Node = HostTaskList.front ();
222221 HostTaskList.pop_front ();
223222 CurrentPartition++;
224- for (auto &Predecessor : Node->MPredecessors ) {
225- propagatePartitionUp (Predecessor. lock () , CurrentPartition);
223+ for (node_impl &Predecessor : Node->predecessors () ) {
224+ propagatePartitionUp (Predecessor, CurrentPartition);
226225 }
227226 CurrentPartition++;
228227 Node->MPartitionNum = CurrentPartition;
229228 CurrentPartition++;
230229 auto TmpSize = HostTaskList.size ();
231- for (auto &Successor : Node->MSuccessors ) {
232- propagatePartitionDown (Successor. lock () , CurrentPartition, HostTaskList);
230+ for (node_impl &Successor : Node->successors () ) {
231+ propagatePartitionDown (Successor, CurrentPartition, HostTaskList);
233232 }
234233 if (HostTaskList.size () > TmpSize) {
235234 // At least one HostTask has been re-numbered so group merge opportunities
@@ -290,9 +289,9 @@ void exec_graph_impl::makePartitions() {
290289 for (const auto &Partition : MPartitions) {
291290 for (auto const &Root : Partition->MRoots ) {
292291 auto RootNode = Root.lock ();
293- for (const auto &Dep : RootNode->MPredecessors ) {
294- auto NodeDep = Dep. lock ();
295- auto &Predecessor = MPartitions[MPartitionNodes[NodeDep]];
292+ for (node_impl &NodeDep : RootNode->predecessors () ) {
293+ auto &Predecessor =
294+ MPartitions[MPartitionNodes[NodeDep. shared_from_this () ]];
296295 Partition->MPredecessors .push_back (Predecessor.get ());
297296 Predecessor->MSuccessors .push_back (Partition.get ());
298297 }
@@ -390,8 +389,8 @@ std::set<std::shared_ptr<node_impl>> graph_impl::getCGEdges(
390389 bool ShouldAddDep = true ;
391390 // If any of this node's successors have this requirement then we skip
392391 // adding the current node as a dependency.
393- for (auto &Succ : Node->MSuccessors ) {
394- if (Succ.lock ()-> hasRequirementDependency (Req)) {
392+ for (node_impl &Succ : Node->successors () ) {
393+ if (Succ.hasRequirementDependency (Req)) {
395394 ShouldAddDep = false ;
396395 break ;
397396 }
@@ -721,17 +720,17 @@ void graph_impl::beginRecording(sycl::detail::queue_impl &Queue) {
721720// predecessors until we find the real dependency.
722721void exec_graph_impl::findRealDeps (
723722 std::vector<ur_exp_command_buffer_sync_point_t > &Deps,
724- std::shared_ptr<node_impl> CurrentNode, int ReferencePartitionNum) {
725- if (!CurrentNode->requiresEnqueue ()) {
726- for (auto &N : CurrentNode->MPredecessors ) {
727- auto NodeImpl = N.lock ();
723+ node_impl &CurrentNode, int ReferencePartitionNum) {
724+ if (!CurrentNode.requiresEnqueue ()) {
725+ for (node_impl &NodeImpl : CurrentNode.predecessors ()) {
728726 findRealDeps (Deps, NodeImpl, ReferencePartitionNum);
729727 }
730728 } else {
729+ auto CurrentNodePtr = CurrentNode.shared_from_this ();
731730 // Verify if CurrentNode belong the the same partition
732- if (MPartitionNodes[CurrentNode ] == ReferencePartitionNum) {
731+ if (MPartitionNodes[CurrentNodePtr ] == ReferencePartitionNum) {
733732 // Verify that the sync point has actually been set for this node.
734- auto SyncPoint = MSyncPoints.find (CurrentNode );
733+ auto SyncPoint = MSyncPoints.find (CurrentNodePtr );
735734 assert (SyncPoint != MSyncPoints.end () &&
736735 " No sync point has been set for node dependency." );
737736 // Check if the dependency has already been added.
@@ -749,8 +748,8 @@ exec_graph_impl::enqueueNodeDirect(const sycl::context &Ctx,
749748 ur_exp_command_buffer_handle_t CommandBuffer,
750749 std::shared_ptr<node_impl> Node) {
751750 std::vector<ur_exp_command_buffer_sync_point_t > Deps;
752- for (auto &N : Node->MPredecessors ) {
753- findRealDeps (Deps, N. lock () , MPartitionNodes[Node]);
751+ for (node_impl &N : Node->predecessors () ) {
752+ findRealDeps (Deps, N, MPartitionNodes[Node]);
754753 }
755754 ur_exp_command_buffer_sync_point_t NewSyncPoint;
756755 ur_exp_command_buffer_command_handle_t NewCommand = 0 ;
@@ -805,8 +804,8 @@ exec_graph_impl::enqueueNode(ur_exp_command_buffer_handle_t CommandBuffer,
805804 std::shared_ptr<node_impl> Node) {
806805
807806 std::vector<ur_exp_command_buffer_sync_point_t > Deps;
808- for (auto &N : Node->MPredecessors ) {
809- findRealDeps (Deps, N. lock () , MPartitionNodes[Node]);
807+ for (node_impl &N : Node->predecessors () ) {
808+ findRealDeps (Deps, N, MPartitionNodes[Node]);
810809 }
811810
812811 sycl::detail::EventImplPtr Event =
@@ -1275,8 +1274,8 @@ void exec_graph_impl::duplicateNodes() {
12751274 auto NodeCopy = NewNodes[i];
12761275 // Look through all the original node successors, find their copies and
12771276 // register those as successors with the current copied node
1278- for (auto &NextNode : OriginalNode->MSuccessors ) {
1279- auto Successor = NodesMap.at (NextNode.lock ());
1277+ for (node_impl &NextNode : OriginalNode->successors () ) {
1278+ auto Successor = NodesMap.at (NextNode.shared_from_this ());
12801279 NodeCopy->registerSuccessor (Successor);
12811280 }
12821281 }
@@ -1317,8 +1316,8 @@ void exec_graph_impl::duplicateNodes() {
13171316 auto SubgraphNode = SubgraphNodes[i];
13181317 auto NodeCopy = NewSubgraphNodes[i];
13191318
1320- for (auto &NextNode : SubgraphNode->MSuccessors ) {
1321- auto Successor = SubgraphNodesMap.at (NextNode.lock ());
1319+ for (node_impl &NextNode : SubgraphNode->successors () ) {
1320+ auto Successor = SubgraphNodesMap.at (NextNode.shared_from_this ());
13221321 NodeCopy->registerSuccessor (Successor);
13231322 }
13241323 }
@@ -1339,9 +1338,8 @@ void exec_graph_impl::duplicateNodes() {
13391338 // original subgraph node
13401339
13411340 // Predecessors
1342- for (auto &PredNodeWeak : NewNode->MPredecessors ) {
1343- auto PredNode = PredNodeWeak.lock ();
1344- auto &Successors = PredNode->MSuccessors ;
1341+ for (node_impl &PredNode : NewNode->predecessors ()) {
1342+ auto &Successors = PredNode.MSuccessors ;
13451343
13461344 // Remove the subgraph node from this nodes successors
13471345 Successors.erase (std::remove_if (Successors.begin (), Successors.end (),
@@ -1353,14 +1351,13 @@ void exec_graph_impl::duplicateNodes() {
13531351 // Add all input nodes from the subgraph as successors for this node
13541352 // instead
13551353 for (auto &Input : Inputs) {
1356- PredNode-> registerSuccessor (Input);
1354+ PredNode. registerSuccessor (Input);
13571355 }
13581356 }
13591357
13601358 // Successors
1361- for (auto &SuccNodeWeak : NewNode->MSuccessors ) {
1362- auto SuccNode = SuccNodeWeak.lock ();
1363- auto &Predecessors = SuccNode->MPredecessors ;
1359+ for (node_impl &SuccNode : NewNode->successors ()) {
1360+ auto &Predecessors = SuccNode.MPredecessors ;
13641361
13651362 // Remove the subgraph node from this nodes successors
13661363 Predecessors.erase (std::remove_if (Predecessors.begin (),
@@ -1373,7 +1370,7 @@ void exec_graph_impl::duplicateNodes() {
13731370 // Add all Output nodes from the subgraph as predecessors for this node
13741371 // instead
13751372 for (auto &Output : Outputs) {
1376- Output->registerSuccessor (SuccNode);
1373+ Output->registerSuccessor (SuccNode. shared_from_this () );
13771374 }
13781375 }
13791376
0 commit comments