2626using namespace mlir ;
2727
2828static void
29- getForwardSliceImpl (Operation *op, SetVector<Operation *> *forwardSlice,
29+ getForwardSliceImpl (Operation *op, DenseSet<Operation *> &visited,
30+ SetVector<Operation *> *forwardSlice,
3031 const SliceOptions::TransitiveFilter &filter = nullptr ) {
3132 if (!op)
3233 return ;
@@ -40,20 +41,41 @@ getForwardSliceImpl(Operation *op, SetVector<Operation *> *forwardSlice,
4041 for (Region ®ion : op->getRegions ())
4142 for (Block &block : region)
4243 for (Operation &blockOp : block)
43- if (forwardSlice->count (&blockOp) == 0 )
44- getForwardSliceImpl (&blockOp, forwardSlice, filter);
45- for (Value result : op->getResults ()) {
46- for (Operation *userOp : result.getUsers ())
47- if (forwardSlice->count (userOp) == 0 )
48- getForwardSliceImpl (userOp, forwardSlice, filter);
49- }
44+ if (forwardSlice->count (&blockOp) == 0 ) {
45+ // We don't have to check if the 'blockOp' is already visited because
46+ // there cannot be a traversal path from this nested op to the parent
47+ // and thus a cycle cannot be closed here. We still have to mark it
48+ // as visited to stop before visiting this operation again if it is
49+ // part of a cycle.
50+ visited.insert (&blockOp);
51+ getForwardSliceImpl (&blockOp, visited, forwardSlice, filter);
52+ visited.erase (&blockOp);
53+ }
54+
55+ for (Value result : op->getResults ())
56+ for (Operation *userOp : result.getUsers ()) {
57+ // A cycle can only occur within a basic block (not across regions or
58+ // basic blocks) because the parent region must be a graph region, graph
59+ // regions are restricted to always have 0 or 1 blocks, and there cannot
60+ // be a def-use edge from a nested operation to an operation in an
61+ // ancestor region. Therefore, we don't have to but may use the same
62+ // 'visited' set across regions/blocks as long as we remove operations
63+ // from the set again when the DFS traverses back from the leaf to the
64+ // root.
65+ if (forwardSlice->count (userOp) == 0 && visited.insert (userOp).second )
66+ getForwardSliceImpl (userOp, visited, forwardSlice, filter);
67+
68+ visited.erase (userOp);
69+ }
5070
5171 forwardSlice->insert (op);
5272}
5373
5474void mlir::getForwardSlice (Operation *op, SetVector<Operation *> *forwardSlice,
5575 const ForwardSliceOptions &options) {
56- getForwardSliceImpl (op, forwardSlice, options.filter );
76+ DenseSet<Operation *> visited;
77+ visited.insert (op);
78+ getForwardSliceImpl (op, visited, forwardSlice, options.filter );
5779 if (!options.inclusive ) {
5880 // Don't insert the top level operation, we just queried on it and don't
5981 // want it in the results.
@@ -69,8 +91,12 @@ void mlir::getForwardSlice(Operation *op, SetVector<Operation *> *forwardSlice,
6991
7092void mlir::getForwardSlice (Value root, SetVector<Operation *> *forwardSlice,
7193 const SliceOptions &options) {
72- for (Operation *user : root.getUsers ())
73- getForwardSliceImpl (user, forwardSlice, options.filter );
94+ DenseSet<Operation *> visited;
95+ for (Operation *user : root.getUsers ()) {
96+ visited.insert (user);
97+ getForwardSliceImpl (user, visited, forwardSlice, options.filter );
98+ visited.erase (user);
99+ }
74100
75101 // Reverse to get back the actual topological order.
76102 // std::reverse does not work out of the box on SetVector and I want an
@@ -80,6 +106,7 @@ void mlir::getForwardSlice(Value root, SetVector<Operation *> *forwardSlice,
80106}
81107
82108static LogicalResult getBackwardSliceImpl (Operation *op,
109+ DenseSet<Operation *> &visited,
83110 SetVector<Operation *> *backwardSlice,
84111 const BackwardSliceOptions &options) {
85112 if (!op || op->hasTrait <OpTrait::IsIsolatedFromAbove>())
@@ -93,8 +120,12 @@ static LogicalResult getBackwardSliceImpl(Operation *op,
93120
94121 auto processValue = [&](Value value) {
95122 if (auto *definingOp = value.getDefiningOp ()) {
96- if (backwardSlice->count (definingOp) == 0 )
97- return getBackwardSliceImpl (definingOp, backwardSlice, options);
123+ if (backwardSlice->count (definingOp) == 0 &&
124+ visited.insert (definingOp).second )
125+ return getBackwardSliceImpl (definingOp, visited, backwardSlice,
126+ options);
127+
128+ visited.erase (definingOp);
98129 } else if (auto blockArg = dyn_cast<BlockArgument>(value)) {
99130 if (options.omitBlockArguments )
100131 return success ();
@@ -107,7 +138,8 @@ static LogicalResult getBackwardSliceImpl(Operation *op,
107138 if (parentOp && backwardSlice->count (parentOp) == 0 ) {
108139 if (parentOp->getNumRegions () == 1 &&
109140 llvm::hasSingleElement (parentOp->getRegion (0 ).getBlocks ())) {
110- return getBackwardSliceImpl (parentOp, backwardSlice, options);
141+ return getBackwardSliceImpl (parentOp, visited, backwardSlice,
142+ options);
111143 }
112144 }
113145 } else {
@@ -145,7 +177,10 @@ static LogicalResult getBackwardSliceImpl(Operation *op,
145177LogicalResult mlir::getBackwardSlice (Operation *op,
146178 SetVector<Operation *> *backwardSlice,
147179 const BackwardSliceOptions &options) {
148- LogicalResult result = getBackwardSliceImpl (op, backwardSlice, options);
180+ DenseSet<Operation *> visited;
181+ visited.insert (op);
182+ LogicalResult result =
183+ getBackwardSliceImpl (op, visited, backwardSlice, options);
149184
150185 if (!options.inclusive ) {
151186 // Don't insert the top level operation, we just queried on it and don't
0 commit comments