@@ -80,41 +80,43 @@ void mlir::getForwardSlice(Value root, SetVector<Operation *> *forwardSlice,
8080 forwardSlice->insert (v.rbegin (), v.rend ());
8181}
8282
83- static void getBackwardSliceImpl (Operation *op,
84- SetVector<Operation *> *backwardSlice,
85- const BackwardSliceOptions &options) {
83+ static LogicalResult getBackwardSliceImpl (Operation *op,
84+ SetVector<Operation *> *backwardSlice,
85+ const BackwardSliceOptions &options) {
8686 if (!op || op->hasTrait <OpTrait::IsIsolatedFromAbove>())
87- return ;
87+ return success () ;
8888
8989 // Evaluate whether we should keep this def.
9090 // This is useful in particular to implement scoping; i.e. return the
9191 // transitive backwardSlice in the current scope.
9292 if (options.filter && !options.filter (op))
93- return ;
93+ return success () ;
9494
9595 auto processValue = [&](Value value) {
9696 if (auto *definingOp = value.getDefiningOp ()) {
9797 if (backwardSlice->count (definingOp) == 0 )
98- getBackwardSliceImpl (definingOp, backwardSlice, options);
98+ return getBackwardSliceImpl (definingOp, backwardSlice, options);
9999 } else if (auto blockArg = dyn_cast<BlockArgument>(value)) {
100100 if (options.omitBlockArguments )
101- return ;
101+ return success () ;
102102
103103 Block *block = blockArg.getOwner ();
104104 Operation *parentOp = block->getParentOp ();
105105 // TODO: determine whether we want to recurse backward into the other
106106 // blocks of parentOp, which are not technically backward unless they flow
107107 // into us. For now, just bail.
108108 if (parentOp && backwardSlice->count (parentOp) == 0 ) {
109- assert (parentOp->getNumRegions () == 1 &&
110- llvm::hasSingleElement (parentOp->getRegion (0 ).getBlocks ()));
111- getBackwardSliceImpl (parentOp, backwardSlice, options);
109+ if (parentOp->getNumRegions () == 1 &&
110+ llvm::hasSingleElement (parentOp->getRegion (0 ).getBlocks ())) {
111+ return getBackwardSliceImpl (parentOp, backwardSlice, options);
112+ }
112113 }
113- } else {
114- llvm_unreachable (" No definingOp and not a block argument." );
115114 }
115+ return failure ();
116116 };
117117
118+ bool succeeded = true ;
119+
118120 if (!options.omitUsesFromAbove ) {
119121 llvm::for_each (op->getRegions (), [&](Region ®ion) {
120122 // Walk this region recursively to collect the regions that descend from
@@ -125,36 +127,41 @@ static void getBackwardSliceImpl(Operation *op,
125127 region.walk ([&](Operation *op) {
126128 for (OpOperand &operand : op->getOpOperands ()) {
127129 if (!descendents.contains (operand.get ().getParentRegion ()))
128- processValue (operand.get ());
130+ if (!processValue (operand.get ()).succeeded ()) {
131+ return WalkResult::interrupt ();
132+ }
129133 }
134+ return WalkResult::advance ();
130135 });
131136 });
132137 }
133138 llvm::for_each (op->getOperands (), processValue);
134139
135140 backwardSlice->insert (op);
141+ return success (succeeded);
136142}
137143
138- void mlir::getBackwardSlice (Operation *op,
139- SetVector<Operation *> *backwardSlice,
140- const BackwardSliceOptions &options) {
141- getBackwardSliceImpl (op, backwardSlice, options);
144+ LogicalResult mlir::getBackwardSlice (Operation *op,
145+ SetVector<Operation *> *backwardSlice,
146+ const BackwardSliceOptions &options) {
147+ LogicalResult result = getBackwardSliceImpl (op, backwardSlice, options);
142148
143149 if (!options.inclusive ) {
144150 // Don't insert the top level operation, we just queried on it and don't
145151 // want it in the results.
146152 backwardSlice->remove (op);
147153 }
154+ return result;
148155}
149156
150- void mlir::getBackwardSlice (Value root, SetVector<Operation *> *backwardSlice,
151- const BackwardSliceOptions &options) {
157+ LogicalResult mlir::getBackwardSlice (Value root,
158+ SetVector<Operation *> *backwardSlice,
159+ const BackwardSliceOptions &options) {
152160 if (Operation *definingOp = root.getDefiningOp ()) {
153- getBackwardSlice (definingOp, backwardSlice, options);
154- return ;
161+ return getBackwardSlice (definingOp, backwardSlice, options);
155162 }
156163 Operation *bbAargOwner = cast<BlockArgument>(root).getOwner ()->getParentOp ();
157- getBackwardSlice (bbAargOwner, backwardSlice, options);
164+ return getBackwardSlice (bbAargOwner, backwardSlice, options);
158165}
159166
160167SetVector<Operation *>
@@ -170,7 +177,9 @@ mlir::getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions,
170177 auto *currentOp = (slice)[currentIndex];
171178 // Compute and insert the backwardSlice starting from currentOp.
172179 backwardSlice.clear ();
173- getBackwardSlice (currentOp, &backwardSlice, backwardSliceOptions);
180+ LogicalResult result =
181+ getBackwardSlice (currentOp, &backwardSlice, backwardSliceOptions);
182+ assert (result.succeeded ());
174183 slice.insert_range (backwardSlice);
175184
176185 // Compute and insert the forwardSlice starting from currentOp.
@@ -193,7 +202,8 @@ static bool dependsOnCarriedVals(Value value,
193202 sliceOptions.filter = [&](Operation *op) {
194203 return !ancestorOp->isAncestor (op);
195204 };
196- getBackwardSlice (value, &slice, sliceOptions);
205+ LogicalResult result = getBackwardSlice (value, &slice, sliceOptions);
206+ assert (result.succeeded ());
197207
198208 // Check that none of the operands of the operations in the backward slice are
199209 // loop iteration arguments, and neither is the value itself.
0 commit comments