Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ def ForOp : SCF_Op<"for",
"getLoopInductionVars", "getLoopLowerBounds", "getLoopSteps",
"getLoopUpperBounds", "getYieldedValuesMutable",
"promoteIfSingleIteration", "replaceWithAdditionalYields",
"wrapInTripCountCheck", "unwrapTripCountCheck",
"yieldTiledValuesAndReplace"]>,
AllTypesMatch<["lowerBound", "upperBound", "step"]>,
ConditionallySpeculatable,
Expand Down Expand Up @@ -302,7 +303,7 @@ def ForallOp : SCF_Op<"forall", [
AttrSizedOperandSegments,
AutomaticAllocationScope,
DeclareOpInterfaceMethods<LoopLikeOpInterface,
["getInitsMutable", "getRegionIterArgs", "getLoopInductionVars",
["getInitsMutable", "getRegionIterArgs", "getLoopInductionVars",
"getLoopLowerBounds", "getLoopUpperBounds", "getLoopSteps",
"promoteIfSingleIteration", "yieldTiledValuesAndReplace"]>,
RecursiveMemoryEffects,
Expand Down
20 changes: 20 additions & 0 deletions mlir/include/mlir/Interfaces/LoopLikeInterface.td
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,26 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
/*methodBody=*/"",
/*defaultImplementation=*/"op->moveBefore($_op);"
>,
InterfaceMethod<[{
Wraps the loop into a trip-count check.
}],
/*retTy=*/"FailureOr<std::pair<::mlir::Operation *, ::mlir::Region *>>",
/*methodName=*/"wrapInTripCountCheck",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/"return ::mlir::failure();"
>,
InterfaceMethod<[{
Unwraps the trip-count check.
}],
/*retTy=*/"::llvm::LogicalResult",
/*methodName=*/"unwrapTripCountCheck",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return ::mlir::failure();
}]
>,
InterfaceMethod<[{
Promotes the loop body to its containing block if the loop is known to
have a single iteration. Returns "success" if the promotion was
Expand Down
4 changes: 4 additions & 0 deletions mlir/include/mlir/Interfaces/SideEffectInterfaces.h
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,10 @@ bool wouldOpBeTriviallyDead(Operation *op);
/// conditions are satisfied.
bool isMemoryEffectFree(Operation *op);

/// Returns true if the given operation implements `MemoryEffectOpInterface` and
/// has only read effects.
bool hasOnlyReadEffect(Operation *op);

/// Returns the side effects of an operation. If the operation has
/// RecursiveMemoryEffects, include all side effects of child operations.
///
Expand Down
12 changes: 9 additions & 3 deletions mlir/include/mlir/Transforms/LoopInvariantCodeMotionUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,19 @@ class Value;
/// }
/// ```
///
/// Users must supply three callbacks.
/// Users must supply five callbacks.
///
/// - `isDefinedOutsideRegion` returns true if the given value is invariant with
/// respect to the given region. A common implementation might be:
/// `value.getParentRegion()->isProperAncestor(region)`.
/// - `shouldMoveOutOfRegion` returns true if the provided operation can be
/// moved of the given region, e.g. if it is side-effect free.
/// moved of the given region, e.g. if it is side-effect free or has only read
/// side effects.
/// - `wrapInGuard` wraps the given operation in a trip-count check guard.
/// - `moveOutOfRegion` moves the operation out of the given region. A common
/// implementation might be: `op->moveBefore(region->getParentOp())`.
/// - `unwrapGuard` unwraps the trip-count check if there is no op guarded by
/// this check.
///
/// An operation is moved if all of its operands satisfy
/// `isDefinedOutsideRegion` and it satisfies `shouldMoveOutOfRegion`.
Expand All @@ -66,7 +70,9 @@ size_t moveLoopInvariantCode(
ArrayRef<Region *> regions,
function_ref<bool(Value, Region *)> isDefinedOutsideRegion,
function_ref<bool(Operation *, Region *)> shouldMoveOutOfRegion,
function_ref<void(Operation *, Region *)> moveOutOfRegion);
function_ref<FailureOr<std::pair<Operation *, Region *>>()> wrapInGuard,
function_ref<void(Operation *, Region *)> moveOutOfRegion,
function_ref<LogicalResult()> unwrapGuard);

/// Move side-effect free loop invariant code out of a loop-like op using
/// methods provided by the interface.
Expand Down
59 changes: 56 additions & 3 deletions mlir/lib/Dialect/SCF/IR/SCF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,60 @@ std::optional<SmallVector<OpFoldResult>> ForOp::getLoopUpperBounds() {

std::optional<ResultRange> ForOp::getLoopResults() { return getResults(); }

FailureOr<std::pair<Operation *, Region *>> ForOp::wrapInTripCountCheck() {

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Strip blank lines

IRRewriter rewriter(this->getContext());
OpBuilder::InsertionGuard insertGuard(rewriter);
rewriter.setInsertionPointAfter(this->getOperation());

auto loc = this->getLoc();
auto cmpIOp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ugt,
this->getUpperBound(),
this->getLowerBound());
scf::YieldOp yieldInThen;
// Create the trip-count check.
auto ifOp = rewriter.create<scf::IfOp>(
loc, cmpIOp,
[&](OpBuilder &builder, Location loc) {
yieldInThen = builder.create<scf::YieldOp>(loc, this->getResults());
},
[&](OpBuilder &builder, Location loc) {
builder.create<scf::YieldOp>(loc, this->getInitArgs());
});

for (auto [forOpResult, ifOpResult] :
llvm::zip(this->getResults(), ifOp.getResults()))
rewriter.replaceAllUsesExcept(forOpResult, ifOpResult, yieldInThen);
// Move the scf.for into the then block.
rewriter.moveOpBefore(this->getOperation(), yieldInThen);
return std::make_pair(ifOp.getOperation(), &this->getRegion());
}

LogicalResult ForOp::unwrapTripCountCheck() {
auto ifOp = (*this)->getParentRegion()->getParentOp();
if (!isa<scf::IfOp>(ifOp))
return failure();

IRRewriter rewriter(ifOp->getContext());
OpBuilder::InsertionGuard insertGuard(rewriter);
rewriter.setInsertionPoint(ifOp);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This requires that this function be called immediately after wrapInTripCountCheck. How can this be guaranteed?

Copy link
Contributor Author

@ardaunal ardaunal Dec 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By the check here.

This checks if this region is wrapped and there is no op that needs to be guarded when being hoisted, then it unwraps. This function is called only then.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the reply, I'll look into details later.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Functions implemented by interfaces should be callable from anywhere. If there is a requirement for functions of multiple interfaces to be called in a specific order, it is recommended not to use interfaces.


auto cmpOp = ifOp->getOperand(0).getDefiningOp();
if (!isa<arith::CmpIOp>(cmpOp))
return failure();

auto wrappedForOp = this->getOperation();
rewriter.moveOpBefore(wrappedForOp, ifOp);

for (auto [forOpResult, ifOpResult] :
llvm::zip(wrappedForOp->getResults(), ifOp->getResults()))
rewriter.replaceAllUsesWith(ifOpResult, forOpResult);

rewriter.eraseOp(ifOp);
rewriter.eraseOp(cmpOp);
return success();
}

/// Promotes the loop body of a forOp to its containing block if the forOp
/// it can be determined that the loop has a single iteration.
LogicalResult ForOp::promoteIfSingleIteration(RewriterBase &rewriter) {
Expand Down Expand Up @@ -3397,9 +3451,8 @@ ParseResult scf::WhileOp::parse(OpAsmParser &parser, OperationState &result) {

if (functionType.getNumInputs() != operands.size()) {
return parser.emitError(typeLoc)
<< "expected as many input types as operands "
<< "(expected " << operands.size() << " got "
<< functionType.getNumInputs() << ")";
<< "expected as many input types as operands " << "(expected "
<< operands.size() << " got " << functionType.getNumInputs() << ")";
}

// Resolve input operands.
Expand Down
20 changes: 20 additions & 0 deletions mlir/lib/Interfaces/SideEffectInterfaces.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,26 @@ bool mlir::wouldOpBeTriviallyDead(Operation *op) {
return wouldOpBeTriviallyDeadImpl(op);
}

bool mlir::hasOnlyReadEffect(Operation *op) {
if (auto memEffects = dyn_cast<MemoryEffectOpInterface>(op)) {
if (!op->hasTrait<OpTrait::HasRecursiveMemoryEffects>())
return memEffects.onlyHasEffect<MemoryEffects::Read>();
} else if (!op->hasTrait<OpTrait::HasRecursiveMemoryEffects>()) {
// Otherwise, if the op does not implement the memory effect interface and
// it does not have recursive side effects, then it cannot be known that the
// op is moveable.
return false;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This implementation is not robust and can hardly handle any operations with regions, because their terminators are inherently memory effect free.

}

// Recurse into the regions and ensure that all nested ops are memory effect
// free.
for (Region &region : op->getRegions())
for (Operation &op : region.getOps())
if (!hasOnlyReadEffect(&op))
return false;
return true;
}

bool mlir::isMemoryEffectFree(Operation *op) {
if (auto memInterface = dyn_cast<MemoryEffectOpInterface>(op)) {
if (!memInterface.hasNoEffect())
Expand Down
96 changes: 85 additions & 11 deletions mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,48 +56,117 @@ static bool canBeHoisted(Operation *op,
op, [&](OpOperand &operand) { return definedOutside(operand.get()); });
}

static bool dependsOnGuarded(Operation *op,
function_ref<bool(OpOperand &)> condition) {
auto walkFn = [&](Operation *child) {
for (OpOperand &operand : child->getOpOperands()) {
if (!condition(operand))
return WalkResult::interrupt();
}
return WalkResult::advance();
};
return op->walk(walkFn).wasInterrupted();
}

static bool dependsOnGuarded(Operation *op,
function_ref<bool(Value)> definedOutsideGuard) {
return dependsOnGuarded(op, [&](OpOperand &operand) {
return definedOutsideGuard(operand.get());
});
}

static bool loopSideEffectFreeOrHasOnlyReadEffect(Operation *loop) {
for (Region &region : loop->getRegions()) {
for (Block &block : region.getBlocks()) {
for (Operation &op : block.getOperations()) {
if (!isMemoryEffectFree(&op) && !hasOnlyReadEffect(&op))
return false;
}
}
}
return true;
}

size_t mlir::moveLoopInvariantCode(
ArrayRef<Region *> regions,
function_ref<bool(Value, Region *)> isDefinedOutsideRegion,
function_ref<bool(Operation *, Region *)> shouldMoveOutOfRegion,
function_ref<void(Operation *, Region *)> moveOutOfRegion) {
function_ref<FailureOr<std::pair<Operation *, Region *>>()> wrapInGuard,
function_ref<void(Operation *, Region *)> moveOutOfRegion,
function_ref<LogicalResult()> unwrapGuard) {
size_t numMoved = 0;

for (Region *region : regions) {
LLVM_DEBUG(llvm::dbgs() << "Original loop:\n"
<< *region->getParentOp() << "\n");

auto loopSideEffectFreeOrHasOnlyReadSideEffect =
loopSideEffectFreeOrHasOnlyReadEffect(region->getParentOp());

size_t numMovedWithoutGuard = 0;

FailureOr<std::pair<Operation *, Region *>> ifOpAndRegion = wrapInGuard();
Region *loopRegion = region;
auto isLoopWrapped = false;
if (succeeded(ifOpAndRegion)) {
loopRegion = ifOpAndRegion->second;
isLoopWrapped = true;
}

std::queue<Operation *> worklist;
// Add top-level operations in the loop body to the worklist.
for (Operation &op : region->getOps())
for (Operation &op : loopRegion->getOps())
worklist.push(&op);

auto definedOutside = [&](Value value) {
return isDefinedOutsideRegion(value, region);
return isDefinedOutsideRegion(value, loopRegion);
};

auto definedOutsideGuard = [&](Value value) {
return isDefinedOutsideRegion(value, loopRegion->getParentRegion());
};

while (!worklist.empty()) {
Operation *op = worklist.front();
worklist.pop();
// Skip ops that have already been moved. Check if the op can be hoisted.
if (op->getParentRegion() != region)
if (op->getParentRegion() != loopRegion)
continue;

LLVM_DEBUG(llvm::dbgs() << "Checking op: " << *op << "\n");
if (!shouldMoveOutOfRegion(op, region) ||

if (!shouldMoveOutOfRegion(op, loopRegion) ||
!canBeHoisted(op, definedOutside))
continue;
// Can only hoist pure ops (side-effect free) when there is an op with
// write and/or unknown side effects in the loop.
if (!loopSideEffectFreeOrHasOnlyReadSideEffect && !isMemoryEffectFree(op))
continue;

LLVM_DEBUG(llvm::dbgs() << "Moving loop-invariant op: " << *op << "\n");
moveOutOfRegion(op, region);

auto moveWithoutGuard = isMemoryEffectFree(op) &&
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why auto

!dependsOnGuarded(op, definedOutsideGuard) &&
isLoopWrapped;
numMovedWithoutGuard += moveWithoutGuard;

moveOutOfRegion(op, moveWithoutGuard ? loopRegion->getParentRegion()
: loopRegion);
++numMoved;

// Since the op has been moved, we need to check its users within the
// top-level of the loop body.
for (Operation *user : op->getUsers())
if (user->getParentRegion() == region)
if (user->getParentRegion() == loopRegion)
worklist.push(user);
}

// Unwrap the loop if it was wrapped but no ops were moved in the guard.
if (isLoopWrapped && numMovedWithoutGuard == numMoved) {
auto tripCountCheckUnwrapped = unwrapGuard();
if (failed(tripCountCheckUnwrapped))
llvm_unreachable("Should not fail unwrapping trip-count check");
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is unwarp part necessary?Can we use canonicalize to achieve that?

}

return numMoved;
Expand All @@ -106,13 +175,18 @@ size_t mlir::moveLoopInvariantCode(
size_t mlir::moveLoopInvariantCode(LoopLikeOpInterface loopLike) {
return moveLoopInvariantCode(
loopLike.getLoopRegions(),
[&](Value value, Region *) {
return loopLike.isDefinedOutsideOfLoop(value);
[&](Value value, Region *region) {
return !region->isAncestor(value.getParentRegion());
},
[&](Operation *op, Region *) {
return isMemoryEffectFree(op) && isSpeculatable(op);
return isSpeculatable(op) &&
(isMemoryEffectFree(op) || hasOnlyReadEffect(op));
},
[&]() { return loopLike.wrapInTripCountCheck(); },
[&](Operation *op, Region *region) {
op->moveBefore(region->getParentOp());
},
[&](Operation *op, Region *) { loopLike.moveOutOfLoop(op); });
[&]() { return loopLike.unwrapTripCountCheck(); });
}

namespace {
Expand Down
Loading
Loading