@@ -255,7 +255,7 @@ void exec_graph_impl::makePartitions() {
255255 const std::shared_ptr<partition> &Partition = std::make_shared<partition>();
256256 for (auto &Node : MNodeStorage) {
257257 if (Node->MPartitionNum == i) {
258- MPartitionNodes[Node] = PartitionFinalNum;
258+ MPartitionNodes[Node. get () ] = PartitionFinalNum;
259259 if (isPartitionRoot (Node)) {
260260 Partition->MRoots .insert (Node);
261261 if (Node->MCGType == CGType::CodeplayHostTask) {
@@ -290,8 +290,7 @@ void exec_graph_impl::makePartitions() {
290290 for (auto const &Root : Partition->MRoots ) {
291291 auto RootNode = Root.lock ();
292292 for (node_impl &NodeDep : RootNode->predecessors ()) {
293- auto &Predecessor =
294- MPartitions[MPartitionNodes[NodeDep.shared_from_this ()]];
293+ auto &Predecessor = MPartitions[MPartitionNodes[&NodeDep]];
295294 Partition->MPredecessors .push_back (Predecessor.get ());
296295 Predecessor->MSuccessors .push_back (Partition.get ());
297296 }
@@ -610,8 +609,7 @@ bool graph_impl::checkForCycles() {
610609 return CycleFound;
611610}
612611
613- std::shared_ptr<node_impl>
614- graph_impl::getLastInorderNode (sycl::detail::queue_impl *Queue) {
612+ node_impl *graph_impl::getLastInorderNode (sycl::detail::queue_impl *Queue) {
615613 if (!Queue) {
616614 assert (0 ==
617615 MInorderQueueMap.count (std::weak_ptr<sycl::detail::queue_impl>{}));
@@ -624,8 +622,8 @@ graph_impl::getLastInorderNode(sycl::detail::queue_impl *Queue) {
624622}
625623
626624void graph_impl::setLastInorderNode (sycl::detail::queue_impl &Queue,
627- std::shared_ptr< node_impl> Node) {
628- MInorderQueueMap[Queue.weak_from_this ()] = std::move ( Node) ;
625+ node_impl & Node) {
626+ MInorderQueueMap[Queue.weak_from_this ()] = & Node;
629627}
630628
631629void graph_impl::makeEdge (std::shared_ptr<node_impl> Src,
@@ -728,9 +726,9 @@ void exec_graph_impl::findRealDeps(
728726 } else {
729727 auto CurrentNodePtr = CurrentNode.shared_from_this ();
730728 // Verify if CurrentNode belong the the same partition
731- if (MPartitionNodes[CurrentNodePtr ] == ReferencePartitionNum) {
729+ if (MPartitionNodes[&CurrentNode ] == ReferencePartitionNum) {
732730 // Verify that the sync point has actually been set for this node.
733- auto SyncPoint = MSyncPoints.find (CurrentNodePtr );
731+ auto SyncPoint = MSyncPoints.find (&CurrentNode );
734732 assert (SyncPoint != MSyncPoints.end () &&
735733 " No sync point has been set for node dependency." );
736734 // Check if the dependency has already been added.
@@ -749,7 +747,7 @@ exec_graph_impl::enqueueNodeDirect(const sycl::context &Ctx,
749747 std::shared_ptr<node_impl> Node) {
750748 std::vector<ur_exp_command_buffer_sync_point_t > Deps;
751749 for (node_impl &N : Node->predecessors ()) {
752- findRealDeps (Deps, N, MPartitionNodes[Node]);
750+ findRealDeps (Deps, N, MPartitionNodes[Node. get () ]);
753751 }
754752 ur_exp_command_buffer_sync_point_t NewSyncPoint;
755753 ur_exp_command_buffer_command_handle_t NewCommand = 0 ;
@@ -782,7 +780,7 @@ exec_graph_impl::enqueueNodeDirect(const sycl::context &Ctx,
782780 Deps, &NewSyncPoint, MIsUpdatable ? &NewCommand : nullptr , nullptr );
783781
784782 if (MIsUpdatable) {
785- MCommandMap[Node] = NewCommand;
783+ MCommandMap[Node. get () ] = NewCommand;
786784 }
787785
788786 if (Res != UR_RESULT_SUCCESS) {
@@ -805,7 +803,7 @@ exec_graph_impl::enqueueNode(ur_exp_command_buffer_handle_t CommandBuffer,
805803
806804 std::vector<ur_exp_command_buffer_sync_point_t > Deps;
807805 for (node_impl &N : Node->predecessors ()) {
808- findRealDeps (Deps, N, MPartitionNodes[Node]);
806+ findRealDeps (Deps, N, MPartitionNodes[Node. get () ]);
809807 }
810808
811809 sycl::detail::EventImplPtr Event =
@@ -814,7 +812,7 @@ exec_graph_impl::enqueueNode(ur_exp_command_buffer_handle_t CommandBuffer,
814812 /* EventNeeded=*/ true , CommandBuffer, Deps);
815813
816814 if (MIsUpdatable) {
817- MCommandMap[Node] = Event->getCommandBufferCommand ();
815+ MCommandMap[Node. get () ] = Event->getCommandBufferCommand ();
818816 }
819817
820818 return Event->getSyncPoint ();
@@ -830,7 +828,8 @@ void exec_graph_impl::buildRequirements() {
830828 Node->MCommandGroup ->getRequirements ().begin (),
831829 Node->MCommandGroup ->getRequirements ().end ());
832830
833- std::shared_ptr<partition> &Partition = MPartitions[MPartitionNodes[Node]];
831+ std::shared_ptr<partition> &Partition =
832+ MPartitions[MPartitionNodes[Node.get ()]];
834833
835834 Partition->MRequirements .insert (
836835 Partition->MRequirements .end (),
@@ -877,10 +876,10 @@ void exec_graph_impl::createCommandBuffers(
877876 Node->MCommandGroup .get ())
878877 ->MStreams .size () ==
879878 0 ) {
880- MSyncPoints[Node] =
879+ MSyncPoints[Node. get () ] =
881880 enqueueNodeDirect (MContext, DeviceImpl, OutCommandBuffer, Node);
882881 } else {
883- MSyncPoints[Node] = enqueueNode (OutCommandBuffer, Node);
882+ MSyncPoints[Node. get () ] = enqueueNode (OutCommandBuffer, Node);
884883 }
885884 }
886885
@@ -1726,7 +1725,7 @@ void exec_graph_impl::populateURKernelUpdateStructs(
17261725 auto ExecNode = MIDCache.find (Node->MID );
17271726 assert (ExecNode != MIDCache.end () && " Node ID was not found in ID cache" );
17281727
1729- auto Command = MCommandMap.find (ExecNode->second );
1728+ auto Command = MCommandMap.find (ExecNode->second . get () );
17301729 assert (Command != MCommandMap.end ());
17311730 UpdateDesc.hCommand = Command->second ;
17321731
@@ -1756,7 +1755,7 @@ exec_graph_impl::getURUpdatableNodes(
17561755
17571756 auto ExecNode = MIDCache.find (Node->MID );
17581757 assert (ExecNode != MIDCache.end () && " Node ID was not found in ID cache" );
1759- auto PartitionIndex = MPartitionNodes.find (ExecNode->second );
1758+ auto PartitionIndex = MPartitionNodes.find (ExecNode->second . get () );
17601759 assert (PartitionIndex != MPartitionNodes.end ());
17611760 PartitionedNodes[PartitionIndex->second ].push_back (Node);
17621761 }
0 commit comments