-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[mlir] Enable LICM for ops with only read side effects in scf.for #120302
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
0e596fd
c532810
8a38077
190f03a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -395,6 +395,60 @@ std::optional<SmallVector<OpFoldResult>> ForOp::getLoopUpperBounds() { | |
|
|
||
| std::optional<ResultRange> ForOp::getLoopResults() { return getResults(); } | ||
|
|
||
| FailureOr<std::pair<Operation *, Region *>> ForOp::wrapInTripCountCheck() { | ||
|
|
||
| IRRewriter rewriter(this->getContext()); | ||
| OpBuilder::InsertionGuard insertGuard(rewriter); | ||
| rewriter.setInsertionPointAfter(this->getOperation()); | ||
|
|
||
| auto loc = this->getLoc(); | ||
| auto cmpIOp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ugt, | ||
| this->getUpperBound(), | ||
| this->getLowerBound()); | ||
| scf::YieldOp yieldInThen; | ||
| // Create the trip-count check. | ||
| auto ifOp = rewriter.create<scf::IfOp>( | ||
| loc, cmpIOp, | ||
| [&](OpBuilder &builder, Location loc) { | ||
| yieldInThen = builder.create<scf::YieldOp>(loc, this->getResults()); | ||
| }, | ||
| [&](OpBuilder &builder, Location loc) { | ||
| builder.create<scf::YieldOp>(loc, this->getInitArgs()); | ||
| }); | ||
|
|
||
| for (auto [forOpResult, ifOpResult] : | ||
| llvm::zip(this->getResults(), ifOp.getResults())) | ||
| rewriter.replaceAllUsesExcept(forOpResult, ifOpResult, yieldInThen); | ||
| // Move the scf.for into the then block. | ||
| rewriter.moveOpBefore(this->getOperation(), yieldInThen); | ||
| return std::make_pair(ifOp.getOperation(), &this->getRegion()); | ||
| } | ||
|
|
||
| LogicalResult ForOp::unwrapTripCountCheck() { | ||
ardaunal marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| auto ifOp = (*this)->getParentRegion()->getParentOp(); | ||
| if (!isa<scf::IfOp>(ifOp)) | ||
| return failure(); | ||
|
|
||
| IRRewriter rewriter(ifOp->getContext()); | ||
| OpBuilder::InsertionGuard insertGuard(rewriter); | ||
| rewriter.setInsertionPoint(ifOp); | ||
|
||
|
|
||
| auto cmpOp = ifOp->getOperand(0).getDefiningOp(); | ||
| if (!isa<arith::CmpIOp>(cmpOp)) | ||
| return failure(); | ||
|
|
||
| auto wrappedForOp = this->getOperation(); | ||
| rewriter.moveOpBefore(wrappedForOp, ifOp); | ||
ardaunal marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| for (auto [forOpResult, ifOpResult] : | ||
| llvm::zip(wrappedForOp->getResults(), ifOp->getResults())) | ||
| rewriter.replaceAllUsesWith(ifOpResult, forOpResult); | ||
|
|
||
| rewriter.eraseOp(ifOp); | ||
| rewriter.eraseOp(cmpOp); | ||
| return success(); | ||
| } | ||
|
|
||
| /// Promotes the loop body of a forOp to its containing block if the forOp | ||
| /// it can be determined that the loop has a single iteration. | ||
| LogicalResult ForOp::promoteIfSingleIteration(RewriterBase &rewriter) { | ||
|
|
@@ -3397,9 +3451,8 @@ ParseResult scf::WhileOp::parse(OpAsmParser &parser, OperationState &result) { | |
|
|
||
| if (functionType.getNumInputs() != operands.size()) { | ||
| return parser.emitError(typeLoc) | ||
| << "expected as many input types as operands " | ||
| << "(expected " << operands.size() << " got " | ||
| << functionType.getNumInputs() << ")"; | ||
| << "expected as many input types as operands " << "(expected " | ||
| << operands.size() << " got " << functionType.getNumInputs() << ")"; | ||
| } | ||
|
|
||
| // Resolve input operands. | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -306,6 +306,26 @@ bool mlir::wouldOpBeTriviallyDead(Operation *op) { | |
| return wouldOpBeTriviallyDeadImpl(op); | ||
| } | ||
|
|
||
| bool mlir::hasOnlyReadEffect(Operation *op) { | ||
| if (auto memEffects = dyn_cast<MemoryEffectOpInterface>(op)) { | ||
| if (!op->hasTrait<OpTrait::HasRecursiveMemoryEffects>()) | ||
| return memEffects.onlyHasEffect<MemoryEffects::Read>(); | ||
| } else if (!op->hasTrait<OpTrait::HasRecursiveMemoryEffects>()) { | ||
| // Otherwise, if the op does not implement the memory effect interface and | ||
| // it does not have recursive side effects, then it cannot be known that the | ||
| // op is moveable. | ||
| return false; | ||
|
||
| } | ||
|
|
||
| // Recurse into the regions and ensure that all nested ops are memory effect | ||
| // free. | ||
| for (Region ®ion : op->getRegions()) | ||
| for (Operation &op : region.getOps()) | ||
| if (!hasOnlyReadEffect(&op)) | ||
| return false; | ||
| return true; | ||
| } | ||
|
|
||
| bool mlir::isMemoryEffectFree(Operation *op) { | ||
| if (auto memInterface = dyn_cast<MemoryEffectOpInterface>(op)) { | ||
| if (!memInterface.hasNoEffect()) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -56,48 +56,117 @@ static bool canBeHoisted(Operation *op, | |
| op, [&](OpOperand &operand) { return definedOutside(operand.get()); }); | ||
| } | ||
|
|
||
| static bool dependsOnGuarded(Operation *op, | ||
| function_ref<bool(OpOperand &)> condition) { | ||
| auto walkFn = [&](Operation *child) { | ||
| for (OpOperand &operand : child->getOpOperands()) { | ||
| if (!condition(operand)) | ||
| return WalkResult::interrupt(); | ||
| } | ||
| return WalkResult::advance(); | ||
| }; | ||
| return op->walk(walkFn).wasInterrupted(); | ||
| } | ||
|
|
||
| static bool dependsOnGuarded(Operation *op, | ||
| function_ref<bool(Value)> definedOutsideGuard) { | ||
| return dependsOnGuarded(op, [&](OpOperand &operand) { | ||
| return definedOutsideGuard(operand.get()); | ||
| }); | ||
| } | ||
|
|
||
| static bool loopSideEffectFreeOrHasOnlyReadEffect(Operation *loop) { | ||
| for (Region ®ion : loop->getRegions()) { | ||
| for (Block &block : region.getBlocks()) { | ||
| for (Operation &op : block.getOperations()) { | ||
| if (!isMemoryEffectFree(&op) && !hasOnlyReadEffect(&op)) | ||
| return false; | ||
cxy-1993 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| } | ||
| } | ||
| } | ||
| return true; | ||
| } | ||
|
|
||
| size_t mlir::moveLoopInvariantCode( | ||
| ArrayRef<Region *> regions, | ||
| function_ref<bool(Value, Region *)> isDefinedOutsideRegion, | ||
| function_ref<bool(Operation *, Region *)> shouldMoveOutOfRegion, | ||
| function_ref<void(Operation *, Region *)> moveOutOfRegion) { | ||
| function_ref<FailureOr<std::pair<Operation *, Region *>>()> wrapInGuard, | ||
| function_ref<void(Operation *, Region *)> moveOutOfRegion, | ||
| function_ref<LogicalResult()> unwrapGuard) { | ||
| size_t numMoved = 0; | ||
|
|
||
| for (Region *region : regions) { | ||
| LLVM_DEBUG(llvm::dbgs() << "Original loop:\n" | ||
| << *region->getParentOp() << "\n"); | ||
|
|
||
| auto loopSideEffectFreeOrHasOnlyReadSideEffect = | ||
| loopSideEffectFreeOrHasOnlyReadEffect(region->getParentOp()); | ||
|
|
||
| size_t numMovedWithoutGuard = 0; | ||
|
|
||
| FailureOr<std::pair<Operation *, Region *>> ifOpAndRegion = wrapInGuard(); | ||
| Region *loopRegion = region; | ||
| auto isLoopWrapped = false; | ||
| if (succeeded(ifOpAndRegion)) { | ||
| loopRegion = ifOpAndRegion->second; | ||
| isLoopWrapped = true; | ||
| } | ||
|
|
||
| std::queue<Operation *> worklist; | ||
| // Add top-level operations in the loop body to the worklist. | ||
| for (Operation &op : region->getOps()) | ||
| for (Operation &op : loopRegion->getOps()) | ||
| worklist.push(&op); | ||
|
|
||
| auto definedOutside = [&](Value value) { | ||
| return isDefinedOutsideRegion(value, region); | ||
| return isDefinedOutsideRegion(value, loopRegion); | ||
| }; | ||
|
|
||
| auto definedOutsideGuard = [&](Value value) { | ||
| return isDefinedOutsideRegion(value, loopRegion->getParentRegion()); | ||
| }; | ||
|
|
||
| while (!worklist.empty()) { | ||
| Operation *op = worklist.front(); | ||
| worklist.pop(); | ||
| // Skip ops that have already been moved. Check if the op can be hoisted. | ||
| if (op->getParentRegion() != region) | ||
| if (op->getParentRegion() != loopRegion) | ||
| continue; | ||
|
|
||
| LLVM_DEBUG(llvm::dbgs() << "Checking op: " << *op << "\n"); | ||
| if (!shouldMoveOutOfRegion(op, region) || | ||
|
|
||
| if (!shouldMoveOutOfRegion(op, loopRegion) || | ||
| !canBeHoisted(op, definedOutside)) | ||
| continue; | ||
| // Can only hoist pure ops (side-effect free) when there is an op with | ||
| // write and/or unknown side effects in the loop. | ||
| if (!loopSideEffectFreeOrHasOnlyReadSideEffect && !isMemoryEffectFree(op)) | ||
| continue; | ||
|
|
||
| LLVM_DEBUG(llvm::dbgs() << "Moving loop-invariant op: " << *op << "\n"); | ||
| moveOutOfRegion(op, region); | ||
|
|
||
| auto moveWithoutGuard = isMemoryEffectFree(op) && | ||
|
||
| !dependsOnGuarded(op, definedOutsideGuard) && | ||
| isLoopWrapped; | ||
| numMovedWithoutGuard += moveWithoutGuard; | ||
|
|
||
| moveOutOfRegion(op, moveWithoutGuard ? loopRegion->getParentRegion() | ||
| : loopRegion); | ||
| ++numMoved; | ||
|
|
||
| // Since the op has been moved, we need to check its users within the | ||
| // top-level of the loop body. | ||
| for (Operation *user : op->getUsers()) | ||
| if (user->getParentRegion() == region) | ||
| if (user->getParentRegion() == loopRegion) | ||
| worklist.push(user); | ||
| } | ||
|
|
||
| // Unwrap the loop if it was wrapped but no ops were moved in the guard. | ||
| if (isLoopWrapped && numMovedWithoutGuard == numMoved) { | ||
| auto tripCountCheckUnwrapped = unwrapGuard(); | ||
| if (failed(tripCountCheckUnwrapped)) | ||
| llvm_unreachable("Should not fail unwrapping trip-count check"); | ||
| } | ||
|
||
| } | ||
|
|
||
| return numMoved; | ||
|
|
@@ -106,13 +175,18 @@ size_t mlir::moveLoopInvariantCode( | |
| size_t mlir::moveLoopInvariantCode(LoopLikeOpInterface loopLike) { | ||
| return moveLoopInvariantCode( | ||
| loopLike.getLoopRegions(), | ||
| [&](Value value, Region *) { | ||
| return loopLike.isDefinedOutsideOfLoop(value); | ||
| [&](Value value, Region *region) { | ||
| return !region->isAncestor(value.getParentRegion()); | ||
| }, | ||
| [&](Operation *op, Region *) { | ||
| return isMemoryEffectFree(op) && isSpeculatable(op); | ||
| return isSpeculatable(op) && | ||
| (isMemoryEffectFree(op) || hasOnlyReadEffect(op)); | ||
ardaunal marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| }, | ||
| [&]() { return loopLike.wrapInTripCountCheck(); }, | ||
| [&](Operation *op, Region *region) { | ||
| op->moveBefore(region->getParentOp()); | ||
| }, | ||
| [&](Operation *op, Region *) { loopLike.moveOutOfLoop(op); }); | ||
| [&]() { return loopLike.unwrapTripCountCheck(); }); | ||
| } | ||
|
|
||
| namespace { | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Strip blank lines