Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions mlir/include/mlir/Dialect/SCF/Utils/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,9 @@ FailureOr<UnrolledLoopInfo> loopUnrollByFactor(
scf::ForOp forOp, uint64_t unrollFactor,
function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn = nullptr);

/// Unrolls this loop completely.
LogicalResult loopUnrollFull(scf::ForOp forOp);

/// Unrolls and jams this `scf.for` operation by the specified unroll factor.
/// Returns failure if the loop cannot be unrolled either due to restrictions or
/// due to invalid unroll factors. In case of unroll factor of 1, the function
Expand Down
15 changes: 15 additions & 0 deletions mlir/lib/Dialect/SCF/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -498,6 +498,21 @@ FailureOr<UnrolledLoopInfo> mlir::loopUnrollByFactor(
return resultLoops;
}

/// Unrolls this loop completely.
LogicalResult mlir::loopUnrollFull(scf::ForOp forOp) {
IRRewriter rewriter(forOp.getContext());
std::optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp);
if (mayBeConstantTripCount.has_value()) {
uint64_t tripCount = *mayBeConstantTripCount;
if (tripCount == 0)
return success();
if (tripCount == 1)
return forOp.promoteIfSingleIteration(rewriter);
return loopUnrollByFactor(forOp, tripCount);
}
return failure();
}

/// Check if bounds of all inner loops are defined outside of `forOp`
/// and return false if not.
static bool areInnerBoundsInvariant(scf::ForOp forOp) {
Expand Down
57 changes: 57 additions & 0 deletions mlir/test/Transforms/scf-loop-unroll.mlir
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
// RUN: mlir-opt %s --test-loop-unrolling="unroll-factor=3" -split-input-file -canonicalize | FileCheck %s
// RUN: mlir-opt %s --test-loop-unrolling="unroll-factor=1" -split-input-file -canonicalize | FileCheck %s --check-prefix UNROLL-BY-1
// RUN: mlir-opt %s --test-loop-unrolling="unroll-full=true" -split-input-file -canonicalize | FileCheck %s --check-prefix UNROLL-FULL

// CHECK-LABEL: scf_loop_unroll_single
func.func @scf_loop_unroll_single(%arg0 : f32, %arg1 : f32) -> f32 {
Expand Down Expand Up @@ -56,3 +57,59 @@ func.func @scf_loop_unroll_factor_1_promote() -> () {
// UNROLL-BY-1-NEXT: %[[C0:.*]] = arith.constant 0 : index
// UNROLL-BY-1-NEXT: %{{.*}} = "test.foo"(%[[C0]]) : (index) -> i32
}

// UNROLL-FULL-LABEL: func @scf_loop_unroll_full_single(
// UNROLL-FULL-SAME: %[[VAL_0:.*]]: index) -> index {
func.func @scf_loop_unroll_full_single(%arg : index) -> index {
%0 = arith.constant 0 : index
%1 = arith.constant 1 : index
%2 = arith.constant 4 : index
%4 = scf.for %iv = %0 to %2 step %1 iter_args(%arg1 = %1) -> index {
%3 = arith.addi %arg1, %arg : index
scf.yield %3 : index
}
return %4 : index
// UNROLL-FULL: %[[VAL_1:.*]] = arith.constant 1 : index
// UNROLL-FULL: %[[VAL_2:.*]] = arith.addi %[[VAL_0]], %[[VAL_1]] : index
// UNROLL-FULL: %[[VAL_3:.*]] = arith.addi %[[VAL_2]], %[[VAL_0]] : index
// UNROLL-FULL: %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_0]] : index
// UNROLL-FULL: %[[VAL_5:.*]] = arith.addi %[[VAL_4]], %[[VAL_0]] : index
// UNROLL-FULL: return %[[VAL_5]] : index
}

// UNROLL-FULL-LABEL: func @scf_loop_unroll_full_outter_loops(
// UNROLL-FULL-SAME: %[[VAL_0:.*]]: vector<4x4xindex>) -> index {
func.func @scf_loop_unroll_full_outter_loops(%arg0: vector<4x4xindex>) -> index {
%0 = arith.constant 0 : index
%1 = arith.constant 1 : index
%2 = arith.constant 4 : index
%6 = scf.for %arg1 = %0 to %2 step %1 iter_args(%it0 = %0) -> index {
%5 = scf.for %arg2 = %0 to %2 step %1 iter_args(%it1 = %it0) -> index {
%3 = vector.extract %arg0[%arg1, %arg2] : index from vector<4x4xindex>
%4 = arith.addi %3, %it1 : index
scf.yield %3 : index
}
scf.yield %5 : index
}
return %6 : index
// UNROLL-FULL: %[[VAL_1:.*]] = arith.constant 0 : index
// UNROLL-FULL: %[[VAL_2:.*]] = arith.constant 1 : index
// UNROLL-FULL: %[[VAL_3:.*]] = arith.constant 4 : index
// UNROLL-FULL: %[[VAL_4:.*]] = scf.for %[[VAL_5:.*]] = %[[VAL_1]] to %[[VAL_3]] step %[[VAL_2]] iter_args(%[[VAL_6:.*]] = %[[VAL_1]]) -> (index) {
// UNROLL-FULL: %[[VAL_7:.*]] = vector.extract %[[VAL_0]][0, %[[VAL_5]]] : index from vector<4x4xindex>
// UNROLL-FULL: scf.yield %[[VAL_7]] : index
// UNROLL-FULL: }
// UNROLL-FULL: %[[VAL_8:.*]] = scf.for %[[VAL_9:.*]] = %[[VAL_1]] to %[[VAL_3]] step %[[VAL_2]] iter_args(%[[VAL_10:.*]] = %[[VAL_4]]) -> (index) {
// UNROLL-FULL: %[[VAL_11:.*]] = vector.extract %[[VAL_0]][1, %[[VAL_9]]] : index from vector<4x4xindex>
// UNROLL-FULL: scf.yield %[[VAL_11]] : index
// UNROLL-FULL: }
// UNROLL-FULL: %[[VAL_12:.*]] = scf.for %[[VAL_13:.*]] = %[[VAL_1]] to %[[VAL_3]] step %[[VAL_2]] iter_args(%[[VAL_14:.*]] = %[[VAL_8]]) -> (index) {
// UNROLL-FULL: %[[VAL_15:.*]] = vector.extract %[[VAL_0]][2, %[[VAL_13]]] : index from vector<4x4xindex>
// UNROLL-FULL: scf.yield %[[VAL_15]] : index
// UNROLL-FULL: }
// UNROLL-FULL: %[[VAL_16:.*]] = scf.for %[[VAL_17:.*]] = %[[VAL_1]] to %[[VAL_3]] step %[[VAL_2]] iter_args(%[[VAL_18:.*]] = %[[VAL_12]]) -> (index) {
// UNROLL-FULL: %[[VAL_19:.*]] = vector.extract %[[VAL_0]][3, %[[VAL_17]]] : index from vector<4x4xindex>
// UNROLL-FULL: scf.yield %[[VAL_19]] : index
// UNROLL-FULL: }
// UNROLL-FULL: return %[[VAL_16]] : index
}
14 changes: 11 additions & 3 deletions mlir/test/lib/Dialect/SCF/TestLoopUnrolling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,11 @@ struct TestLoopUnrollingPass
TestLoopUnrollingPass(const TestLoopUnrollingPass &) {}
explicit TestLoopUnrollingPass(uint64_t unrollFactorParam,
unsigned loopDepthParam,
bool annotateLoopParam) {
bool annotateLoopParam, bool unrollFullParam) {
unrollFactor = unrollFactorParam;
loopDepth = loopDepthParam;
annotateLoop = annotateLoopParam;
unrollFull = unrollFactorParam;
}

void getDependentDialects(DialectRegistry &registry) const override {
Expand All @@ -63,8 +64,12 @@ struct TestLoopUnrollingPass
op->setAttr("unrolled_iteration", b.getUI32IntegerAttr(i));
}
};
for (auto loop : loops)
(void)loopUnrollByFactor(loop, unrollFactor, annotateFn);
for (auto loop : loops) {
if (unrollFull)
loopUnrollFull(loop);
else
(void)loopUnrollByFactor(loop, unrollFactor, annotateFn);
}
}
Option<uint64_t> unrollFactor{*this, "unroll-factor",
llvm::cl::desc("Loop unroll factor."),
Expand All @@ -77,6 +82,9 @@ struct TestLoopUnrollingPass
llvm::cl::init(false)};
Option<unsigned> loopDepth{*this, "loop-depth", llvm::cl::desc("Loop depth."),
llvm::cl::init(0)};
Option<bool> unrollFull{*this, "unroll-full",
llvm::cl::desc("Full unroll loops."),
llvm::cl::init(false)};
};
} // namespace

Expand Down