Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions mlir/include/mlir/Dialect/SCF/Utils/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,11 +111,13 @@ LogicalResult coalescePerfectlyNestedSCFForLoops(scf::ForOp op);
void collapseParallelLoops(RewriterBase &rewriter, scf::ParallelOp loops,
ArrayRef<std::vector<unsigned>> combinedDimensions);

/// Unrolls this for operation by the specified unroll factor. Returns failure
/// if the loop cannot be unrolled either due to restrictions or due to invalid
/// unroll factors. Requires positive loop bounds and step. If specified,
/// annotates the Ops in each unrolled iteration by applying `annotateFn`.
LogicalResult loopUnrollByFactor(
/// Unrolls this for operation by the specified unroll factor. Returns the
/// unrolled main loop and the eplilog loop in sequence, if the loop is
/// unrolled. Otherwise returns an empty vector if the loop cannot be unrolled
/// either due to restrictions or due to invalid unroll factors. Requires
/// positive loop bounds and step. If specified, annotates the Ops in each
/// unrolled iteration by applying `annotateFn`.
SmallVector<scf::ForOp> loopUnrollByFactor(
scf::ForOp forOp, uint64_t unrollFactor,
function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn = nullptr);

Expand Down
6 changes: 4 additions & 2 deletions mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -353,8 +353,10 @@ transform::LoopUnrollOp::applyToOne(transform::TransformRewriter &rewriter,
transform::ApplyToEachResultList &results,
transform::TransformState &state) {
LogicalResult result(failure());
if (scf::ForOp scfFor = dyn_cast<scf::ForOp>(op))
result = loopUnrollByFactor(scfFor, getFactor());
if (scf::ForOp scfFor = dyn_cast<scf::ForOp>(op)) {
auto resultLoops = loopUnrollByFactor(scfFor, getFactor());
result = resultLoops.empty() ? failure() : success();
}
else if (AffineForOp affineFor = dyn_cast<AffineForOp>(op))
result = loopUnrollByFactor(affineFor, getFactor());
else
Expand Down
18 changes: 12 additions & 6 deletions mlir/lib/Dialect/SCF/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -372,15 +372,17 @@ static void generateUnrolledLoop(
loopBodyBlock->getTerminator()->setOperands(lastYielded);
}

/// Unrolls 'forOp' by 'unrollFactor', returns success if the loop is unrolled.
LogicalResult mlir::loopUnrollByFactor(
/// Unrolls 'forOp' by 'unrollFactor', returns the unrolled main loop and the
/// eplilog loop in sequence, if the loop is unrolled. Otherwise return an empty
/// vector.
SmallVector<scf::ForOp> mlir::loopUnrollByFactor(
scf::ForOp forOp, uint64_t unrollFactor,
function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn) {
assert(unrollFactor > 0 && "expected positive unroll factor");

// Return if the loop body is empty.
if (llvm::hasSingleElement(forOp.getBody()->getOperations()))
return success();
return {forOp};

// Compute tripCount = ceilDiv((upperBound - lowerBound), step) and populate
// 'upperBoundUnrolled' and 'stepUnrolled' for static and dynamic cases.
Expand All @@ -401,8 +403,8 @@ LogicalResult mlir::loopUnrollByFactor(
if (unrollFactor == 1) {
if (*constTripCount == 1 &&
failed(forOp.promoteIfSingleIteration(rewriter)))
return failure();
return success();
return {};
return {forOp};
}

int64_t tripCountEvenMultiple =
Expand Down Expand Up @@ -450,6 +452,9 @@ LogicalResult mlir::loopUnrollByFactor(
boundsBuilder.create<arith::MulIOp>(loc, step, unrollFactorCst);
}

SmallVector<scf::ForOp, 2> resultLoops;
resultLoops.push_back(forOp);

// Create epilogue clean up loop starting at 'upperBoundUnrolled'.
if (generateEpilogueLoop) {
OpBuilder epilogueBuilder(forOp->getContext());
Expand All @@ -468,6 +473,7 @@ LogicalResult mlir::loopUnrollByFactor(
epilogueForOp->setOperands(epilogueForOp.getNumControlOperands(),
epilogueForOp.getInitArgs().size(), results);
(void)epilogueForOp.promoteIfSingleIteration(rewriter);
resultLoops.push_back(epilogueForOp);
}

// Create unrolled loop.
Expand All @@ -490,7 +496,7 @@ LogicalResult mlir::loopUnrollByFactor(
annotateFn, iterArgs, yieldedValues);
// Promote the loop body up if this has turned into a single iteration loop.
(void)forOp.promoteIfSingleIteration(rewriter);
return success();
return resultLoops;
}

/// Check if bounds of all inner loops are defined outside of `forOp`
Expand Down
Loading