@@ -341,7 +341,7 @@ void graph_impl::addRoot(node_impl &Root) { MRoots.insert(&Root); }
341341
342342void graph_impl::removeRoot (node_impl &Root) { MRoots.erase (&Root); }
343343
344- std::set<std::shared_ptr< node_impl> > graph_impl::getCGEdges (
344+ std::set<node_impl * > graph_impl::getCGEdges (
345345 const std::shared_ptr<sycl::detail::CG> &CommandGroup) const {
346346 const auto &Requirements = CommandGroup->getRequirements ();
347347 if (!MAllowBuffers && Requirements.size ()) {
@@ -362,14 +362,14 @@ std::set<std::shared_ptr<node_impl>> graph_impl::getCGEdges(
362362 }
363363
364364 // Add any nodes specified by event dependencies into the dependency list
365- std::set<std::shared_ptr< node_impl> > UniqueDeps;
365+ std::set<node_impl * > UniqueDeps;
366366 for (auto &Dep : CommandGroup->getEvents ()) {
367367 if (auto NodeImpl = MEventsMap.find (Dep); NodeImpl == MEventsMap.end ()) {
368368 throw sycl::exception (sycl::make_error_code (errc::invalid),
369369 " Event dependency from handler::depends_on does "
370370 " not correspond to a node within the graph" );
371371 } else {
372- UniqueDeps.insert (NodeImpl->second -> shared_from_this () );
372+ UniqueDeps.insert (NodeImpl->second );
373373 }
374374 }
375375
@@ -388,7 +388,7 @@ std::set<std::shared_ptr<node_impl>> graph_impl::getCGEdges(
388388 }
389389 }
390390 if (ShouldAddDep) {
391- UniqueDeps.insert (Node);
391+ UniqueDeps.insert (Node. get () );
392392 }
393393 }
394394 }
@@ -501,7 +501,7 @@ graph_impl::add(node_type NodeType,
501501 nodes_range Deps) {
502502
503503 // A unique set of dependencies obtained by checking requirements and events
504- std::set<std::shared_ptr< node_impl> > UniqueDeps = getCGEdges (CommandGroup);
504+ std::set<node_impl * > UniqueDeps = getCGEdges (CommandGroup);
505505
506506 // Track and mark the memory objects being used by the graph.
507507 markCGMemObjs (CommandGroup);
@@ -530,8 +530,7 @@ std::shared_ptr<node_impl>
530530graph_impl::add (std::shared_ptr<dynamic_command_group_impl> &DynCGImpl,
531531 nodes_range Deps) {
532532 // Set of Dependent nodes based on CG event and accessor dependencies.
533- std::set<std::shared_ptr<node_impl>> DynCGDeps =
534- getCGEdges (DynCGImpl->MCommandGroups [0 ]);
533+ std::set<node_impl *> DynCGDeps = getCGEdges (DynCGImpl->MCommandGroups [0 ]);
535534 for (unsigned i = 1 ; i < DynCGImpl->getNumCGs (); i++) {
536535 auto &CG = DynCGImpl->MCommandGroups [i];
537536 auto CGEdges = getCGEdges (CG);
@@ -1559,7 +1558,7 @@ bool exec_graph_impl::needsScheduledUpdate(
15591558}
15601559
15611560void exec_graph_impl::populateURKernelUpdateStructs (
1562- const std::shared_ptr< node_impl> &Node, FastKernelCacheValPtr &BundleObjs,
1561+ node_impl &Node, FastKernelCacheValPtr &BundleObjs,
15631562 std::vector<ur_exp_command_buffer_update_memobj_arg_desc_t > &MemobjDescs,
15641563 std::vector<ur_kernel_arg_mem_obj_properties_t > &MemobjProps,
15651564 std::vector<ur_exp_command_buffer_update_pointer_arg_desc_t > &PtrDescs,
@@ -1574,7 +1573,7 @@ void exec_graph_impl::populateURKernelUpdateStructs(
15741573
15751574 // Gather arg information from Node
15761575 auto &ExecCG =
1577- *(static_cast <sycl::detail::CGExecKernel *>(Node-> MCommandGroup .get ()));
1576+ *(static_cast <sycl::detail::CGExecKernel *>(Node. MCommandGroup .get ()));
15781577 // Copy args because we may modify them
15791578 std::vector<sycl::detail::ArgDesc> NodeArgs = ExecCG.getArguments ();
15801579 // Copy NDR desc since we need to modify it
@@ -1713,7 +1712,7 @@ void exec_graph_impl::populateURKernelUpdateStructs(
17131712 // TODO: Handle subgraphs or any other cases where multiple nodes may be
17141713 // associated with a single key, once those node types are supported for
17151714 // update.
1716- auto ExecNode = MIDCache.find (Node-> MID );
1715+ auto ExecNode = MIDCache.find (Node. MID );
17171716 assert (ExecNode != MIDCache.end () && " Node ID was not found in ID cache" );
17181717
17191718 auto Command = MCommandMap.find (ExecNode->second .get ());
@@ -1725,30 +1724,29 @@ void exec_graph_impl::populateURKernelUpdateStructs(
17251724 ExecNode->second ->updateFromOtherNode (Node);
17261725}
17271726
1728- std::map<int , std::vector<std::shared_ptr<node_impl>>>
1729- exec_graph_impl::getURUpdatableNodes (
1730- const std::vector<std::shared_ptr<node_impl>> &Nodes) const {
1727+ std::map<int , std::vector<node_impl *>>
1728+ exec_graph_impl::getURUpdatableNodes (nodes_range Nodes) const {
17311729 // Iterate over the list of nodes, and for every node that can
17321730 // be updated through UR, add it to the list of nodes for
17331731 // that can be updated for the UR command-buffer partition.
1734- std::map<int , std::vector<std::shared_ptr< node_impl> >> PartitionedNodes;
1732+ std::map<int , std::vector<node_impl * >> PartitionedNodes;
17351733
17361734 // Initialize vector for each partition
17371735 for (size_t i = 0 ; i < MPartitions.size (); i++) {
17381736 PartitionedNodes[i] = {};
17391737 }
17401738
1741- for (auto &Node : Nodes) {
1739+ for (node_impl &Node : Nodes) {
17421740 // Kernel node update is the only command type supported in UR for update.
1743- if (Node-> MCGType != sycl::detail::CGType::Kernel) {
1741+ if (Node. MCGType != sycl::detail::CGType::Kernel) {
17441742 continue ;
17451743 }
17461744
1747- auto ExecNode = MIDCache.find (Node-> MID );
1745+ auto ExecNode = MIDCache.find (Node. MID );
17481746 assert (ExecNode != MIDCache.end () && " Node ID was not found in ID cache" );
17491747 auto PartitionIndex = MPartitionNodes.find (ExecNode->second .get ());
17501748 assert (PartitionIndex != MPartitionNodes.end ());
1751- PartitionedNodes[PartitionIndex->second ].push_back (Node);
1749+ PartitionedNodes[PartitionIndex->second ].push_back (& Node);
17521750 }
17531751
17541752 return PartitionedNodes;
@@ -1765,13 +1763,12 @@ void exec_graph_impl::updateHostTasksImpl(
17651763 auto ExecNode = MIDCache.find (Node->MID );
17661764 assert (ExecNode != MIDCache.end () && " Node ID was not found in ID cache" );
17671765
1768- ExecNode->second ->updateFromOtherNode (Node);
1766+ ExecNode->second ->updateFromOtherNode (* Node);
17691767 }
17701768}
17711769
1772- void exec_graph_impl::updateURImpl (
1773- ur_exp_command_buffer_handle_t CommandBuffer,
1774- const std::vector<std::shared_ptr<node_impl>> &Nodes) const {
1770+ void exec_graph_impl::updateURImpl (ur_exp_command_buffer_handle_t CommandBuffer,
1771+ nodes_range Nodes) const {
17751772 const size_t NumUpdatableNodes = Nodes.size ();
17761773 if (NumUpdatableNodes == 0 ) {
17771774 return ;
@@ -1797,10 +1794,10 @@ void exec_graph_impl::updateURImpl(
17971794 std::vector<FastKernelCacheValPtr> KernelBundleObjList (NumUpdatableNodes);
17981795
17991796 size_t StructListIndex = 0 ;
1800- for (auto &Node : Nodes) {
1797+ for (node_impl &Node : Nodes) {
18011798 // This should be the case when getURUpdatableNodes() is used to
18021799 // create the list of nodes.
1803- assert (Node-> MCGType == sycl::detail::CGType::Kernel);
1800+ assert (Node. MCGType == sycl::detail::CGType::Kernel);
18041801
18051802 auto &MemobjDescs = MemobjDescsList[StructListIndex];
18061803 auto &MemobjProps = MemobjPropsList[StructListIndex];
0 commit comments