Skip to content

Commit 8a38077

Browse files
committed
fixup! fixup! Enable LICM for ops with read side effects in scf.for wrapped by a guard
1 parent c532810 commit 8a38077

File tree

8 files changed

+154
-201
lines changed

8 files changed

+154
-201
lines changed

mlir/include/mlir/Dialect/SCF/IR/SCFOps.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def SCF_Dialect : Dialect {
4040
and then lowered to some final target like LLVM or SPIR-V.
4141
}];
4242

43-
let dependentDialects = ["arith::ArithDialect"];
43+
let dependentDialects = ["arith::ArithDialect", "ub::UBDialect"];
4444
}
4545

4646
// Base class for SCF dialect ops.
@@ -138,6 +138,7 @@ def ForOp : SCF_Op<"for",
138138
["getInitsMutable", "getLoopResults", "getRegionIterArgs",
139139
"getLoopInductionVars", "getLoopLowerBounds", "getLoopSteps",
140140
"getLoopUpperBounds", "getYieldedValuesMutable",
141+
"moveOutOfLoopWithGuard",
141142
"promoteIfSingleIteration", "replaceWithAdditionalYields",
142143
"wrapInTripCountCheck", "unwrapTripCountCheck",
143144
"yieldTiledValuesAndReplace"]>,

mlir/include/mlir/Interfaces/LoopLikeInterface.td

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -80,23 +80,15 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
8080
/*defaultImplementation=*/"op->moveBefore($_op);"
8181
>,
8282
InterfaceMethod<[{
83-
Wraps the loop into a trip-count check.
83+
Moves the given loop-invariant operation out of the loop with a
84+
trip-count guard.
8485
}],
85-
/*retTy=*/"FailureOr<std::pair<::mlir::Operation *, ::mlir::Region *>>",
86-
/*methodName=*/"wrapInTripCountCheck",
87-
/*args=*/(ins),
88-
/*methodBody=*/"",
89-
/*defaultImplementation=*/"return ::mlir::failure();"
90-
>,
91-
InterfaceMethod<[{
92-
Unwraps the trip-count check.
93-
}],
94-
/*retTy=*/"::llvm::LogicalResult",
95-
/*methodName=*/"unwrapTripCountCheck",
96-
/*args=*/(ins),
86+
/*retTy=*/"void",
87+
/*methodName=*/"moveOutOfLoopWithGuard",
88+
/*args=*/(ins "::mlir::Operation *":$op),
9789
/*methodBody=*/"",
9890
/*defaultImplementation=*/[{
99-
return ::mlir::failure();
91+
return;
10092
}]
10193
>,
10294
InterfaceMethod<[{

mlir/include/mlir/Interfaces/SideEffectInterfaces.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -433,9 +433,9 @@ bool wouldOpBeTriviallyDead(Operation *op);
433433
/// conditions are satisfied.
434434
bool isMemoryEffectFree(Operation *op);
435435

436-
/// Returns true if the given operation implements `MemoryEffectOpInterface` and
437-
/// has only read effects.
438-
bool hasOnlyReadEffect(Operation *op);
436+
/// Returns true if the given operation is free of memory effects or has only
437+
/// read effect.
438+
bool isMemoryEffectFreeOrOnlyRead(Operation *op);
439439

440440
/// Returns the side effects of an operation. If the operation has
441441
/// RecursiveMemoryEffects, include all side effects of child operations.

mlir/include/mlir/Transforms/LoopInvariantCodeMotionUtils.h

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -47,20 +47,19 @@ class Value;
4747
/// }
4848
/// }
4949
/// ```
50-
///
51-
/// Users must supply five callbacks.
50+
/// Users must supply four callbacks.
5251
///
5352
/// - `isDefinedOutsideRegion` returns true if the given value is invariant with
5453
/// respect to the given region. A common implementation might be:
5554
/// `value.getParentRegion()->isProperAncestor(region)`.
5655
/// - `shouldMoveOutOfRegion` returns true if the provided operation can be
5756
/// moved of the given region, e.g. if it is side-effect free or has only read
5857
/// side effects.
59-
/// - `wrapInGuard` wraps the given operation in a trip-count check guard.
60-
/// - `moveOutOfRegion` moves the operation out of the given region. A common
61-
/// implementation might be: `op->moveBefore(region->getParentOp())`.
62-
/// - `unwrapGuard` unwraps the trip-count check if there is no op guarded by
63-
/// this check.
58+
/// - `moveOutOfRegionWithoutGuard` moves the operation out of the given region
59+
/// without a guard. A common implementation might be:
60+
/// `op->moveBefore(region->getParentOp())`.
61+
/// - `moveOutOfRegionWithGuard` moves the operation out of the given region
62+
/// with a guard.
6463
///
6564
/// An operation is moved if all of its operands satisfy
6665
/// `isDefinedOutsideRegion` and it satisfies `shouldMoveOutOfRegion`.
@@ -70,9 +69,8 @@ size_t moveLoopInvariantCode(
7069
ArrayRef<Region *> regions,
7170
function_ref<bool(Value, Region *)> isDefinedOutsideRegion,
7271
function_ref<bool(Operation *, Region *)> shouldMoveOutOfRegion,
73-
function_ref<FailureOr<std::pair<Operation *, Region *>>()> wrapInGuard,
74-
function_ref<void(Operation *, Region *)> moveOutOfRegion,
75-
function_ref<LogicalResult()> unwrapGuard);
72+
function_ref<void(Operation *)> moveOutOfRegionWithoutGuard,
73+
function_ref<void(Operation *)> moveOutOfRegionWithGuard);
7674

7775
/// Move side-effect free loop invariant code out of a loop-like op using
7876
/// methods provided by the interface.

mlir/lib/Dialect/SCF/IR/SCF.cpp

Lines changed: 25 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "mlir/Dialect/MemRef/IR/MemRef.h"
1616
#include "mlir/Dialect/SCF/IR/DeviceMappingInterface.h"
1717
#include "mlir/Dialect/Tensor/IR/Tensor.h"
18+
#include "mlir/Dialect/UB/IR/UBOps.h"
1819
#include "mlir/IR/BuiltinAttributes.h"
1920
#include "mlir/IR/IRMapping.h"
2021
#include "mlir/IR/Matchers.h"
@@ -395,58 +396,38 @@ std::optional<SmallVector<OpFoldResult>> ForOp::getLoopUpperBounds() {
395396

396397
std::optional<ResultRange> ForOp::getLoopResults() { return getResults(); }
397398

398-
FailureOr<std::pair<Operation *, Region *>> ForOp::wrapInTripCountCheck() {
399-
399+
/// Moves the op out of the loop with a guard that checks if the loop has at
400+
/// least one iteration.
401+
void ForOp::moveOutOfLoopWithGuard(Operation *op) {
400402
IRRewriter rewriter(this->getContext());
401403
OpBuilder::InsertionGuard insertGuard(rewriter);
402-
rewriter.setInsertionPointAfter(this->getOperation());
403-
404-
auto loc = this->getLoc();
405-
auto cmpIOp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ugt,
406-
this->getUpperBound(),
407-
this->getLowerBound());
408-
scf::YieldOp yieldInThen;
404+
rewriter.setInsertionPoint(this->getOperation());
405+
Location loc = this->getLoc();
406+
arith::CmpIOp cmpIOp = rewriter.create<arith::CmpIOp>(
407+
loc, arith::CmpIPredicate::ult, this->getLowerBound(),
408+
this->getUpperBound());
409409
// Create the trip-count check.
410-
auto ifOp = rewriter.create<scf::IfOp>(
410+
scf::YieldOp thenYield;
411+
scf::IfOp ifOp = rewriter.create<scf::IfOp>(
411412
loc, cmpIOp,
412413
[&](OpBuilder &builder, Location loc) {
413-
yieldInThen = builder.create<scf::YieldOp>(loc, this->getResults());
414+
thenYield = builder.create<scf::YieldOp>(loc, op->getResults());
414415
},
415416
[&](OpBuilder &builder, Location loc) {
416-
builder.create<scf::YieldOp>(loc, this->getInitArgs());
417+
SmallVector<Value> poisonResults;
418+
poisonResults.reserve(op->getResults().size());
419+
for (Type type : op->getResults().getTypes()) {
420+
ub::PoisonOp poisonOp =
421+
rewriter.create<ub::PoisonOp>(loc, type, nullptr);
422+
poisonResults.push_back(poisonOp);
423+
}
424+
builder.create<scf::YieldOp>(loc, poisonResults);
417425
});
418-
419-
for (auto [forOpResult, ifOpResult] :
420-
llvm::zip(this->getResults(), ifOp.getResults()))
421-
rewriter.replaceAllUsesExcept(forOpResult, ifOpResult, yieldInThen);
422-
// Move the scf.for into the then block.
423-
rewriter.moveOpBefore(this->getOperation(), yieldInThen);
424-
return std::make_pair(ifOp.getOperation(), &this->getRegion());
425-
}
426-
427-
LogicalResult ForOp::unwrapTripCountCheck() {
428-
auto ifOp = (*this)->getParentRegion()->getParentOp();
429-
if (!isa<scf::IfOp>(ifOp))
430-
return failure();
431-
432-
IRRewriter rewriter(ifOp->getContext());
433-
OpBuilder::InsertionGuard insertGuard(rewriter);
434-
rewriter.setInsertionPoint(ifOp);
435-
436-
auto cmpOp = ifOp->getOperand(0).getDefiningOp();
437-
if (!isa<arith::CmpIOp>(cmpOp))
438-
return failure();
439-
440-
auto wrappedForOp = this->getOperation();
441-
rewriter.moveOpBefore(wrappedForOp, ifOp);
442-
443-
for (auto [forOpResult, ifOpResult] :
444-
llvm::zip(wrappedForOp->getResults(), ifOp->getResults()))
445-
rewriter.replaceAllUsesWith(ifOpResult, forOpResult);
446-
447-
rewriter.eraseOp(ifOp);
448-
rewriter.eraseOp(cmpOp);
449-
return success();
426+
for (auto [opResult, ifOpResult] :
427+
llvm::zip(op->getResults(), ifOp->getResults()))
428+
rewriter.replaceAllUsesExcept(opResult, ifOpResult, thenYield);
429+
// Move the op into the then block.
430+
rewriter.moveOpBefore(op, thenYield);
450431
}
451432

452433
/// Promotes the loop body of a forOp to its containing block if the forOp

mlir/lib/Interfaces/SideEffectInterfaces.cpp

Lines changed: 10 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -306,26 +306,6 @@ bool mlir::wouldOpBeTriviallyDead(Operation *op) {
306306
return wouldOpBeTriviallyDeadImpl(op);
307307
}
308308

309-
bool mlir::hasOnlyReadEffect(Operation *op) {
310-
if (auto memEffects = dyn_cast<MemoryEffectOpInterface>(op)) {
311-
if (!op->hasTrait<OpTrait::HasRecursiveMemoryEffects>())
312-
return memEffects.onlyHasEffect<MemoryEffects::Read>();
313-
} else if (!op->hasTrait<OpTrait::HasRecursiveMemoryEffects>()) {
314-
// Otherwise, if the op does not implement the memory effect interface and
315-
// it does not have recursive side effects, then it cannot be known that the
316-
// op is moveable.
317-
return false;
318-
}
319-
320-
// Recurse into the regions and ensure that all nested ops are memory effect
321-
// free.
322-
for (Region &region : op->getRegions())
323-
for (Operation &op : region.getOps())
324-
if (!hasOnlyReadEffect(&op))
325-
return false;
326-
return true;
327-
}
328-
329309
bool mlir::isMemoryEffectFree(Operation *op) {
330310
if (auto memInterface = dyn_cast<MemoryEffectOpInterface>(op)) {
331311
if (!memInterface.hasNoEffect())
@@ -383,6 +363,16 @@ mlir::getEffectsRecursively(Operation *rootOp) {
383363
return effects;
384364
}
385365

366+
bool mlir::isMemoryEffectFreeOrOnlyRead(Operation *op) {
367+
std::optional<SmallVector<MemoryEffects::EffectInstance>> effects =
368+
getEffectsRecursively(op);
369+
if (!effects)
370+
return false;
371+
return std::all_of(effects->begin(), effects->end(), [](auto &effect) {
372+
return isa<MemoryEffects::Read>(effect.getEffect());
373+
});
374+
}
375+
386376
bool mlir::isSpeculatable(Operation *op) {
387377
auto conditionallySpeculatable = dyn_cast<ConditionallySpeculatable>(op);
388378
if (!conditionallySpeculatable)

mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp

Lines changed: 24 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -56,117 +56,67 @@ static bool canBeHoisted(Operation *op,
5656
op, [&](OpOperand &operand) { return definedOutside(operand.get()); });
5757
}
5858

59-
static bool dependsOnGuarded(Operation *op,
60-
function_ref<bool(OpOperand &)> condition) {
61-
auto walkFn = [&](Operation *child) {
62-
for (OpOperand &operand : child->getOpOperands()) {
63-
if (!condition(operand))
64-
return WalkResult::interrupt();
65-
}
66-
return WalkResult::advance();
67-
};
68-
return op->walk(walkFn).wasInterrupted();
69-
}
70-
71-
static bool dependsOnGuarded(Operation *op,
72-
function_ref<bool(Value)> definedOutsideGuard) {
73-
return dependsOnGuarded(op, [&](OpOperand &operand) {
74-
return definedOutsideGuard(operand.get());
75-
});
76-
}
77-
78-
static bool loopSideEffectFreeOrHasOnlyReadEffect(Operation *loop) {
79-
for (Region &region : loop->getRegions()) {
80-
for (Block &block : region.getBlocks()) {
81-
for (Operation &op : block.getOperations()) {
82-
if (!isMemoryEffectFree(&op) && !hasOnlyReadEffect(&op))
83-
return false;
84-
}
85-
}
86-
}
87-
return true;
88-
}
89-
9059
size_t mlir::moveLoopInvariantCode(
9160
ArrayRef<Region *> regions,
9261
function_ref<bool(Value, Region *)> isDefinedOutsideRegion,
9362
function_ref<bool(Operation *, Region *)> shouldMoveOutOfRegion,
94-
function_ref<FailureOr<std::pair<Operation *, Region *>>()> wrapInGuard,
95-
function_ref<void(Operation *, Region *)> moveOutOfRegion,
96-
function_ref<LogicalResult()> unwrapGuard) {
63+
function_ref<void(Operation *)> moveOutOfRegionWithoutGuard,
64+
function_ref<void(Operation *)> moveOutOfRegionWithGuard) {
9765
size_t numMoved = 0;
9866

9967
for (Region *region : regions) {
10068
LLVM_DEBUG(llvm::dbgs() << "Original loop:\n"
10169
<< *region->getParentOp() << "\n");
10270

103-
auto loopSideEffectFreeOrHasOnlyReadSideEffect =
104-
loopSideEffectFreeOrHasOnlyReadEffect(region->getParentOp());
105-
106-
size_t numMovedWithoutGuard = 0;
107-
108-
FailureOr<std::pair<Operation *, Region *>> ifOpAndRegion = wrapInGuard();
109-
Region *loopRegion = region;
110-
auto isLoopWrapped = false;
111-
if (succeeded(ifOpAndRegion)) {
112-
loopRegion = ifOpAndRegion->second;
113-
isLoopWrapped = true;
114-
}
71+
bool anyOpHoistedWithGuard = false;
72+
bool loopSideEffectFreeOrHasOnlyReadSideEffect =
73+
isMemoryEffectFreeOrOnlyRead(region->getParentOp());
11574

11675
std::queue<Operation *> worklist;
11776
// Add top-level operations in the loop body to the worklist.
118-
for (Operation &op : loopRegion->getOps())
77+
for (Operation &op : region->getOps())
11978
worklist.push(&op);
12079

12180
auto definedOutside = [&](Value value) {
122-
return isDefinedOutsideRegion(value, loopRegion);
123-
};
124-
125-
auto definedOutsideGuard = [&](Value value) {
126-
return isDefinedOutsideRegion(value, loopRegion->getParentRegion());
81+
return isDefinedOutsideRegion(value, region);
12782
};
12883

12984
while (!worklist.empty()) {
13085
Operation *op = worklist.front();
13186
worklist.pop();
13287
// Skip ops that have already been moved. Check if the op can be hoisted.
133-
if (op->getParentRegion() != loopRegion)
88+
if (op->getParentRegion() != region)
13489
continue;
13590

13691
LLVM_DEBUG(llvm::dbgs() << "Checking op: " << *op << "\n");
13792

138-
if (!shouldMoveOutOfRegion(op, loopRegion) ||
93+
if (!shouldMoveOutOfRegion(op, region) ||
13994
!canBeHoisted(op, definedOutside))
14095
continue;
14196
// Can only hoist pure ops (side-effect free) when there is an op with
14297
// write and/or unknown side effects in the loop.
14398
if (!loopSideEffectFreeOrHasOnlyReadSideEffect && !isMemoryEffectFree(op))
14499
continue;
145100

146-
LLVM_DEBUG(llvm::dbgs() << "Moving loop-invariant op: " << *op << "\n");
147-
148-
auto moveWithoutGuard = isMemoryEffectFree(op) &&
149-
!dependsOnGuarded(op, definedOutsideGuard) &&
150-
isLoopWrapped;
151-
numMovedWithoutGuard += moveWithoutGuard;
152-
153-
moveOutOfRegion(op, moveWithoutGuard ? loopRegion->getParentRegion()
154-
: loopRegion);
101+
bool moveWithoutGuard = !anyOpHoistedWithGuard && isMemoryEffectFree(op);
102+
if (moveWithoutGuard) {
103+
LLVM_DEBUG(llvm::dbgs() << "Moving loop-invariant op: " << *op
104+
<< " without guard\n");
105+
moveOutOfRegionWithoutGuard(op);
106+
} else {
107+
LLVM_DEBUG(llvm::dbgs()
108+
<< "Moving loop-invariant op: " << *op << " with guard\n");
109+
moveOutOfRegionWithGuard(op);
110+
anyOpHoistedWithGuard = true;
111+
}
155112
++numMoved;
156113

157114
// Since the op has been moved, we need to check its users within the
158115
// top-level of the loop body.
159116
for (Operation *user : op->getUsers())
160-
if (user->getParentRegion() == loopRegion)
117+
if (user->getParentRegion() == region)
161118
worklist.push(user);
162119
}
163-
164-
// Unwrap the loop if it was wrapped but no ops were moved in the guard.
165-
if (isLoopWrapped && numMovedWithoutGuard == numMoved) {
166-
auto tripCountCheckUnwrapped = unwrapGuard();
167-
if (failed(tripCountCheckUnwrapped))
168-
llvm_unreachable("Should not fail unwrapping trip-count check");
169-
}
170120
}
171121

172122
return numMoved;
@@ -179,14 +129,10 @@ size_t mlir::moveLoopInvariantCode(LoopLikeOpInterface loopLike) {
179129
return !region->isAncestor(value.getParentRegion());
180130
},
181131
[&](Operation *op, Region *) {
182-
return isSpeculatable(op) &&
183-
(isMemoryEffectFree(op) || hasOnlyReadEffect(op));
184-
},
185-
[&]() { return loopLike.wrapInTripCountCheck(); },
186-
[&](Operation *op, Region *region) {
187-
op->moveBefore(region->getParentOp());
132+
return isSpeculatable(op) && isMemoryEffectFreeOrOnlyRead(op);
188133
},
189-
[&]() { return loopLike.unwrapTripCountCheck(); });
134+
[&](Operation *op) { loopLike.moveOutOfLoop(op); },
135+
[&](Operation *op) { loopLike.moveOutOfLoopWithGuard(op); });
190136
}
191137

192138
namespace {

0 commit comments

Comments
 (0)