diff --git a/sycl/source/detail/graph/graph_impl.cpp b/sycl/source/detail/graph/graph_impl.cpp index 9f878f0e8ea66..19a60c47dab34 100644 --- a/sycl/source/detail/graph/graph_impl.cpp +++ b/sycl/source/detail/graph/graph_impl.cpp @@ -327,7 +327,7 @@ graph_impl::graph_impl(const sycl::context &SyclContext, graph_impl::~graph_impl() { try { - clearQueues(); + clearQueues(false /*Needs lock*/); for (auto &MemObj : MMemObjs) { MemObj->markNoLongerBeingUsedInGraph(); } @@ -564,17 +564,21 @@ void graph_impl::removeQueue(sycl::detail::queue_impl &RecordingQueue) { MRecordingQueues.erase(RecordingQueue.weak_from_this()); } -bool graph_impl::clearQueues() { - bool AnyQueuesCleared = false; - for (auto &Queue : MRecordingQueues) { +void graph_impl::clearQueues(bool NeedsLock) { + graph_impl::RecQueuesStorage SwappedQueues; + { + graph_impl::WriteLock Guard(MMutex, std::defer_lock); + if (NeedsLock) { + Guard.lock(); + } + std::swap(MRecordingQueues, SwappedQueues); + } + + for (auto &Queue : SwappedQueues) { if (auto ValidQueue = Queue.lock(); ValidQueue) { ValidQueue->setCommandGraph(nullptr); - AnyQueuesCleared = true; } } - MRecordingQueues.clear(); - - return AnyQueuesCleared; } bool graph_impl::checkForCycles() { @@ -1964,8 +1968,7 @@ void modifiable_command_graph::begin_recording( } void modifiable_command_graph::end_recording() { - graph_impl::WriteLock Lock(impl->MMutex); - impl->clearQueues(); + impl->clearQueues(true /*Needs lock*/); } void modifiable_command_graph::end_recording(queue &RecordingQueue) { diff --git a/sycl/source/detail/graph/graph_impl.hpp b/sycl/source/detail/graph/graph_impl.hpp index 10cbbfab0282c..d35b271493ed0 100644 --- a/sycl/source/detail/graph/graph_impl.hpp +++ b/sycl/source/detail/graph/graph_impl.hpp @@ -190,9 +190,7 @@ class graph_impl : public std::enable_shared_from_this { /// Remove all queues which are recording to this graph, also sets all queues /// cleared back to the executing state. - /// - /// @return True if any queues were removed. - bool clearQueues(); + void clearQueues(bool NeedsLock); /// Associate a sycl event with a node in the graph. /// @param EventImpl Event to associate with a node in map. @@ -561,10 +559,12 @@ class graph_impl : public std::enable_shared_from_this { /// Device associated with this graph. All graph nodes will execute on this /// device. sycl::device MDevice; + + using RecQueuesStorage = + std::set, + std::owner_less>>; /// Unique set of queues which are currently recording to this graph. - std::set, - std::owner_less>> - MRecordingQueues; + RecQueuesStorage MRecordingQueues; /// Map of events to their associated recorded nodes. std::unordered_map, node_impl *> MEventsMap; diff --git a/sycl/source/detail/queue_impl.hpp b/sycl/source/detail/queue_impl.hpp index 7c793b619ecab..b7c74e07c2681 100644 --- a/sycl/source/detail/queue_impl.hpp +++ b/sycl/source/detail/queue_impl.hpp @@ -602,7 +602,8 @@ class queue_impl : public std::enable_shared_from_this { bool CallerNeedsEvent); void setCommandGraphUnlocked( - std::shared_ptr Graph) { + const std::shared_ptr + &Graph) { MGraph = Graph; MExtGraphDeps.reset(); @@ -614,7 +615,8 @@ class queue_impl : public std::enable_shared_from_this { } void setCommandGraph( - std::shared_ptr Graph) { + const std::shared_ptr + &Graph) { std::lock_guard Lock(MMutex); setCommandGraphUnlocked(Graph); }