diff --git a/mlir/include/mlir/Analysis/SliceAnalysis.h b/mlir/include/mlir/Analysis/SliceAnalysis.h index 18349d071bb2e..61ee9ffbd25ff 100644 --- a/mlir/include/mlir/Analysis/SliceAnalysis.h +++ b/mlir/include/mlir/Analysis/SliceAnalysis.h @@ -142,6 +142,9 @@ void getForwardSlice(Value root, SetVector *forwardSlice, /// /// This function returns whether the backwards slice was able to be /// successfully computed, and failure if it was unable to determine the slice. +/// This function will presently return failure if a value to process is a +/// blockargument whose parent op has more than one region, or a region with +/// more than one block. LogicalResult getBackwardSlice(Operation *op, SetVector *backwardSlice, const BackwardSliceOptions &options = {}); diff --git a/mlir/lib/Analysis/SliceAnalysis.cpp b/mlir/lib/Analysis/SliceAnalysis.cpp index 12dff19ed31d3..c831a534462f9 100644 --- a/mlir/lib/Analysis/SliceAnalysis.cpp +++ b/mlir/lib/Analysis/SliceAnalysis.cpp @@ -136,11 +136,14 @@ static LogicalResult getBackwardSliceImpl(Operation *op, // blocks of parentOp, which are not technically backward unless they flow // into us. For now, just bail. if (parentOp && backwardSlice->count(parentOp) == 0) { - if (!parentOp->hasTrait() && - parentOp->getNumRegions() == 1 && - parentOp->getRegion(0).hasOneBlock()) { + if (parentOp->hasTrait()) { + return success(); + } else if (parentOp->getNumRegions() == 1 && + parentOp->getRegion(0).hasOneBlock()) { return getBackwardSliceImpl(parentOp, visited, backwardSlice, options); + } else { + return failure(); } } } else { @@ -159,18 +162,25 @@ static LogicalResult getBackwardSliceImpl(Operation *op, SmallPtrSet descendents; region.walk( [&](Region *childRegion) { descendents.insert(childRegion); }); - region.walk([&](Operation *op) { - for (OpOperand &operand : op->getOpOperands()) { - if (!descendents.contains(operand.get().getParentRegion())) - if (!processValue(operand.get()).succeeded()) { - return WalkResult::interrupt(); - } - } - return WalkResult::advance(); - }); + if (region + .walk([&](Operation *op) { + for (OpOperand &operand : op->getOpOperands()) { + if (!descendents.contains(operand.get().getParentRegion())) + if (!processValue(operand.get()).succeeded()) { + return WalkResult::interrupt(); + } + } + return WalkResult::advance(); + }) + .wasInterrupted()) + succeeded = false; }); } - llvm::for_each(op->getOperands(), processValue); + llvm::for_each(op->getOperands(), [&](Value value) { + if (!processValue(value).succeeded()) { + succeeded = false; + } + }); backwardSlice->insert(op); return success(succeeded);