Skip to content

Commit f04c79b

Browse files
author
Ewan Crawford
authored
[SYCL][Graph] enable_shared_from_this refactor (#15195)
Use `std::enable_shared_from_this` to remove need for passing a shared pointer of `this` as a function parameter. `std::enable_shared_from_this` usage was previously introduced to graph code in #14453 (comment)
1 parent 57cf62c commit f04c79b

File tree

3 files changed

+32
-52
lines changed

3 files changed

+32
-52
lines changed

sycl/source/detail/graph_impl.cpp

Lines changed: 17 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,6 @@ graph_impl::~graph_impl() {
308308
}
309309

310310
std::shared_ptr<node_impl> graph_impl::addNodesToExits(
311-
const std::shared_ptr<graph_impl> &Impl,
312311
const std::list<std::shared_ptr<node_impl>> &NodeList) {
313312
// Find all input and output nodes from the node list
314313
std::vector<std::shared_ptr<node_impl>> Inputs;
@@ -327,18 +326,18 @@ std::shared_ptr<node_impl> graph_impl::addNodesToExits(
327326
for (auto &NodeImpl : MNodeStorage) {
328327
if (NodeImpl->MSuccessors.size() == 0) {
329328
for (auto &Input : Inputs) {
330-
NodeImpl->registerSuccessor(Input, NodeImpl);
329+
NodeImpl->registerSuccessor(Input);
331330
}
332331
}
333332
}
334333

335334
// Add all the new nodes to the node storage
336335
for (auto &Node : NodeList) {
337336
MNodeStorage.push_back(Node);
338-
addEventForNode(Impl, std::make_shared<sycl::detail::event_impl>(), Node);
337+
addEventForNode(std::make_shared<sycl::detail::event_impl>(), Node);
339338
}
340339

341-
return this->add(Impl, Outputs);
340+
return this->add(Outputs);
342341
}
343342

344343
void graph_impl::addRoot(const std::shared_ptr<node_impl> &Root) {
@@ -350,8 +349,7 @@ void graph_impl::removeRoot(const std::shared_ptr<node_impl> &Root) {
350349
}
351350

352351
std::shared_ptr<node_impl>
353-
graph_impl::add(const std::shared_ptr<graph_impl> &Impl,
354-
const std::vector<std::shared_ptr<node_impl>> &Dep) {
352+
graph_impl::add(const std::vector<std::shared_ptr<node_impl>> &Dep) {
355353
// Copy deps so we can modify them
356354
auto Deps = Dep;
357355

@@ -361,17 +359,16 @@ graph_impl::add(const std::shared_ptr<graph_impl> &Impl,
361359

362360
addDepsToNode(NodeImpl, Deps);
363361
// Add an event associated with this explicit node for mixed usage
364-
addEventForNode(Impl, std::make_shared<sycl::detail::event_impl>(), NodeImpl);
362+
addEventForNode(std::make_shared<sycl::detail::event_impl>(), NodeImpl);
365363
return NodeImpl;
366364
}
367365

368366
std::shared_ptr<node_impl>
369-
graph_impl::add(const std::shared_ptr<graph_impl> &Impl,
370-
std::function<void(handler &)> CGF,
367+
graph_impl::add(std::function<void(handler &)> CGF,
371368
const std::vector<sycl::detail::ArgDesc> &Args,
372369
const std::vector<std::shared_ptr<node_impl>> &Dep) {
373370
(void)Args;
374-
sycl::handler Handler{Impl};
371+
sycl::handler Handler{shared_from_this()};
375372
CGF(Handler);
376373

377374
if (Handler.getType() == sycl::detail::CGType::Barrier) {
@@ -394,7 +391,7 @@ graph_impl::add(const std::shared_ptr<graph_impl> &Impl,
394391
this->add(NodeType, std::move(Handler.impl->MGraphNodeCG), Dep);
395392
NodeImpl->MNDRangeUsed = Handler.impl->MNDRangeUsed;
396393
// Add an event associated with this explicit node for mixed usage
397-
addEventForNode(Impl, std::make_shared<sycl::detail::event_impl>(), NodeImpl);
394+
addEventForNode(std::make_shared<sycl::detail::event_impl>(), NodeImpl);
398395

399396
// Retrieve any dynamic parameters which have been registered in the CGF and
400397
// register the actual nodes with them.
@@ -414,8 +411,7 @@ graph_impl::add(const std::shared_ptr<graph_impl> &Impl,
414411
}
415412

416413
std::shared_ptr<node_impl>
417-
graph_impl::add(const std::shared_ptr<graph_impl> &Impl,
418-
const std::vector<sycl::detail::EventImplPtr> Events) {
414+
graph_impl::add(const std::vector<sycl::detail::EventImplPtr> Events) {
419415

420416
std::vector<std::shared_ptr<node_impl>> Deps;
421417

@@ -430,7 +426,7 @@ graph_impl::add(const std::shared_ptr<graph_impl> &Impl,
430426
}
431427
}
432428

433-
return this->add(Impl, Deps);
429+
return this->add(Deps);
434430
}
435431

436432
std::shared_ptr<node_impl>
@@ -594,7 +590,7 @@ void graph_impl::makeEdge(std::shared_ptr<node_impl> Src,
594590
}
595591

596592
// We need to add the edges first before checking for cycles
597-
Src->registerSuccessor(Dest, Src);
593+
Src->registerSuccessor(Dest);
598594

599595
// We can skip cycle checks if either Dest has no successors (cycle not
600596
// possible) or cycle checks have been disabled with the no_cycle_check
@@ -1061,7 +1057,7 @@ void exec_graph_impl::duplicateNodes() {
10611057
// register those as successors with the current copied node
10621058
for (auto &NextNode : OriginalNode->MSuccessors) {
10631059
auto Successor = NodesMap.at(NextNode.lock());
1064-
NodeCopy->registerSuccessor(Successor, NodeCopy);
1060+
NodeCopy->registerSuccessor(Successor);
10651061
}
10661062
}
10671063

@@ -1103,7 +1099,7 @@ void exec_graph_impl::duplicateNodes() {
11031099

11041100
for (auto &NextNode : SubgraphNode->MSuccessors) {
11051101
auto Successor = SubgraphNodesMap.at(NextNode.lock());
1106-
NodeCopy->registerSuccessor(Successor, NodeCopy);
1102+
NodeCopy->registerSuccessor(Successor);
11071103
}
11081104
}
11091105

@@ -1137,7 +1133,7 @@ void exec_graph_impl::duplicateNodes() {
11371133
// Add all input nodes from the subgraph as successors for this node
11381134
// instead
11391135
for (auto &Input : Inputs) {
1140-
PredNode->registerSuccessor(Input, PredNode);
1136+
PredNode->registerSuccessor(Input);
11411137
}
11421138
}
11431139

@@ -1157,7 +1153,7 @@ void exec_graph_impl::duplicateNodes() {
11571153
// Add all Output nodes from the subgraph as predecessors for this node
11581154
// instead
11591155
for (auto &Output : Outputs) {
1160-
Output->registerSuccessor(SuccNode, Output);
1156+
Output->registerSuccessor(SuccNode);
11611157
}
11621158
}
11631159

@@ -1531,7 +1527,7 @@ node modifiable_command_graph::addImpl(const std::vector<node> &Deps) {
15311527
}
15321528

15331529
graph_impl::WriteLock Lock(impl->MMutex);
1534-
std::shared_ptr<detail::node_impl> NodeImpl = impl->add(impl, DepImpls);
1530+
std::shared_ptr<detail::node_impl> NodeImpl = impl->add(DepImpls);
15351531
return sycl::detail::createSyclObjFromImpl<node>(NodeImpl);
15361532
}
15371533

@@ -1544,8 +1540,7 @@ node modifiable_command_graph::addImpl(std::function<void(handler &)> CGF,
15441540
}
15451541

15461542
graph_impl::WriteLock Lock(impl->MMutex);
1547-
std::shared_ptr<detail::node_impl> NodeImpl =
1548-
impl->add(impl, CGF, {}, DepImpls);
1543+
std::shared_ptr<detail::node_impl> NodeImpl = impl->add(CGF, {}, DepImpls);
15491544
return sycl::detail::createSyclObjFromImpl<node>(NodeImpl);
15501545
}
15511546

sycl/source/detail/graph_impl.hpp

Lines changed: 14 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ inline node_type getNodeTypeFromCG(sycl::detail::CGType CGType) {
7777
}
7878

7979
/// Implementation of node class from SYCL_EXT_ONEAPI_GRAPH.
80-
class node_impl {
80+
class node_impl : public std::enable_shared_from_this<node_impl> {
8181
public:
8282
using id_type = uint64_t;
8383

@@ -112,20 +112,15 @@ class node_impl {
112112

113113
/// Add successor to the node.
114114
/// @param Node Node to add as a successor.
115-
/// @param Prev Predecessor to \p node being added as successor.
116-
///
117-
/// \p Prev should be a shared_ptr to an instance of this object, but can't
118-
/// use a raw \p this pointer, so the extra \p Prev parameter is passed.
119-
void registerSuccessor(const std::shared_ptr<node_impl> &Node,
120-
const std::shared_ptr<node_impl> &Prev) {
115+
void registerSuccessor(const std::shared_ptr<node_impl> &Node) {
121116
if (std::find_if(MSuccessors.begin(), MSuccessors.end(),
122117
[Node](const std::weak_ptr<node_impl> &Ptr) {
123118
return Ptr.lock() == Node;
124119
}) != MSuccessors.end()) {
125120
return;
126121
}
127122
MSuccessors.push_back(Node);
128-
Node->registerPredecessor(Prev);
123+
Node->registerPredecessor(shared_from_this());
129124
}
130125

131126
/// Add predecessor to the node.
@@ -161,9 +156,10 @@ class node_impl {
161156
/// Construct a node from another node. This will perform a deep-copy of the
162157
/// command group object associated with this node.
163158
node_impl(node_impl &Other)
164-
: MSuccessors(Other.MSuccessors), MPredecessors(Other.MPredecessors),
165-
MCGType(Other.MCGType), MNodeType(Other.MNodeType),
166-
MCommandGroup(Other.getCGCopy()), MSubGraphImpl(Other.MSubGraphImpl) {}
159+
: enable_shared_from_this(Other), MSuccessors(Other.MSuccessors),
160+
MPredecessors(Other.MPredecessors), MCGType(Other.MCGType),
161+
MNodeType(Other.MNodeType), MCommandGroup(Other.getCGCopy()),
162+
MSubGraphImpl(Other.MSubGraphImpl) {}
167163

168164
/// Copy-assignment operator. This will perform a deep-copy of the
169165
/// command group object associated with this node.
@@ -901,32 +897,26 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
901897
const std::vector<std::shared_ptr<node_impl>> &Dep = {});
902898

903899
/// Create a CGF node in the graph.
904-
/// @param Impl Graph implementation pointer to create a handler with.
905900
/// @param CGF Command-group function to create node with.
906901
/// @param Args Node arguments.
907902
/// @param Dep Dependencies of the created node.
908903
/// @return Created node in the graph.
909904
std::shared_ptr<node_impl>
910-
add(const std::shared_ptr<graph_impl> &Impl,
911-
std::function<void(handler &)> CGF,
905+
add(std::function<void(handler &)> CGF,
912906
const std::vector<sycl::detail::ArgDesc> &Args,
913907
const std::vector<std::shared_ptr<node_impl>> &Dep = {});
914908

915909
/// Create an empty node in the graph.
916-
/// @param Impl Graph implementation pointer.
917910
/// @param Dep List of predecessor nodes.
918911
/// @return Created node in the graph.
919912
std::shared_ptr<node_impl>
920-
add(const std::shared_ptr<graph_impl> &Impl,
921-
const std::vector<std::shared_ptr<node_impl>> &Dep = {});
913+
add(const std::vector<std::shared_ptr<node_impl>> &Dep = {});
922914

923915
/// Create an empty node in the graph.
924-
/// @param Impl Graph implementation pointer.
925916
/// @param Events List of events associated to this node.
926917
/// @return Created node in the graph.
927918
std::shared_ptr<node_impl>
928-
add(const std::shared_ptr<graph_impl> &Impl,
929-
const std::vector<sycl::detail::EventImplPtr> Events);
919+
add(const std::vector<sycl::detail::EventImplPtr> Events);
930920

931921
/// Add a queue to the set of queues which are currently recording to this
932922
/// graph.
@@ -951,15 +941,12 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
951941
bool clearQueues();
952942

953943
/// Associate a sycl event with a node in the graph.
954-
/// @param GraphImpl shared_ptr to Graph impl associated with this event, aka
955-
/// this.
956944
/// @param EventImpl Event to associate with a node in map.
957945
/// @param NodeImpl Node to associate with event in map.
958-
void addEventForNode(std::shared_ptr<graph_impl> GraphImpl,
959-
std::shared_ptr<sycl::detail::event_impl> EventImpl,
946+
void addEventForNode(std::shared_ptr<sycl::detail::event_impl> EventImpl,
960947
std::shared_ptr<node_impl> NodeImpl) {
961948
if (!(EventImpl->getCommandGraph()))
962-
EventImpl->setCommandGraph(GraphImpl);
949+
EventImpl->setCommandGraph(shared_from_this());
963950
MEventsMap[EventImpl] = NodeImpl;
964951
}
965952

@@ -1238,12 +1225,10 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
12381225
void addRoot(const std::shared_ptr<node_impl> &Root);
12391226

12401227
/// Adds nodes to the exit nodes of this graph.
1241-
/// @param Impl Graph implementation pointer.
12421228
/// @param NodeList List of nodes from sub-graph in schedule order.
12431229
/// @return An empty node is used to schedule dependencies on this sub-graph.
12441230
std::shared_ptr<node_impl>
1245-
addNodesToExits(const std::shared_ptr<graph_impl> &Impl,
1246-
const std::list<std::shared_ptr<node_impl>> &NodeList);
1231+
addNodesToExits(const std::list<std::shared_ptr<node_impl>> &NodeList);
12471232

12481233
/// Adds dependencies for a new node, if it has no deps it will be
12491234
/// added as a root node.
@@ -1253,7 +1238,7 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
12531238
const std::vector<std::shared_ptr<node_impl>> &Deps) {
12541239
if (!Deps.empty()) {
12551240
for (auto &N : Deps) {
1256-
N->registerSuccessor(Node, N);
1241+
N->registerSuccessor(Node);
12571242
this->removeRoot(Node);
12581243
}
12591244
} else {

sycl/source/handler.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -573,7 +573,7 @@ event handler::finalize() {
573573
}
574574

575575
// Associate an event with this new node and return the event.
576-
GraphImpl->addEventForNode(GraphImpl, EventImpl, NodeImpl);
576+
GraphImpl->addEventForNode(EventImpl, NodeImpl);
577577

578578
NodeImpl->MNDRangeUsed = impl->MNDRangeUsed;
579579

0 commit comments

Comments
 (0)