88
99#define __SYCL_GRAPH_IMPL_CPP
1010
11+ #include < stack>
1112#include < detail/graph_impl.hpp>
1213#include < detail/handler_impl.hpp>
1314#include < detail/kernel_arg_mask.hpp>
@@ -31,64 +32,47 @@ namespace experimental {
3132namespace detail {
3233
3334namespace {
34- // / Visits a node on the graph and it's successors recursively in a depth-first
35- // / approach.
36- // / @param[in] Node The current node being visited.
37- // / @param[in,out] VisitedNodes A set of unique nodes which have already been
38- // / visited.
39- // / @param[in] NodeStack Stack of nodes which are currently being visited on the
40- // / current path through the graph.
41- // / @param[in] NodeFunc The function object to be run on each node. A return
42- // / value of true indicates the search should be ended immediately and the
43- // / function will return.
44- // / @return True if the search should end immediately, false if not.
45- bool visitNodeDepthFirst (
46- std::shared_ptr<node_impl> Node,
47- std::set<std::shared_ptr<node_impl>> &VisitedNodes,
48- std::deque<std::shared_ptr<node_impl>> &NodeStack,
49- std::function<bool (std::shared_ptr<node_impl> &,
50- std::deque<std::shared_ptr<node_impl>> &)>
51- NodeFunc) {
52- auto EarlyReturn = NodeFunc (Node, NodeStack);
53- if (EarlyReturn) {
54- return true ;
55- }
56- NodeStack.push_back (Node);
57- Node->MVisited = true ;
58- VisitedNodes.emplace (Node);
59- for (auto &Successor : Node->MSuccessors ) {
60- if (visitNodeDepthFirst (Successor.lock (), VisitedNodes, NodeStack,
61- NodeFunc)) {
62- return true ;
63- }
64- }
65- NodeStack.pop_back ();
66- return false ;
67- }
68-
69- // / Recursively add nodes to execution stack.
70- // / @param NodeImpl Node to schedule.
71- // / @param Schedule Execution ordering to add node to.
72- // / @param PartitionBounded If set to true, the topological sort is stopped at
73- // / partition borders. Hence, nodes belonging to a partition different from the
74- // / NodeImpl partition are not processed.
75- void sortTopological (std::shared_ptr<node_impl> NodeImpl,
76- std::list<std::shared_ptr<node_impl>> &Schedule,
77- bool PartitionBounded = false ) {
78- for (auto &Succ : NodeImpl->MSuccessors ) {
79- auto NextNode = Succ.lock ();
80- if (PartitionBounded &&
81- (NextNode->MPartitionNum != NodeImpl->MPartitionNum )) {
82- continue ;
83- }
84- // Check if we've already scheduled this node
85- if (std::find (Schedule.begin (), Schedule.end (), NextNode) ==
86- Schedule.end ()) {
87- sortTopological (NextNode, Schedule, PartitionBounded);
35+ // / Topologically sorts the graph in order to schedule nodes for execution.
36+ // / This implementation is based on Kahn's algorithm which uses a Breadth-first
37+ // / search approach.
38+ // / For performance reasons, this function uses the MTotalVisitedEdges
39+ // / member variable of the node_impl class. It's the caller responsibility to
40+ // / make sure that MTotalVisitedEdges is set to 0 for all nodes in the graph
41+ // / before calling this function.
42+ // / @param[in] Roots List of root nodes.
43+ // / @param[out] SortedNodes The graph nodes sorted in topological order.
44+ // / @param[in] PartitionBounded If set to true, the topological sort is stopped
45+ // / at partition borders. Hence, nodes belonging to a partition different from
46+ // / the NodeImpl partition are not processed.
47+ void sortTopological (std::set<std::weak_ptr<node_impl>,
48+ std::owner_less<std::weak_ptr<node_impl>>> &Roots,
49+ std::list<std::shared_ptr<node_impl>> &SortedNodes,
50+ bool PartitionBounded) {
51+ std::stack<std::weak_ptr<node_impl>> Source;
52+
53+ for (auto &Node : Roots) {
54+ Source.push (Node);
55+ }
56+
57+ while (!Source.empty ()) {
58+ auto Node = Source.top ().lock ();
59+ Source.pop ();
60+ SortedNodes.push_back (Node);
61+
62+ for (auto &SuccWP : Node->MSuccessors ) {
63+ auto Succ = SuccWP.lock ();
64+
65+ if (PartitionBounded && (Succ->MPartitionNum != Node->MPartitionNum )) {
66+ continue ;
67+ }
68+
69+ auto &TotalVisitedEdges = Succ->MTotalVisitedEdges ;
70+ ++TotalVisitedEdges;
71+ if (TotalVisitedEdges == Succ->MPredecessors .size ()) {
72+ Source.push (Succ);
73+ }
8874 }
8975 }
90-
91- Schedule.push_front (NodeImpl);
9276}
9377
9478// / Propagates the partition number `PartitionNum` to predecessors.
@@ -180,9 +164,9 @@ std::vector<node> createNodesFromImpls(
180164
181165void partition::schedule () {
182166 if (MSchedule.empty ()) {
183- for ( auto &Node : MRoots) {
184- sortTopological (Node. lock (), MSchedule, true );
185- }
167+ // There is no need to reset MTotalVisitedEdges before calling
168+ // sortTopological because this function is only called once per partition.
169+ sortTopological (MRoots, MSchedule, true );
186170 }
187171}
188172
@@ -311,6 +295,7 @@ static void checkGraphPropertiesAndThrow(const property_list &Properties) {
311295#define __SYCL_MANUALLY_DEFINED_PROP (NS_QUALIFIER, PROP_NAME )
312296 switch (PropertyKind) {
313297#include < sycl/ext/oneapi/experimental/detail/properties/graph_properties.def>
298+
314299 default :
315300 return false ;
316301 }
@@ -627,44 +612,20 @@ bool graph_impl::clearQueues() {
627612 return AnyQueuesCleared;
628613}
629614
630- void graph_impl::searchDepthFirst (
631- std::function<bool (std::shared_ptr<node_impl> &,
632- std::deque<std::shared_ptr<node_impl>> &)>
633- NodeFunc) {
634- // Track nodes visited during the search which can be used by NodeFunc in
635- // depth first search queries. Currently unusued but is an
636- // integral part of depth first searches.
637- std::set<std::shared_ptr<node_impl>> VisitedNodes;
615+ bool graph_impl::checkForCycles () {
616+ std::list<std::shared_ptr<node_impl>> SortedNodes;
617+ sortTopological (MRoots, SortedNodes, false );
638618
639- for (auto &Root : MRoots) {
640- std::deque<std::shared_ptr<node_impl>> NodeStack;
641- if (visitNodeDepthFirst (Root.lock (), VisitedNodes, NodeStack, NodeFunc)) {
642- break ;
643- }
644- }
619+ // If after a topological sort, not all the nodes in the graph are sorted,
620+ // then there must be at least one cycle in the graph. This is guaranteed
621+ // by Kahn's algorithm, which sortTopological() implements.
622+ bool CycleFound = SortedNodes.size () != MNodeStorage.size ();
645623
646- // Reset the visited status of all nodes encountered in the search .
647- for (auto &Node : VisitedNodes ) {
648- Node->MVisited = false ;
624+ // Reset the MTotalVisitedEdges variable to prepare for the next cycle check .
625+ for (auto &Node : MNodeStorage ) {
626+ Node->MTotalVisitedEdges = 0 ;
649627 }
650- }
651628
652- bool graph_impl::checkForCycles () {
653- // Using a depth-first search and checking if we vist a node more than once in
654- // the current path to identify if there are cycles.
655- bool CycleFound = false ;
656- auto CheckFunc = [&](std::shared_ptr<node_impl> &Node,
657- std::deque<std::shared_ptr<node_impl>> &NodeStack) {
658- // If the current node has previously been found in the current path through
659- // the graph then we have a cycle and we end the search early.
660- if (std::find (NodeStack.begin (), NodeStack.end (), Node) !=
661- NodeStack.end ()) {
662- CycleFound = true ;
663- return true ;
664- }
665- return false ;
666- };
667- searchDepthFirst (CheckFunc);
668629 return CycleFound;
669630}
670631
@@ -698,19 +659,31 @@ void graph_impl::makeEdge(std::shared_ptr<node_impl> Src,
698659 " Dest must be a node inside the graph." );
699660 }
700661
662+ bool DestWasGraphRoot = Dest->MPredecessors .size () == 0 ;
663+
701664 // We need to add the edges first before checking for cycles
702665 Src->registerSuccessor (Dest);
703666
667+ bool DestLostRootStatus = DestWasGraphRoot && Dest->MPredecessors .size () == 1 ;
668+ if (DestLostRootStatus) {
669+ // Dest is no longer a Root node, so we need to remove it from MRoots.
670+ MRoots.erase (Dest);
671+ }
672+
704673 // We can skip cycle checks if either Dest has no successors (cycle not
705674 // possible) or cycle checks have been disabled with the no_cycle_check
706675 // property;
707676 if (Dest->MSuccessors .empty () || !MSkipCycleChecks) {
708677 bool CycleFound = checkForCycles ();
709678
710679 if (CycleFound) {
711- // Remove the added successor and predecessor
680+ // Remove the added successor and predecessor.
712681 Src->MSuccessors .pop_back ();
713682 Dest->MPredecessors .pop_back ();
683+ if (DestLostRootStatus) {
684+ // Add Dest back into MRoots.
685+ MRoots.insert (Dest);
686+ }
714687
715688 throw sycl::exception (make_error_code (sycl::errc::invalid),
716689 " Command graphs cannot contain cycles." );
0 commit comments