Skip to content

Commit ec46002

Browse files
committed
make return value more structured.
1 parent 6e779e6 commit ec46002

File tree

3 files changed

+21
-18
lines changed

3 files changed

+21
-18
lines changed

mlir/include/mlir/Dialect/SCF/Utils/Utils.h

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -111,13 +111,18 @@ LogicalResult coalescePerfectlyNestedSCFForLoops(scf::ForOp op);
111111
void collapseParallelLoops(RewriterBase &rewriter, scf::ParallelOp loops,
112112
ArrayRef<std::vector<unsigned>> combinedDimensions);
113113

114+
struct UnrolledLoopInfo {
115+
scf::ForOp mainLoopOp;
116+
scf::ForOp epilogueLoopOp;
117+
};
118+
114119
/// Unrolls this for operation by the specified unroll factor. Returns the
115-
/// unrolled main loop and the eplilog loop in sequence, if the loop is
116-
/// unrolled. Otherwise returns an empty vector if the loop cannot be unrolled
117-
/// either due to restrictions or due to invalid unroll factors. Requires
118-
/// positive loop bounds and step. If specified, annotates the Ops in each
119-
/// unrolled iteration by applying `annotateFn`.
120-
SmallVector<scf::ForOp> loopUnrollByFactor(
120+
/// unrolled main loop and the eplilog loop, if the loop is unrolled. Otherwise
121+
/// returns a strucutre of null fields if the loop cannot be unrolled either due
122+
/// to restrictions or due to invalid unroll factors. Requires positive loop
123+
/// bounds and step. If specified, annotates the Ops in each unrolled iteration
124+
/// by applying `annotateFn`.
125+
UnrolledLoopInfo loopUnrollByFactor(
121126
scf::ForOp forOp, uint64_t unrollFactor,
122127
function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn = nullptr);
123128

mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -355,9 +355,8 @@ transform::LoopUnrollOp::applyToOne(transform::TransformRewriter &rewriter,
355355
LogicalResult result(failure());
356356
if (scf::ForOp scfFor = dyn_cast<scf::ForOp>(op)) {
357357
auto resultLoops = loopUnrollByFactor(scfFor, getFactor());
358-
result = resultLoops.empty() ? failure() : success();
359-
}
360-
else if (AffineForOp affineFor = dyn_cast<AffineForOp>(op))
358+
result = resultLoops.mainLoopOp ? success() : failure();
359+
} else if (AffineForOp affineFor = dyn_cast<AffineForOp>(op))
361360
result = loopUnrollByFactor(affineFor, getFactor());
362361
else
363362
return emitSilenceableError()

mlir/lib/Dialect/SCF/Utils/Utils.cpp

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -373,16 +373,15 @@ static void generateUnrolledLoop(
373373
}
374374

375375
/// Unrolls 'forOp' by 'unrollFactor', returns the unrolled main loop and the
376-
/// eplilog loop in sequence, if the loop is unrolled. Otherwise return an empty
377-
/// vector.
378-
SmallVector<scf::ForOp> mlir::loopUnrollByFactor(
376+
/// eplilog loop, if the loop is unrolled. Otherwise return null.
377+
UnrolledLoopInfo mlir::loopUnrollByFactor(
379378
scf::ForOp forOp, uint64_t unrollFactor,
380379
function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn) {
381380
assert(unrollFactor > 0 && "expected positive unroll factor");
382381

383382
// Return if the loop body is empty.
384383
if (llvm::hasSingleElement(forOp.getBody()->getOperations()))
385-
return {forOp};
384+
return {forOp, nullptr};
386385

387386
// Compute tripCount = ceilDiv((upperBound - lowerBound), step) and populate
388387
// 'upperBoundUnrolled' and 'stepUnrolled' for static and dynamic cases.
@@ -403,8 +402,8 @@ SmallVector<scf::ForOp> mlir::loopUnrollByFactor(
403402
if (unrollFactor == 1) {
404403
if (*constTripCount == 1 &&
405404
failed(forOp.promoteIfSingleIteration(rewriter)))
406-
return {};
407-
return {forOp};
405+
return {nullptr, nullptr};
406+
return {forOp, nullptr};
408407
}
409408

410409
int64_t tripCountEvenMultiple =
@@ -452,8 +451,7 @@ SmallVector<scf::ForOp> mlir::loopUnrollByFactor(
452451
boundsBuilder.create<arith::MulIOp>(loc, step, unrollFactorCst);
453452
}
454453

455-
SmallVector<scf::ForOp, 2> resultLoops;
456-
resultLoops.push_back(forOp);
454+
UnrolledLoopInfo resultLoops;
457455

458456
// Create epilogue clean up loop starting at 'upperBoundUnrolled'.
459457
if (generateEpilogueLoop) {
@@ -473,7 +471,7 @@ SmallVector<scf::ForOp> mlir::loopUnrollByFactor(
473471
epilogueForOp->setOperands(epilogueForOp.getNumControlOperands(),
474472
epilogueForOp.getInitArgs().size(), results);
475473
(void)epilogueForOp.promoteIfSingleIteration(rewriter);
476-
resultLoops.push_back(epilogueForOp);
474+
resultLoops.epilogueLoopOp = epilogueForOp;
477475
}
478476

479477
// Create unrolled loop.
@@ -496,6 +494,7 @@ SmallVector<scf::ForOp> mlir::loopUnrollByFactor(
496494
annotateFn, iterArgs, yieldedValues);
497495
// Promote the loop body up if this has turned into a single iteration loop.
498496
(void)forOp.promoteIfSingleIteration(rewriter);
497+
resultLoops.mainLoopOp = forOp;
499498
return resultLoops;
500499
}
501500

0 commit comments

Comments
 (0)