66//
77// ===----------------------------------------------------------------------===//
88
9+ #include " sycl/accessor.hpp"
910#include < detail/context_impl.hpp>
1011#include < detail/event_impl.hpp>
12+ #include < detail/graph_impl.hpp>
1113#include < detail/queue_impl.hpp>
1214#include < sycl/detail/ur.hpp>
1315#include < sycl/ext/oneapi/experimental/async_alloc/async_alloc.hpp>
@@ -29,6 +31,27 @@ getUrEvents(const std::vector<std::shared_ptr<detail::event_impl>> &DepEvents) {
2931 }
3032 return RetUrEvents;
3133}
34+
35+ std::vector<std::shared_ptr<detail::node_impl>> getDepGraphNodes (
36+ sycl::handler &Handler, const std::shared_ptr<detail::queue_impl> &Queue,
37+ const std::shared_ptr<detail::graph_impl> &Graph,
38+ const std::vector<std::shared_ptr<detail::event_impl>> &DepEvents) {
39+ auto HandlerImpl = detail::getSyclObjImpl (Handler);
40+ // Get dependent graph nodes from any events
41+ auto DepNodes = Graph->getNodesForEvents (DepEvents);
42+ // If this node was added explicitly we may have node deps in the handler as
43+ // well, so add them to the list
44+ DepNodes.insert (DepNodes.end (), HandlerImpl->MNodeDeps .begin (),
45+ HandlerImpl->MNodeDeps .end ());
46+ // If this is being recorded from an in-order queue we need to get the last
47+ // in-order node if any, since this will later become a dependency of the
48+ // node being processed here.
49+ if (const auto &LastInOrderNode = Graph->getLastInorderNode (Queue);
50+ LastInOrderNode) {
51+ DepNodes.push_back (LastInOrderNode);
52+ }
53+ return DepNodes;
54+ }
3255} // namespace
3356
3457__SYCL_EXPORT
@@ -46,22 +69,23 @@ void *async_malloc(sycl::handler &h, sycl::usm::alloc kind, size_t size) {
4669
4770 auto &Adapter = h.getContextImplPtr ()->getAdapter ();
4871
49- // Get events to wait on .
50- auto depEvents = getUrEvents ( h.impl ->CGData .MEvents ) ;
51- uint32_t numEvents = h. impl -> CGData . MEvents . size ( );
72+ // Get CG event dependencies for this allocation .
73+ const auto &DepEvents = h.impl ->CGData .MEvents ;
74+ auto UREvents = getUrEvents (DepEvents );
5275
5376 void *alloc = nullptr ;
5477
5578 ur_event_handle_t Event = nullptr ;
5679 // If a graph is present do the allocation from the graph memory pool instead.
5780 if (auto Graph = h.getCommandGraph (); Graph) {
58- alloc = Graph->getMemPool ().malloc (size, kind);
81+ auto DepNodes = getDepGraphNodes (h, h.MQueue , Graph, DepEvents);
82+ alloc = Graph->getMemPool ().malloc (size, kind, DepNodes);
5983 } else {
6084 auto &Q = h.MQueue ->getHandleRef ();
6185 Adapter->call <sycl::errc::runtime,
6286 sycl::detail::UrApiKind::urEnqueueUSMDeviceAllocExp>(
63- Q, (ur_usm_pool_handle_t )0 , size, nullptr , numEvents, depEvents. data (),
64- &alloc, &Event);
87+ Q, (ur_usm_pool_handle_t )0 , size, nullptr , UREvents. size (),
88+ UREvents. data (), &alloc, &Event);
6589 }
6690
6791 // Async malloc must return a void* immediately.
@@ -95,24 +119,26 @@ __SYCL_EXPORT void *async_malloc_from_pool(sycl::handler &h, size_t size,
95119 auto &Adapter = h.getContextImplPtr ()->getAdapter ();
96120 auto &memPoolImpl = sycl::detail::getSyclObjImpl (pool);
97121
98- // Get events to wait on .
99- auto depEvents = getUrEvents ( h.impl ->CGData .MEvents ) ;
100- uint32_t numEvents = h. impl -> CGData . MEvents . size ( );
122+ // Get CG event dependencies for this allocation .
123+ const auto &DepEvents = h.impl ->CGData .MEvents ;
124+ auto UREvents = getUrEvents (DepEvents );
101125
102126 void *alloc = nullptr ;
103127
104128 ur_event_handle_t Event = nullptr ;
105129 // If a graph is present do the allocation from the graph memory pool instead.
106130 if (auto Graph = h.getCommandGraph (); Graph) {
131+ auto DepNodes = getDepGraphNodes (h, h.MQueue , Graph, DepEvents);
132+
107133 // Memory pool is passed as the graph may use some properties of it.
108- alloc = Graph->getMemPool ().malloc (size, pool.get_alloc_kind (),
134+ alloc = Graph->getMemPool ().malloc (size, pool.get_alloc_kind (), DepNodes,
109135 sycl::detail::getSyclObjImpl (pool));
110136 } else {
111137 auto &Q = h.MQueue ->getHandleRef ();
112138 Adapter->call <sycl::errc::runtime,
113139 sycl::detail::UrApiKind::urEnqueueUSMDeviceAllocExp>(
114- Q, memPoolImpl.get ()->get_handle (), size, nullptr , numEvents ,
115- depEvents .data (), &alloc, &Event);
140+ Q, memPoolImpl.get ()->get_handle (), size, nullptr , UREvents. size () ,
141+ UREvents .data (), &alloc, &Event);
116142 }
117143 // Async malloc must return a void* immediately.
118144 // Set up CommandGroup which is a no-op and pass the event from the alloc.
@@ -140,6 +166,9 @@ async_malloc_from_pool(const sycl::queue &q, size_t size,
140166}
141167
142168__SYCL_EXPORT void async_free (sycl::handler &h, void *ptr) {
169+ // We only check for errors for the graph here because marking the allocation
170+ // as free in the graph memory pool requires a node object which doesn't exist
171+ // at this point.
143172 if (auto Graph = h.getCommandGraph (); Graph) {
144173 // Check if the pointer to be freed has an associated allocation node, and
145174 // error if not
0 commit comments