@@ -56,48 +56,117 @@ static bool canBeHoisted(Operation *op,
5656 op, [&](OpOperand &operand) { return definedOutside (operand.get ()); });
5757}
5858
59+ static bool dependsOnGuarded (Operation *op,
60+ function_ref<bool (OpOperand &)> condition) {
61+ auto walkFn = [&](Operation *child) {
62+ for (OpOperand &operand : child->getOpOperands ()) {
63+ if (!condition (operand))
64+ return WalkResult::interrupt ();
65+ }
66+ return WalkResult::advance ();
67+ };
68+ return op->walk (walkFn).wasInterrupted ();
69+ }
70+
71+ static bool dependsOnGuarded (Operation *op,
72+ function_ref<bool (Value)> definedOutsideGuard) {
73+ return dependsOnGuarded (op, [&](OpOperand &operand) {
74+ return definedOutsideGuard (operand.get ());
75+ });
76+ }
77+
78+ static bool loopSideEffectFreeOrHasOnlyReadEffect (Operation *loop) {
79+ for (Region ®ion : loop->getRegions ()) {
80+ for (Block &block : region.getBlocks ()) {
81+ for (Operation &op : block.getOperations ()) {
82+ if (!isMemoryEffectFree (&op) && !hasOnlyReadEffect (&op))
83+ return false ;
84+ }
85+ }
86+ }
87+ return true ;
88+ }
89+
5990size_t mlir::moveLoopInvariantCode (
6091 ArrayRef<Region *> regions,
6192 function_ref<bool (Value, Region *)> isDefinedOutsideRegion,
6293 function_ref<bool(Operation *, Region *)> shouldMoveOutOfRegion,
63- function_ref<void(Operation *, Region *)> moveOutOfRegion) {
94+ function_ref<FailureOr<std::pair<Operation *, Region *>>()> wrapInGuard,
95+ function_ref<void(Operation *, Region *)> moveOutOfRegion,
96+ function_ref<LogicalResult()> unwrapGuard) {
6497 size_t numMoved = 0 ;
6598
6699 for (Region *region : regions) {
67100 LLVM_DEBUG (llvm::dbgs () << " Original loop:\n "
68101 << *region->getParentOp () << " \n " );
69102
103+ auto loopSideEffectFreeOrHasOnlyReadSideEffect =
104+ loopSideEffectFreeOrHasOnlyReadEffect (region->getParentOp ());
105+
106+ size_t numMovedWithoutGuard = 0 ;
107+
108+ FailureOr<std::pair<Operation *, Region *>> ifOpAndRegion = wrapInGuard ();
109+ Region *loopRegion = region;
110+ auto isLoopWrapped = false ;
111+ if (succeeded (ifOpAndRegion)) {
112+ loopRegion = ifOpAndRegion->second ;
113+ isLoopWrapped = true ;
114+ }
115+
70116 std::queue<Operation *> worklist;
71117 // Add top-level operations in the loop body to the worklist.
72- for (Operation &op : region ->getOps ())
118+ for (Operation &op : loopRegion ->getOps ())
73119 worklist.push (&op);
74120
75121 auto definedOutside = [&](Value value) {
76- return isDefinedOutsideRegion (value, region);
122+ return isDefinedOutsideRegion (value, loopRegion);
123+ };
124+
125+ auto definedOutsideGuard = [&](Value value) {
126+ return isDefinedOutsideRegion (value, loopRegion->getParentRegion ());
77127 };
78128
79129 while (!worklist.empty ()) {
80130 Operation *op = worklist.front ();
81131 worklist.pop ();
82132 // Skip ops that have already been moved. Check if the op can be hoisted.
83- if (op->getParentRegion () != region )
133+ if (op->getParentRegion () != loopRegion )
84134 continue ;
85135
86136 LLVM_DEBUG (llvm::dbgs () << " Checking op: " << *op << " \n " );
87- if (!shouldMoveOutOfRegion (op, region) ||
137+
138+ if (!shouldMoveOutOfRegion (op, loopRegion) ||
88139 !canBeHoisted (op, definedOutside))
89140 continue ;
141+ // Can only hoist pure ops (side-effect free) when there is an op with
142+ // write side effects in the loop.
143+ if (!loopSideEffectFreeOrHasOnlyReadSideEffect && !isMemoryEffectFree (op))
144+ continue ;
90145
91146 LLVM_DEBUG (llvm::dbgs () << " Moving loop-invariant op: " << *op << " \n " );
92- moveOutOfRegion (op, region);
147+
148+ auto moveWithoutGuard = isMemoryEffectFree (op) &&
149+ !dependsOnGuarded (op, definedOutsideGuard) &&
150+ isLoopWrapped;
151+ numMovedWithoutGuard += moveWithoutGuard;
152+
153+ moveOutOfRegion (op, moveWithoutGuard ? loopRegion->getParentRegion ()
154+ : loopRegion);
93155 ++numMoved;
94156
95157 // Since the op has been moved, we need to check its users within the
96158 // top-level of the loop body.
97159 for (Operation *user : op->getUsers ())
98- if (user->getParentRegion () == region )
160+ if (user->getParentRegion () == loopRegion )
99161 worklist.push (user);
100162 }
163+
164+ // Unwrap the loop if it was wrapped but no ops were moved in the guard.
165+ if (isLoopWrapped && numMovedWithoutGuard == numMoved) {
166+ auto tripCountCheckUnwrapped = unwrapGuard ();
167+ if (failed (tripCountCheckUnwrapped))
168+ llvm_unreachable (" Should not fail unwrapping trip-count check" );
169+ }
101170 }
102171
103172 return numMoved;
@@ -106,13 +175,18 @@ size_t mlir::moveLoopInvariantCode(
106175size_t mlir::moveLoopInvariantCode (LoopLikeOpInterface loopLike) {
107176 return moveLoopInvariantCode (
108177 loopLike.getLoopRegions (),
109- [&](Value value, Region *) {
110- return loopLike. isDefinedOutsideOfLoop (value);
178+ [&](Value value, Region *region ) {
179+ return !region-> isAncestor (value. getParentRegion () );
111180 },
112181 [&](Operation *op, Region *) {
113- return isMemoryEffectFree (op) && isSpeculatable (op);
182+ return isSpeculatable (op) &&
183+ (isMemoryEffectFree (op) || hasOnlyReadEffect (op));
184+ },
185+ [&]() { return loopLike.wrapInTripCountCheck (); },
186+ [&](Operation *op, Region *region) {
187+ op->moveBefore (region->getParentOp ());
114188 },
115- [&](Operation *op, Region * ) { loopLike.moveOutOfLoop (op ); });
189+ [&]() { return loopLike.unwrapTripCountCheck ( ); });
116190}
117191
118192namespace {
0 commit comments