diff --git a/mlir/include/mlir/Analysis/SliceAnalysis.h b/mlir/include/mlir/Analysis/SliceAnalysis.h index 3b731e8bb1c22..d082d2d9f758b 100644 --- a/mlir/include/mlir/Analysis/SliceAnalysis.h +++ b/mlir/include/mlir/Analysis/SliceAnalysis.h @@ -138,13 +138,17 @@ void getForwardSlice(Value root, SetVector *forwardSlice, /// Assuming all local orders match the numbering order: /// {1, 2, 5, 3, 4, 6} /// -void getBackwardSlice(Operation *op, SetVector *backwardSlice, - const BackwardSliceOptions &options = {}); +/// This function returns whether the backwards slice was able to be +/// successfully computed, and failure if it was unable to determine the slice. +LogicalResult getBackwardSlice(Operation *op, + SetVector *backwardSlice, + const BackwardSliceOptions &options = {}); /// Value-rooted version of `getBackwardSlice`. Return the union of all backward /// slices for the op defining or owning the value `root`. -void getBackwardSlice(Value root, SetVector *backwardSlice, - const BackwardSliceOptions &options = {}); +LogicalResult getBackwardSlice(Value root, + SetVector *backwardSlice, + const BackwardSliceOptions &options = {}); /// Iteratively computes backward slices and forward slices until /// a fixed point is reached. Returns an `SetVector` which diff --git a/mlir/include/mlir/Query/Matcher/SliceMatchers.h b/mlir/include/mlir/Query/Matcher/SliceMatchers.h index 1b0e4c32dbe94..40a39d23ca695 100644 --- a/mlir/include/mlir/Query/Matcher/SliceMatchers.h +++ b/mlir/include/mlir/Query/Matcher/SliceMatchers.h @@ -112,7 +112,8 @@ bool BackwardSliceMatcher::matches( } return true; }; - getBackwardSlice(rootOp, &backwardSlice, options); + LogicalResult result = getBackwardSlice(rootOp, &backwardSlice, options); + assert(result.succeeded() && "expected backward slice to succeed"); return options.inclusive ? backwardSlice.size() > 1 : backwardSlice.size() >= 1; } diff --git a/mlir/lib/Analysis/SliceAnalysis.cpp b/mlir/lib/Analysis/SliceAnalysis.cpp index 5aebb19e3a86e..12b9d3adb49fa 100644 --- a/mlir/lib/Analysis/SliceAnalysis.cpp +++ b/mlir/lib/Analysis/SliceAnalysis.cpp @@ -80,25 +80,25 @@ void mlir::getForwardSlice(Value root, SetVector *forwardSlice, forwardSlice->insert(v.rbegin(), v.rend()); } -static void getBackwardSliceImpl(Operation *op, - SetVector *backwardSlice, - const BackwardSliceOptions &options) { +static LogicalResult getBackwardSliceImpl(Operation *op, + SetVector *backwardSlice, + const BackwardSliceOptions &options) { if (!op || op->hasTrait()) - return; + return success(); // Evaluate whether we should keep this def. // This is useful in particular to implement scoping; i.e. return the // transitive backwardSlice in the current scope. if (options.filter && !options.filter(op)) - return; + return success(); auto processValue = [&](Value value) { if (auto *definingOp = value.getDefiningOp()) { if (backwardSlice->count(definingOp) == 0) - getBackwardSliceImpl(definingOp, backwardSlice, options); + return getBackwardSliceImpl(definingOp, backwardSlice, options); } else if (auto blockArg = dyn_cast(value)) { if (options.omitBlockArguments) - return; + return success(); Block *block = blockArg.getOwner(); Operation *parentOp = block->getParentOp(); @@ -106,15 +106,17 @@ static void getBackwardSliceImpl(Operation *op, // blocks of parentOp, which are not technically backward unless they flow // into us. For now, just bail. if (parentOp && backwardSlice->count(parentOp) == 0) { - assert(parentOp->getNumRegions() == 1 && - llvm::hasSingleElement(parentOp->getRegion(0).getBlocks())); - getBackwardSliceImpl(parentOp, backwardSlice, options); + if (parentOp->getNumRegions() == 1 && + llvm::hasSingleElement(parentOp->getRegion(0).getBlocks())) { + return getBackwardSliceImpl(parentOp, backwardSlice, options); + } } - } else { - llvm_unreachable("No definingOp and not a block argument."); } + return failure(); }; + bool succeeded = true; + if (!options.omitUsesFromAbove) { llvm::for_each(op->getRegions(), [&](Region ®ion) { // Walk this region recursively to collect the regions that descend from @@ -125,36 +127,41 @@ static void getBackwardSliceImpl(Operation *op, region.walk([&](Operation *op) { for (OpOperand &operand : op->getOpOperands()) { if (!descendents.contains(operand.get().getParentRegion())) - processValue(operand.get()); + if (!processValue(operand.get()).succeeded()) { + return WalkResult::interrupt(); + } } + return WalkResult::advance(); }); }); } llvm::for_each(op->getOperands(), processValue); backwardSlice->insert(op); + return success(succeeded); } -void mlir::getBackwardSlice(Operation *op, - SetVector *backwardSlice, - const BackwardSliceOptions &options) { - getBackwardSliceImpl(op, backwardSlice, options); +LogicalResult mlir::getBackwardSlice(Operation *op, + SetVector *backwardSlice, + const BackwardSliceOptions &options) { + LogicalResult result = getBackwardSliceImpl(op, backwardSlice, options); if (!options.inclusive) { // Don't insert the top level operation, we just queried on it and don't // want it in the results. backwardSlice->remove(op); } + return result; } -void mlir::getBackwardSlice(Value root, SetVector *backwardSlice, - const BackwardSliceOptions &options) { +LogicalResult mlir::getBackwardSlice(Value root, + SetVector *backwardSlice, + const BackwardSliceOptions &options) { if (Operation *definingOp = root.getDefiningOp()) { - getBackwardSlice(definingOp, backwardSlice, options); - return; + return getBackwardSlice(definingOp, backwardSlice, options); } Operation *bbAargOwner = cast(root).getOwner()->getParentOp(); - getBackwardSlice(bbAargOwner, backwardSlice, options); + return getBackwardSlice(bbAargOwner, backwardSlice, options); } SetVector @@ -170,7 +177,9 @@ mlir::getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions, auto *currentOp = (slice)[currentIndex]; // Compute and insert the backwardSlice starting from currentOp. backwardSlice.clear(); - getBackwardSlice(currentOp, &backwardSlice, backwardSliceOptions); + LogicalResult result = + getBackwardSlice(currentOp, &backwardSlice, backwardSliceOptions); + assert(result.succeeded()); slice.insert_range(backwardSlice); // Compute and insert the forwardSlice starting from currentOp. @@ -193,7 +202,8 @@ static bool dependsOnCarriedVals(Value value, sliceOptions.filter = [&](Operation *op) { return !ancestorOp->isAncestor(op); }; - getBackwardSlice(value, &slice, sliceOptions); + LogicalResult result = getBackwardSlice(value, &slice, sliceOptions); + assert(result.succeeded()); // Check that none of the operands of the operations in the backward slice are // loop iteration arguments, and neither is the value itself. diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp index 8b16da387457d..0ec9ddc25ff8d 100644 --- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp +++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp @@ -317,7 +317,9 @@ getSliceContract(Operation *op, auto *currentOp = (slice)[currentIndex]; // Compute and insert the backwardSlice starting from currentOp. backwardSlice.clear(); - getBackwardSlice(currentOp, &backwardSlice, backwardSliceOptions); + LogicalResult result = + getBackwardSlice(currentOp, &backwardSlice, backwardSliceOptions); + assert(result.succeeded() && "expected a backward slice"); slice.insert_range(backwardSlice); // Compute and insert the forwardSlice starting from currentOp. diff --git a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp index d33a17af63459..2c98bd3ba93af 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp @@ -124,10 +124,13 @@ static void computeBackwardSlice(tensor::PadOp padOp, getUsedValuesDefinedAbove(padOp.getRegion(), padOp.getRegion(), valuesDefinedAbove); for (Value v : valuesDefinedAbove) { - getBackwardSlice(v, &backwardSlice, sliceOptions); + LogicalResult result = getBackwardSlice(v, &backwardSlice, sliceOptions); + assert(result.succeeded() && "expected a backward slice"); } // Then, add the backward slice from padOp itself. - getBackwardSlice(padOp.getOperation(), &backwardSlice, sliceOptions); + LogicalResult result = + getBackwardSlice(padOp.getOperation(), &backwardSlice, sliceOptions); + assert(result.succeeded() && "expected a backward slice"); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp index 75dbe0becf80d..1046f5798ecd4 100644 --- a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp +++ b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp @@ -290,8 +290,10 @@ static void getPipelineStages( }); options.inclusive = true; for (Operation &op : forOp.getBody()->getOperations()) { - if (stage0Ops.contains(&op)) - getBackwardSlice(&op, &dependencies, options); + if (stage0Ops.contains(&op)) { + LogicalResult result = getBackwardSlice(&op, &dependencies, options); + assert(result.succeeded() && "expected a backward slice"); + } } for (Operation &op : forOp.getBody()->getOperations()) { diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp index 719e2c6fa459e..9e3d3f8b10a13 100644 --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -1772,7 +1772,8 @@ checkAssumptionForLoop(Operation *loopOp, Operation *consumerOp, }; llvm::SetVector slice; for (auto operand : consumerOp->getOperands()) { - getBackwardSlice(operand, &slice, options); + LogicalResult result = getBackwardSlice(operand, &slice, options); + assert(result.succeeded() && "expected a backward slice"); } if (!slice.empty()) { diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp index 4985d718c1780..c136ff92255cd 100644 --- a/mlir/lib/Transforms/Utils/RegionUtils.cpp +++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp @@ -1094,7 +1094,8 @@ LogicalResult mlir::moveOperationDependencies(RewriterBase &rewriter, return !dominance.properlyDominates(sliceBoundaryOp, insertionPoint); }; llvm::SetVector slice; - getBackwardSlice(op, &slice, options); + LogicalResult result = getBackwardSlice(op, &slice, options); + assert(result.succeeded() && "expected a backward slice"); // If the slice contains `insertionPoint` cannot move the dependencies. if (slice.contains(insertionPoint)) { @@ -1159,7 +1160,8 @@ LogicalResult mlir::moveValueDefinitions(RewriterBase &rewriter, }; llvm::SetVector slice; for (auto value : prunedValues) { - getBackwardSlice(value, &slice, options); + LogicalResult result = getBackwardSlice(value, &slice, options); + assert(result.succeeded() && "expected a backward slice"); } // If the slice contains `insertionPoint` cannot move the dependencies. diff --git a/mlir/test/lib/Dialect/Affine/TestVectorizationUtils.cpp b/mlir/test/lib/Dialect/Affine/TestVectorizationUtils.cpp index f26058f30ad7b..145acd99e6616 100644 --- a/mlir/test/lib/Dialect/Affine/TestVectorizationUtils.cpp +++ b/mlir/test/lib/Dialect/Affine/TestVectorizationUtils.cpp @@ -154,7 +154,9 @@ void VectorizerTestPass::testBackwardSlicing(llvm::raw_ostream &outs) { patternTestSlicingOps().match(f, &matches); for (auto m : matches) { SetVector backwardSlice; - getBackwardSlice(m.getMatchedOperation(), &backwardSlice); + LogicalResult result = + getBackwardSlice(m.getMatchedOperation(), &backwardSlice); + assert(result.succeeded() && "expected a backward slice"); outs << "\nmatched: " << *m.getMatchedOperation() << " backward static slice: "; for (auto *op : backwardSlice) diff --git a/mlir/test/lib/IR/TestSlicing.cpp b/mlir/test/lib/IR/TestSlicing.cpp index e99d5976d6d9d..ad99be2b9d0c9 100644 --- a/mlir/test/lib/IR/TestSlicing.cpp +++ b/mlir/test/lib/IR/TestSlicing.cpp @@ -41,7 +41,8 @@ static LogicalResult createBackwardSliceFunction(Operation *op, options.omitBlockArguments = omitBlockArguments; // TODO: Make this default. options.omitUsesFromAbove = false; - getBackwardSlice(op, &slice, options); + LogicalResult result = getBackwardSlice(op, &slice, options); + assert(result.succeeded() && "expected a backward slice"); for (Operation *slicedOp : slice) builder.clone(*slicedOp, mapper); builder.create(loc);