Skip to content

Commit a6abc7a

Browse files
committed
Changed return type to FailureOr<UnrolledLoopInfo>
1 parent ec46002 commit a6abc7a

File tree

3 files changed

+19
-20
lines changed

3 files changed

+19
-20
lines changed

mlir/include/mlir/Dialect/SCF/Utils/Utils.h

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -112,17 +112,17 @@ void collapseParallelLoops(RewriterBase &rewriter, scf::ParallelOp loops,
112112
ArrayRef<std::vector<unsigned>> combinedDimensions);
113113

114114
struct UnrolledLoopInfo {
115-
scf::ForOp mainLoopOp;
116-
scf::ForOp epilogueLoopOp;
115+
scf::ForOp mainLoopOp = nullptr;
116+
scf::ForOp epilogueLoopOp = nullptr;
117117
};
118118

119119
/// Unrolls this for operation by the specified unroll factor. Returns the
120120
/// unrolled main loop and the eplilog loop, if the loop is unrolled. Otherwise
121-
/// returns a strucutre of null fields if the loop cannot be unrolled either due
122-
/// to restrictions or due to invalid unroll factors. Requires positive loop
123-
/// bounds and step. If specified, annotates the Ops in each unrolled iteration
124-
/// by applying `annotateFn`.
125-
UnrolledLoopInfo loopUnrollByFactor(
121+
/// returns failure if the loop cannot be unrolled either due to restrictions or
122+
/// due to invalid unroll factors. Requires positive loop bounds and step. If
123+
/// specified, annotates the Ops in each unrolled iteration by applying
124+
/// `annotateFn`.
125+
FailureOr<UnrolledLoopInfo> loopUnrollByFactor(
126126
scf::ForOp forOp, uint64_t unrollFactor,
127127
function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn = nullptr);
128128

mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -353,10 +353,9 @@ transform::LoopUnrollOp::applyToOne(transform::TransformRewriter &rewriter,
353353
transform::ApplyToEachResultList &results,
354354
transform::TransformState &state) {
355355
LogicalResult result(failure());
356-
if (scf::ForOp scfFor = dyn_cast<scf::ForOp>(op)) {
357-
auto resultLoops = loopUnrollByFactor(scfFor, getFactor());
358-
result = resultLoops.mainLoopOp ? success() : failure();
359-
} else if (AffineForOp affineFor = dyn_cast<AffineForOp>(op))
356+
if (scf::ForOp scfFor = dyn_cast<scf::ForOp>(op))
357+
result = loopUnrollByFactor(scfFor, getFactor());
358+
else if (AffineForOp affineFor = dyn_cast<AffineForOp>(op))
360359
result = loopUnrollByFactor(affineFor, getFactor());
361360
else
362361
return emitSilenceableError()

mlir/lib/Dialect/SCF/Utils/Utils.cpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -373,15 +373,15 @@ static void generateUnrolledLoop(
373373
}
374374

375375
/// Unrolls 'forOp' by 'unrollFactor', returns the unrolled main loop and the
376-
/// eplilog loop, if the loop is unrolled. Otherwise return null.
377-
UnrolledLoopInfo mlir::loopUnrollByFactor(
376+
/// eplilog loop, if the loop is unrolled.
377+
FailureOr<UnrolledLoopInfo> mlir::loopUnrollByFactor(
378378
scf::ForOp forOp, uint64_t unrollFactor,
379379
function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn) {
380380
assert(unrollFactor > 0 && "expected positive unroll factor");
381381

382382
// Return if the loop body is empty.
383383
if (llvm::hasSingleElement(forOp.getBody()->getOperations()))
384-
return {forOp, nullptr};
384+
return UnrolledLoopInfo{forOp, nullptr};
385385

386386
// Compute tripCount = ceilDiv((upperBound - lowerBound), step) and populate
387387
// 'upperBoundUnrolled' and 'stepUnrolled' for static and dynamic cases.
@@ -402,8 +402,8 @@ UnrolledLoopInfo mlir::loopUnrollByFactor(
402402
if (unrollFactor == 1) {
403403
if (*constTripCount == 1 &&
404404
failed(forOp.promoteIfSingleIteration(rewriter)))
405-
return {nullptr, nullptr};
406-
return {forOp, nullptr};
405+
return failure();
406+
return UnrolledLoopInfo{forOp, nullptr};
407407
}
408408

409409
int64_t tripCountEvenMultiple =
@@ -470,8 +470,8 @@ UnrolledLoopInfo mlir::loopUnrollByFactor(
470470
}
471471
epilogueForOp->setOperands(epilogueForOp.getNumControlOperands(),
472472
epilogueForOp.getInitArgs().size(), results);
473-
(void)epilogueForOp.promoteIfSingleIteration(rewriter);
474-
resultLoops.epilogueLoopOp = epilogueForOp;
473+
if (epilogueForOp.promoteIfSingleIteration(rewriter).failed())
474+
resultLoops.epilogueLoopOp = epilogueForOp;
475475
}
476476

477477
// Create unrolled loop.
@@ -493,8 +493,8 @@ UnrolledLoopInfo mlir::loopUnrollByFactor(
493493
},
494494
annotateFn, iterArgs, yieldedValues);
495495
// Promote the loop body up if this has turned into a single iteration loop.
496-
(void)forOp.promoteIfSingleIteration(rewriter);
497-
resultLoops.mainLoopOp = forOp;
496+
if (forOp.promoteIfSingleIteration(rewriter).failed())
497+
resultLoops.mainLoopOp = forOp;
498498
return resultLoops;
499499
}
500500

0 commit comments

Comments
 (0)