Skip to content

Commit 17c0715

Browse files
update impl and test.
1 parent 2b68fd1 commit 17c0715

File tree

3 files changed

+35
-36
lines changed

3 files changed

+35
-36
lines changed

mlir/lib/Dialect/SCF/Utils/Utils.cpp

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -502,15 +502,14 @@ FailureOr<UnrolledLoopInfo> mlir::loopUnrollByFactor(
502502
LogicalResult mlir::loopUnrollFull(scf::ForOp forOp) {
503503
IRRewriter rewriter(forOp.getContext());
504504
std::optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp);
505-
if (mayBeConstantTripCount.has_value()) {
506-
uint64_t tripCount = *mayBeConstantTripCount;
507-
if (tripCount == 0)
508-
return success();
509-
if (tripCount == 1)
510-
return forOp.promoteIfSingleIteration(rewriter);
511-
return loopUnrollByFactor(forOp, tripCount);
512-
}
513-
return failure();
505+
if (!mayBeConstantTripCount.has_value())
506+
return failure();
507+
uint64_t tripCount = *mayBeConstantTripCount;
508+
if (tripCount == 0)
509+
return success();
510+
if (tripCount == 1)
511+
return forOp.promoteIfSingleIteration(rewriter);
512+
return loopUnrollByFactor(forOp, tripCount);
514513
}
515514

516515
/// Check if bounds of all inner loops are defined outside of `forOp`

mlir/test/Transforms/scf-loop-unroll.mlir

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,8 @@ func.func @scf_loop_unroll_factor_1_promote() -> () {
5858
// UNROLL-BY-1-NEXT: %{{.*}} = "test.foo"(%[[C0]]) : (index) -> i32
5959
}
6060

61-
// UNROLL-FULL-LABEL: func @scf_loop_unroll_full_single(
62-
// UNROLL-FULL-SAME: %[[VAL_0:.*]]: index) -> index {
61+
// UNROLL-FULL-LABEL: func @scf_loop_unroll_full_single
62+
// UNROLL-FULL-SAME: %[[ARG:.*]]: index)
6363
func.func @scf_loop_unroll_full_single(%arg : index) -> index {
6464
%0 = arith.constant 0 : index
6565
%1 = arith.constant 1 : index
@@ -69,16 +69,16 @@ func.func @scf_loop_unroll_full_single(%arg : index) -> index {
6969
scf.yield %3 : index
7070
}
7171
return %4 : index
72-
// UNROLL-FULL: %[[VAL_1:.*]] = arith.constant 1 : index
73-
// UNROLL-FULL: %[[VAL_2:.*]] = arith.addi %[[VAL_0]], %[[VAL_1]] : index
74-
// UNROLL-FULL: %[[VAL_3:.*]] = arith.addi %[[VAL_2]], %[[VAL_0]] : index
75-
// UNROLL-FULL: %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_0]] : index
76-
// UNROLL-FULL: %[[VAL_5:.*]] = arith.addi %[[VAL_4]], %[[VAL_0]] : index
77-
// UNROLL-FULL: return %[[VAL_5]] : index
72+
// UNROLL-FULL: %[[C1:.*]] = arith.constant 1 : index
73+
// UNROLL-FULL: %[[V0:.*]] = arith.addi %[[ARG]], %[[C1]] : index
74+
// UNROLL-FULL: %[[V1:.*]] = arith.addi %[[V0]], %[[ARG]] : index
75+
// UNROLL-FULL: %[[V2:.*]] = arith.addi %[[V1]], %[[ARG]] : index
76+
// UNROLL-FULL: %[[V3:.*]] = arith.addi %[[V2]], %[[ARG]] : index
77+
// UNROLL-FULL: return %[[V3]] : index
7878
}
7979

80-
// UNROLL-FULL-LABEL: func @scf_loop_unroll_full_outter_loops(
81-
// UNROLL-FULL-SAME: %[[VAL_0:.*]]: vector<4x4xindex>) -> index {
80+
// UNROLL-FULL-LABEL: func @scf_loop_unroll_full_outter_loops
81+
// UNROLL-FULL-SAME: %[[ARG:.*]]: vector<4x4xindex>)
8282
func.func @scf_loop_unroll_full_outter_loops(%arg0: vector<4x4xindex>) -> index {
8383
%0 = arith.constant 0 : index
8484
%1 = arith.constant 1 : index
@@ -92,24 +92,24 @@ func.func @scf_loop_unroll_full_outter_loops(%arg0: vector<4x4xindex>) -> index
9292
scf.yield %5 : index
9393
}
9494
return %6 : index
95-
// UNROLL-FULL: %[[VAL_1:.*]] = arith.constant 0 : index
96-
// UNROLL-FULL: %[[VAL_2:.*]] = arith.constant 1 : index
97-
// UNROLL-FULL: %[[VAL_3:.*]] = arith.constant 4 : index
98-
// UNROLL-FULL: %[[VAL_4:.*]] = scf.for %[[VAL_5:.*]] = %[[VAL_1]] to %[[VAL_3]] step %[[VAL_2]] iter_args(%[[VAL_6:.*]] = %[[VAL_1]]) -> (index) {
99-
// UNROLL-FULL: %[[VAL_7:.*]] = vector.extract %[[VAL_0]][0, %[[VAL_5]]] : index from vector<4x4xindex>
100-
// UNROLL-FULL: scf.yield %[[VAL_7]] : index
95+
// UNROLL-FULL: %[[C0:.*]] = arith.constant 0 : index
96+
// UNROLL-FULL: %[[C1:.*]] = arith.constant 1 : index
97+
// UNROLL-FULL: %[[C4:.*]] = arith.constant 4 : index
98+
// UNROLL-FULL: %[[SUM0:.*]] = scf.for %[[IV:.*]] = %[[C0]] to %[[C4]] step %[[C1]] iter_args(%{{.*}} = %[[C0]])
99+
// UNROLL-FULL: %[[VAL:.*]] = vector.extract %[[ARG]][0, %[[IV]]] : index from vector<4x4xindex>
100+
// UNROLL-FULL: scf.yield %[[VAL]] : index
101101
// UNROLL-FULL: }
102-
// UNROLL-FULL: %[[VAL_8:.*]] = scf.for %[[VAL_9:.*]] = %[[VAL_1]] to %[[VAL_3]] step %[[VAL_2]] iter_args(%[[VAL_10:.*]] = %[[VAL_4]]) -> (index) {
103-
// UNROLL-FULL: %[[VAL_11:.*]] = vector.extract %[[VAL_0]][1, %[[VAL_9]]] : index from vector<4x4xindex>
104-
// UNROLL-FULL: scf.yield %[[VAL_11]] : index
102+
// UNROLL-FULL: %[[SUM1:.*]] = scf.for %[[IV:.*]] = %[[C0]] to %[[C4]] step %[[C1]] iter_args(%{{.*}} = %[[SUM0]])
103+
// UNROLL-FULL: %[[VAL:.*]] = vector.extract %[[ARG]][1, %[[IV]]] : index from vector<4x4xindex>
104+
// UNROLL-FULL: scf.yield %[[VAL]] : index
105105
// UNROLL-FULL: }
106-
// UNROLL-FULL: %[[VAL_12:.*]] = scf.for %[[VAL_13:.*]] = %[[VAL_1]] to %[[VAL_3]] step %[[VAL_2]] iter_args(%[[VAL_14:.*]] = %[[VAL_8]]) -> (index) {
107-
// UNROLL-FULL: %[[VAL_15:.*]] = vector.extract %[[VAL_0]][2, %[[VAL_13]]] : index from vector<4x4xindex>
108-
// UNROLL-FULL: scf.yield %[[VAL_15]] : index
106+
// UNROLL-FULL: %[[SUM2:.*]] = scf.for %[[IV:.*]] = %[[C0]] to %[[C4]] step %[[C1]] iter_args(%{{.*}} = %[[SUM1]])
107+
// UNROLL-FULL: %[[VAL:.*]] = vector.extract %[[ARG]][2, %[[IV]]] : index from vector<4x4xindex>
108+
// UNROLL-FULL: scf.yield %[[VAL]] : index
109109
// UNROLL-FULL: }
110-
// UNROLL-FULL: %[[VAL_16:.*]] = scf.for %[[VAL_17:.*]] = %[[VAL_1]] to %[[VAL_3]] step %[[VAL_2]] iter_args(%[[VAL_18:.*]] = %[[VAL_12]]) -> (index) {
111-
// UNROLL-FULL: %[[VAL_19:.*]] = vector.extract %[[VAL_0]][3, %[[VAL_17]]] : index from vector<4x4xindex>
112-
// UNROLL-FULL: scf.yield %[[VAL_19]] : index
110+
// UNROLL-FULL: %[[SUM3:.*]] = scf.for %[[IV:.*]] = %[[C0]] to %[[C4]] step %[[C1]] iter_args(%{{.*}} = %[[SUM2]])
111+
// UNROLL-FULL: %[[VAL:.*]] = vector.extract %[[ARG]][3, %[[IV]]] : index from vector<4x4xindex>
112+
// UNROLL-FULL: scf.yield %[[VAL]] : index
113113
// UNROLL-FULL: }
114-
// UNROLL-FULL: return %[[VAL_16]] : index
114+
// UNROLL-FULL: return %[[SUM3]] : index
115115
}

mlir/test/lib/Dialect/SCF/TestLoopUnrolling.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ struct TestLoopUnrollingPass
6666
};
6767
for (auto loop : loops) {
6868
if (unrollFull)
69-
loopUnrollFull(loop);
69+
(void)loopUnrollFull(loop);
7070
else
7171
(void)loopUnrollByFactor(loop, unrollFactor, annotateFn);
7272
}

0 commit comments

Comments
 (0)