Skip to content

Commit e665f24

Browse files
[mlir] Delete unroll-full option for Affine/SCF unroll pass (llvm#164658)
Make the unroll-factor take -1 as "full" and avoid potential conflict when passing both an explicit factor and unroll-full=true.
1 parent b08bbe5 commit e665f24

File tree

6 files changed

+28
-26
lines changed

6 files changed

+28
-26
lines changed

mlir/include/mlir/Dialect/Affine/Passes.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,6 @@ std::unique_ptr<OperationPass<func::FuncOp>> createLoopTilingPass();
106106
/// all) or the default unroll factor is used (LoopUnroll:kDefaultUnrollFactor).
107107
std::unique_ptr<InterfacePass<FunctionOpInterface>> createLoopUnrollPass(
108108
int unrollFactor = -1, bool unrollUpToFactor = false,
109-
bool unrollFull = false,
110109
const std::function<unsigned(AffineForOp)> &getUnrollFactor = nullptr);
111110

112111
/// Creates a loop unroll jam pass to unroll jam by the specified factor. A

mlir/include/mlir/Dialect/Affine/Passes.td

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -203,12 +203,10 @@ def AffineLoopUnroll : InterfacePass<"affine-loop-unroll", "FunctionOpInterface"
203203
let summary = "Unroll affine loops";
204204
let constructor = "mlir::affine::createLoopUnrollPass()";
205205
let options = [
206-
Option<"unrollFactor", "unroll-factor", "unsigned", /*default=*/"4",
206+
Option<"unrollFactor", "unroll-factor", "int64_t", /*default=*/"4",
207207
"Use this unroll factor for all loops being unrolled">,
208208
Option<"unrollUpToFactor", "unroll-up-to-factor", "bool",
209209
/*default=*/"false", "Allow unrolling up to the factor specified">,
210-
Option<"unrollFull", "unroll-full", "bool", /*default=*/"false",
211-
"Fully unroll loops">,
212210
Option<"numRepetitions", "unroll-num-reps", "unsigned", /*default=*/"1",
213211
"Unroll innermost loops repeatedly this many times">,
214212
Option<"unrollFullThreshold", "unroll-full-threshold", "unsigned",

mlir/lib/Dialect/Affine/Transforms/LoopUnroll.cpp

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -45,18 +45,15 @@ struct LoopUnroll : public affine::impl::AffineLoopUnrollBase<LoopUnroll> {
4545
const std::function<unsigned(AffineForOp)> getUnrollFactor;
4646

4747
LoopUnroll() : getUnrollFactor(nullptr) {}
48-
LoopUnroll(const LoopUnroll &other)
49-
50-
= default;
48+
LoopUnroll(const LoopUnroll &other) = default;
5149
explicit LoopUnroll(
5250
std::optional<unsigned> unrollFactor = std::nullopt,
53-
bool unrollUpToFactor = false, bool unrollFull = false,
51+
bool unrollUpToFactor = false,
5452
const std::function<unsigned(AffineForOp)> &getUnrollFactor = nullptr)
5553
: getUnrollFactor(getUnrollFactor) {
5654
if (unrollFactor)
5755
this->unrollFactor = *unrollFactor;
5856
this->unrollUpToFactor = unrollUpToFactor;
59-
this->unrollFull = unrollFull;
6057
}
6158

6259
void runOnOperation() override;
@@ -85,11 +82,17 @@ static void gatherInnermostLoops(FunctionOpInterface f,
8582
}
8683

8784
void LoopUnroll::runOnOperation() {
85+
if (!(unrollFactor.getValue() > 0 || unrollFactor.getValue() == -1)) {
86+
emitError(UnknownLoc::get(&getContext()),
87+
"Invalid option: 'unroll-factor' should be greater than 0 or "
88+
"equal to -1");
89+
return signalPassFailure();
90+
}
8891
FunctionOpInterface func = getOperation();
8992
if (func.isExternal())
9093
return;
9194

92-
if (unrollFull && unrollFullThreshold.hasValue()) {
95+
if (unrollFactor.getValue() == -1 && unrollFullThreshold.hasValue()) {
9396
// Store short loops as we walk.
9497
SmallVector<AffineForOp, 4> loops;
9598

@@ -130,7 +133,7 @@ LogicalResult LoopUnroll::runOnAffineForOp(AffineForOp forOp) {
130133
return loopUnrollByFactor(forOp, getUnrollFactor(forOp),
131134
/*annotateFn=*/nullptr, cleanUpUnroll);
132135
// Unroll completely if full loop unroll was specified.
133-
if (unrollFull)
136+
if (unrollFactor.getValue() == -1)
134137
return loopUnrollFull(forOp);
135138
// Otherwise, unroll by the given unroll factor.
136139
if (unrollUpToFactor)
@@ -141,9 +144,9 @@ LogicalResult LoopUnroll::runOnAffineForOp(AffineForOp forOp) {
141144

142145
std::unique_ptr<InterfacePass<FunctionOpInterface>>
143146
mlir::affine::createLoopUnrollPass(
144-
int unrollFactor, bool unrollUpToFactor, bool unrollFull,
147+
int unrollFactor, bool unrollUpToFactor,
145148
const std::function<unsigned(AffineForOp)> &getUnrollFactor) {
146149
return std::make_unique<LoopUnroll>(
147150
unrollFactor == -1 ? std::nullopt : std::optional<unsigned>(unrollFactor),
148-
unrollUpToFactor, unrollFull, getUnrollFactor);
151+
unrollUpToFactor, getUnrollFactor);
149152
}

mlir/test/Dialect/Affine/unroll.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline="builtin.module(func.func(affine-loop-unroll{unroll-full=true}))" | FileCheck %s --check-prefix UNROLL-FULL
2-
// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline="builtin.module(func.func(affine-loop-unroll{unroll-full=true unroll-full-threshold=2}))" | FileCheck %s --check-prefix SHORT
1+
// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline="builtin.module(func.func(affine-loop-unroll{unroll-factor=-1}))" | FileCheck %s --check-prefix UNROLL-FULL
2+
// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline="builtin.module(func.func(affine-loop-unroll{unroll-factor=-1 unroll-full-threshold=2}))" | FileCheck %s --check-prefix SHORT
33
// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline="builtin.module(func.func(affine-loop-unroll{unroll-factor=4}))" | FileCheck %s --check-prefix UNROLL-BY-4
44
// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline="builtin.module(func.func(affine-loop-unroll{unroll-factor=1}))" | FileCheck %s --check-prefix UNROLL-BY-1
55
// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline="builtin.module(func.func(affine-loop-unroll{unroll-factor=5 cleanup-unroll=true}))" | FileCheck %s --check-prefix UNROLL-CLEANUP-LOOP
6-
// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline="builtin.module(gpu.module(gpu.func(affine-loop-unroll{unroll-full=true})))" | FileCheck %s --check-prefix GPU-UNROLL-FULL
6+
// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline="builtin.module(gpu.module(gpu.func(affine-loop-unroll{unroll-factor=-1})))" | FileCheck %s --check-prefix GPU-UNROLL-FULL
77

88
// UNROLL-FULL-DAG: [[$MAP0:#map[0-9]*]] = affine_map<(d0) -> (d0 + 1)>
99
// UNROLL-FULL-DAG: [[$MAP1:#map[0-9]*]] = affine_map<(d0) -> (d0 + 2)>

mlir/test/Transforms/scf-loop-unroll.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
// RUN: mlir-opt %s --test-loop-unrolling="unroll-factor=3" -split-input-file -canonicalize | FileCheck %s
22
// RUN: mlir-opt %s --test-loop-unrolling="unroll-factor=1" -split-input-file -canonicalize | FileCheck %s --check-prefix UNROLL-BY-1
3-
// RUN: mlir-opt %s --test-loop-unrolling="unroll-full=true" -split-input-file -canonicalize | FileCheck %s --check-prefix UNROLL-FULL
3+
// RUN: mlir-opt %s --test-loop-unrolling="unroll-factor=-1" -split-input-file -canonicalize | FileCheck %s --check-prefix UNROLL-FULL
44

55
// CHECK-LABEL: scf_loop_unroll_single
66
func.func @scf_loop_unroll_single(%arg0 : f32, %arg1 : f32) -> f32 {

mlir/test/lib/Dialect/SCF/TestLoopUnrolling.cpp

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -42,18 +42,23 @@ struct TestLoopUnrollingPass
4242
TestLoopUnrollingPass(const TestLoopUnrollingPass &) {}
4343
explicit TestLoopUnrollingPass(uint64_t unrollFactorParam,
4444
unsigned loopDepthParam,
45-
bool annotateLoopParam, bool unrollFullParam) {
45+
bool annotateLoopParam) {
4646
unrollFactor = unrollFactorParam;
4747
loopDepth = loopDepthParam;
4848
annotateLoop = annotateLoopParam;
49-
unrollFull = unrollFactorParam;
5049
}
5150

5251
void getDependentDialects(DialectRegistry &registry) const override {
5352
registry.insert<arith::ArithDialect>();
5453
}
5554

5655
void runOnOperation() override {
56+
if (!(unrollFactor.getValue() > 0 || unrollFactor.getValue() == -1)) {
57+
emitError(UnknownLoc::get(&getContext()),
58+
"Invalid option: 'unroll-factor' should be greater than 0 or "
59+
"equal to -1");
60+
return signalPassFailure();
61+
}
5762
SmallVector<scf::ForOp, 4> loops;
5863
getOperation()->walk([&](scf::ForOp forOp) {
5964
if (getNestingDepth(forOp) == loopDepth)
@@ -65,15 +70,15 @@ struct TestLoopUnrollingPass
6570
}
6671
};
6772
for (auto loop : loops) {
68-
if (unrollFull)
73+
if (unrollFactor.getValue() == -1)
6974
(void)loopUnrollFull(loop);
7075
else
7176
(void)loopUnrollByFactor(loop, unrollFactor, annotateFn);
7277
}
7378
}
74-
Option<uint64_t> unrollFactor{*this, "unroll-factor",
75-
llvm::cl::desc("Loop unroll factor."),
76-
llvm::cl::init(1)};
79+
Option<int64_t> unrollFactor{*this, "unroll-factor",
80+
llvm::cl::desc("Loop unroll factor."),
81+
llvm::cl::init(1)};
7782
Option<bool> annotateLoop{*this, "annotate",
7883
llvm::cl::desc("Annotate unrolled iterations."),
7984
llvm::cl::init(false)};
@@ -82,9 +87,6 @@ struct TestLoopUnrollingPass
8287
llvm::cl::init(false)};
8388
Option<unsigned> loopDepth{*this, "loop-depth", llvm::cl::desc("Loop depth."),
8489
llvm::cl::init(0)};
85-
Option<bool> unrollFull{*this, "unroll-full",
86-
llvm::cl::desc("Full unroll loops."),
87-
llvm::cl::init(false)};
8890
};
8991
} // namespace
9092

0 commit comments

Comments
 (0)