Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions mlir/include/mlir/Analysis/SliceAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -138,13 +138,15 @@ void getForwardSlice(Value root, SetVector<Operation *> *forwardSlice,
/// Assuming all local orders match the numbering order:
/// {1, 2, 5, 3, 4, 6}
///
void getBackwardSlice(Operation *op, SetVector<Operation *> *backwardSlice,
const BackwardSliceOptions &options = {});
LogicalResult getBackwardSlice(Operation *op,
SetVector<Operation *> *backwardSlice,
const BackwardSliceOptions &options = {});

/// Value-rooted version of `getBackwardSlice`. Return the union of all backward
/// slices for the op defining or owning the value `root`.
void getBackwardSlice(Value root, SetVector<Operation *> *backwardSlice,
const BackwardSliceOptions &options = {});
LogicalResult getBackwardSlice(Value root,
SetVector<Operation *> *backwardSlice,
const BackwardSliceOptions &options = {});

/// Iteratively computes backward slices and forward slices until
/// a fixed point is reached. Returns an `SetVector<Operation *>` which
Expand Down
3 changes: 2 additions & 1 deletion mlir/include/mlir/Query/Matcher/SliceMatchers.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,8 @@ bool BackwardSliceMatcher<Matcher>::matches(
}
return true;
};
getBackwardSlice(rootOp, &backwardSlice, options);
auto result = getBackwardSlice(rootOp, &backwardSlice, options);
assert(result.succeeded());
return options.inclusive ? backwardSlice.size() > 1
: backwardSlice.size() >= 1;
}
Expand Down
52 changes: 32 additions & 20 deletions mlir/lib/Analysis/SliceAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,22 +80,25 @@ void mlir::getForwardSlice(Value root, SetVector<Operation *> *forwardSlice,
forwardSlice->insert(v.rbegin(), v.rend());
}

static void getBackwardSliceImpl(Operation *op,
SetVector<Operation *> *backwardSlice,
const BackwardSliceOptions &options) {
static LogicalResult getBackwardSliceImpl(Operation *op,
SetVector<Operation *> *backwardSlice,
const BackwardSliceOptions &options) {
if (!op || op->hasTrait<OpTrait::IsIsolatedFromAbove>())
return;
return success();

// Evaluate whether we should keep this def.
// This is useful in particular to implement scoping; i.e. return the
// transitive backwardSlice in the current scope.
if (options.filter && !options.filter(op))
return;
return success();

bool succeeded = true;

auto processValue = [&](Value value) {
if (auto *definingOp = value.getDefiningOp()) {
if (backwardSlice->count(definingOp) == 0)
getBackwardSliceImpl(definingOp, backwardSlice, options);
succeeded &= getBackwardSliceImpl(definingOp, backwardSlice, options)
.succeeded();
} else if (auto blockArg = dyn_cast<BlockArgument>(value)) {
if (options.omitBlockArguments)
return;
Expand All @@ -106,9 +109,13 @@ static void getBackwardSliceImpl(Operation *op,
// blocks of parentOp, which are not technically backward unless they flow
// into us. For now, just bail.
if (parentOp && backwardSlice->count(parentOp) == 0) {
assert(parentOp->getNumRegions() == 1 &&
llvm::hasSingleElement(parentOp->getRegion(0).getBlocks()));
getBackwardSliceImpl(parentOp, backwardSlice, options);
if (parentOp->getNumRegions() == 1 &&
llvm::hasSingleElement(parentOp->getRegion(0).getBlocks())) {
succeeded &= getBackwardSliceImpl(parentOp, backwardSlice, options)
.succeeded();
} else {
succeeded = false;
}
}
} else {
llvm_unreachable("No definingOp and not a block argument.");
Expand All @@ -133,28 +140,30 @@ static void getBackwardSliceImpl(Operation *op,
llvm::for_each(op->getOperands(), processValue);

backwardSlice->insert(op);
return success(succeeded);
}

void mlir::getBackwardSlice(Operation *op,
SetVector<Operation *> *backwardSlice,
const BackwardSliceOptions &options) {
getBackwardSliceImpl(op, backwardSlice, options);
LogicalResult
mlir::getBackwardSlice(Operation *op, SetVector<Operation *> *backwardSlice,
const BackwardSliceOptions &options) {
LogicalResult result = getBackwardSliceImpl(op, backwardSlice, options);

if (!options.inclusive) {
// Don't insert the top level operation, we just queried on it and don't
// want it in the results.
backwardSlice->remove(op);
}
return result;
}

void mlir::getBackwardSlice(Value root, SetVector<Operation *> *backwardSlice,
const BackwardSliceOptions &options) {
LogicalResult mlir::getBackwardSlice(Value root,
SetVector<Operation *> *backwardSlice,
const BackwardSliceOptions &options) {
if (Operation *definingOp = root.getDefiningOp()) {
getBackwardSlice(definingOp, backwardSlice, options);
return;
return getBackwardSlice(definingOp, backwardSlice, options);
}
Operation *bbAargOwner = cast<BlockArgument>(root).getOwner()->getParentOp();
getBackwardSlice(bbAargOwner, backwardSlice, options);
return getBackwardSlice(bbAargOwner, backwardSlice, options);
}

SetVector<Operation *>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may want to return FailureOr<SetVector<Operation *>> here for consistency with the rest of the API.

Expand All @@ -170,7 +179,9 @@ mlir::getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions,
auto *currentOp = (slice)[currentIndex];
// Compute and insert the backwardSlice starting from currentOp.
backwardSlice.clear();
getBackwardSlice(currentOp, &backwardSlice, backwardSliceOptions);
auto result =
getBackwardSlice(currentOp, &backwardSlice, backwardSliceOptions);
assert(result.succeeded());
slice.insert_range(backwardSlice);

// Compute and insert the forwardSlice starting from currentOp.
Expand All @@ -193,7 +204,8 @@ static bool dependsOnCarriedVals(Value value,
sliceOptions.filter = [&](Operation *op) {
return !ancestorOp->isAncestor(op);
};
getBackwardSlice(value, &slice, sliceOptions);
auto result = getBackwardSlice(value, &slice, sliceOptions);
assert(result.succeeded());

// Check that none of the operands of the operations in the backward slice are
// loop iteration arguments, and neither is the value itself.
Expand Down
4 changes: 3 additions & 1 deletion mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,9 @@ getSliceContract(Operation *op,
auto *currentOp = (slice)[currentIndex];
// Compute and insert the backwardSlice starting from currentOp.
backwardSlice.clear();
getBackwardSlice(currentOp, &backwardSlice, backwardSliceOptions);
auto result =
getBackwardSlice(currentOp, &backwardSlice, backwardSliceOptions);
assert(result.succeeded());
slice.insert_range(backwardSlice);

// Compute and insert the forwardSlice starting from currentOp.
Expand Down
7 changes: 5 additions & 2 deletions mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,10 +124,13 @@ static void computeBackwardSlice(tensor::PadOp padOp,
getUsedValuesDefinedAbove(padOp.getRegion(), padOp.getRegion(),
valuesDefinedAbove);
for (Value v : valuesDefinedAbove) {
getBackwardSlice(v, &backwardSlice, sliceOptions);
auto result = getBackwardSlice(v, &backwardSlice, sliceOptions);
assert(result.succeeded());
}
// Then, add the backward slice from padOp itself.
getBackwardSlice(padOp.getOperation(), &backwardSlice, sliceOptions);
auto result =
getBackwardSlice(padOp.getOperation(), &backwardSlice, sliceOptions);
assert(result.succeeded());
}

//===----------------------------------------------------------------------===//
Expand Down
6 changes: 4 additions & 2 deletions mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -290,8 +290,10 @@ static void getPipelineStages(
});
options.inclusive = true;
for (Operation &op : forOp.getBody()->getOperations()) {
if (stage0Ops.contains(&op))
getBackwardSlice(&op, &dependencies, options);
if (stage0Ops.contains(&op)) {
auto result = getBackwardSlice(&op, &dependencies, options);
assert(result.succeeded());
}
}

for (Operation &op : forOp.getBody()->getOperations()) {
Expand Down
3 changes: 2 additions & 1 deletion mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1772,7 +1772,8 @@ checkAssumptionForLoop(Operation *loopOp, Operation *consumerOp,
};
llvm::SetVector<Operation *> slice;
for (auto operand : consumerOp->getOperands()) {
getBackwardSlice(operand, &slice, options);
auto result = getBackwardSlice(operand, &slice, options);
assert(result.succeeded());
}

if (!slice.empty()) {
Expand Down
6 changes: 4 additions & 2 deletions mlir/lib/Transforms/Utils/RegionUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1094,7 +1094,8 @@ LogicalResult mlir::moveOperationDependencies(RewriterBase &rewriter,
return !dominance.properlyDominates(sliceBoundaryOp, insertionPoint);
};
llvm::SetVector<Operation *> slice;
getBackwardSlice(op, &slice, options);
auto result = getBackwardSlice(op, &slice, options);
assert(result.succeeded());

// If the slice contains `insertionPoint` cannot move the dependencies.
if (slice.contains(insertionPoint)) {
Expand Down Expand Up @@ -1159,7 +1160,8 @@ LogicalResult mlir::moveValueDefinitions(RewriterBase &rewriter,
};
llvm::SetVector<Operation *> slice;
for (auto value : prunedValues) {
getBackwardSlice(value, &slice, options);
auto result = getBackwardSlice(value, &slice, options);
assert(result.succeeded());
}

// If the slice contains `insertionPoint` cannot move the dependencies.
Expand Down
3 changes: 2 additions & 1 deletion mlir/test/lib/Dialect/Affine/TestVectorizationUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,8 @@ void VectorizerTestPass::testBackwardSlicing(llvm::raw_ostream &outs) {
patternTestSlicingOps().match(f, &matches);
for (auto m : matches) {
SetVector<Operation *> backwardSlice;
getBackwardSlice(m.getMatchedOperation(), &backwardSlice);
auto result = getBackwardSlice(m.getMatchedOperation(), &backwardSlice);
assert(result.succeeded());
outs << "\nmatched: " << *m.getMatchedOperation()
<< " backward static slice: ";
for (auto *op : backwardSlice)
Expand Down
3 changes: 2 additions & 1 deletion mlir/test/lib/IR/TestSlicing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ static LogicalResult createBackwardSliceFunction(Operation *op,
options.omitBlockArguments = omitBlockArguments;
// TODO: Make this default.
options.omitUsesFromAbove = false;
getBackwardSlice(op, &slice, options);
auto result = getBackwardSlice(op, &slice, options);
assert(result.succeeded());
for (Operation *slicedOp : slice)
builder.clone(*slicedOp, mapper);
builder.create<func::ReturnOp>(loc);
Expand Down
Loading