Skip to content

Commit af4fbca

Browse files
authored
[mlir][SCF] Fix UB adjustment during scf.for loop peeling
Currently when peeling the first iteration, any mentioning of UB within the loop body is replaced with the new UB in the peeled out first iteration. This introduces a bug in the following scenario: Operations inside of the loop that intentionally use the original UB are incorrectly updated.
1 parent 6b83e68 commit af4fbca

File tree

3 files changed

+79
-8
lines changed

3 files changed

+79
-8
lines changed

mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -240,10 +240,10 @@ LogicalResult mlir::scf::peelForLoopFirstIteration(RewriterBase &b, ForOp forOp,
240240
loc, forOp.getUpperBound().getType(), splitBound);
241241

242242
// Peel the first iteration.
243-
IRMapping map;
244-
map.map(forOp.getUpperBound(), splitBound);
245-
firstIteration = cast<ForOp>(b.clone(*forOp.getOperation(), map));
246-
243+
firstIteration = cast<ForOp>(b.clone(*forOp.getOperation()));
244+
b.modifyOpInPlace(firstIteration, [&]() {
245+
firstIteration.getUpperBoundMutable().assign(splitBound);
246+
});
247247
// Update main loop with new lower bound.
248248
b.modifyOpInPlace(forOp, [&]() {
249249
forOp.getInitArgsMutable().assign(firstIteration->getResults());

mlir/test/Dialect/SCF/for-loop-peeling-front.mlir

Lines changed: 46 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
// CHECK-DAG: %[[C17:.*]] = arith.constant 17 : index
99
// CHECK: %[[FIRST:.*]] = scf.for %[[IV:.*]] = %[[C0]] to %[[C4]]
1010
// CHECK-SAME: step %[[C4]] iter_args(%[[ACC:.*]] = %[[C0_I32]]) -> (i32) {
11-
// CHECK: %[[MIN:.*]] = affine.min #[[MAP]](%[[C4]], %[[IV]])[%[[C4]]]
11+
// CHECK: %[[MIN:.*]] = affine.min #[[MAP]](%[[C17]], %[[IV]])[%[[C4]]]
1212
// CHECK: %[[CAST:.*]] = arith.index_cast %[[MIN]] : index to i32
1313
// CHECK: %[[INIT:.*]] = arith.addi %[[ACC]], %[[CAST]] : i32
1414
// CHECK: scf.yield %[[INIT]]
@@ -35,7 +35,49 @@ func.func @fully_static_bounds() -> i32 {
3535
}
3636
return %r : i32
3737
}
38+
// -----
3839

40+
// CHECK-LABEL: func.func @static_two_iterations_ub_used_in_loop(
41+
// CHECK-SAME: %[[IFM1:.*]]: memref<1xi32>) -> i32 {
42+
// CHECK: %[[C0_I32:.*]] = arith.constant 0 : i32
43+
// CHECK: %[[C0_IDX:.*]] = arith.constant 0 : index
44+
// CHECK: %[[C4_IDX:.*]] = arith.constant 4 : index
45+
// CHECK: %[[C7_IDX:.*]] = arith.constant 7 : index
46+
// CHECK: %[[FOR1:.*]] = scf.for %[[IV1:.*]] = %[[C0_IDX]] to %[[C4_IDX]] step %[[C4_IDX]] iter_args(%[[ARG1:.*]] = %[[C0_I32]]) -> (i32) {
47+
// CHECK: %[[AFFINE_MIN1:.*]] = affine.min #map(%[[C7_IDX]], %[[IV1]]){{\[}}%[[C4_IDX]]]
48+
// CHECK: %[[CAST1:.*]] = arith.index_cast %[[AFFINE_MIN1]] : index to i32
49+
// CHECK: %[[ADDI1:.*]] = arith.addi %[[ARG1]], %[[CAST1]] : i32
50+
// CHECK: %[[LOAD1:.*]] = memref.load %[[IFM1]]{{\[}}%[[C7_IDX]]] : memref<1xi32>
51+
// CHECK: %[[ADDI2:.*]] = arith.addi %[[ADDI1]], %[[LOAD1]] : i32
52+
// CHECK: scf.yield %[[ADDI2]] : i32
53+
// CHECK: }
54+
// CHECK: %[[FOR2:.*]] = scf.for %[[IV2:.*]] = %[[C4_IDX]] to %[[C7_IDX]] step %[[C4_IDX]] iter_args(%[[ARG2:.*]] = %[[RESULT1:.*]]) -> (i32) {
55+
// CHECK: %[[AFFINE_MIN2:.*]] = affine.min #map(%[[C7_IDX]], %[[IV2]]){{\[}}%[[C4_IDX]]]
56+
// CHECK: %[[CAST2:.*]] = arith.index_cast %[[AFFINE_MIN2]] : index to i32
57+
// CHECK: %[[ADDI3:.*]] = arith.addi %[[ARG2]], %[[CAST2]] : i32
58+
// CHECK: %[[LOAD2:.*]] = memref.load %[[IFM1]]{{\[}}%[[C7_IDX]]] : memref<1xi32>
59+
// CHECK: %[[ADDI4:.*]] = arith.addi %[[ADDI3]], %[[LOAD2]] : i32
60+
// CHECK: scf.yield %[[ADDI4]] : i32
61+
// CHECK: }
62+
// CHECK: return %[[RESULT2:.*]] : i32
63+
// CHECK: }
64+
65+
#map = affine_map<(d0, d1)[s0] -> (s0, d0 - d1)>
66+
func.func @static_two_iterations_ub_used_in_loop(%arg0: memref<1xi32>) -> i32 {
67+
%c0_i32 = arith.constant 0 : i32
68+
%lb = arith.constant 0 : index
69+
%step = arith.constant 4 : index
70+
%ub = arith.constant 7 : index
71+
%r = scf.for %iv = %lb to %ub step %step iter_args(%arg = %c0_i32) -> i32 {
72+
%s = affine.min #map(%ub, %iv)[%step]
73+
%casted = arith.index_cast %s : index to i32
74+
%0 = arith.addi %arg, %casted : i32
75+
%1 = memref.load %arg0[%ub] : memref<1xi32>
76+
%result = arith.addi %0, %1 : i32
77+
scf.yield %result : i32
78+
}
79+
return %r : i32
80+
}
3981
// -----
4082

4183
// CHECK-DAG: #[[MAP:.*]] = affine_map<(d0, d1)[s0] -> (4, d0 - d1)>
@@ -44,7 +86,7 @@ func.func @fully_static_bounds() -> i32 {
4486
// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
4587
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
4688
// CHECK: scf.for %[[IV:.*]] = %[[C0]] to %[[C4]] step %[[C4]] {
47-
// CHECK: %[[MIN:.*]] = affine.min #[[MAP]](%[[C4]], %[[IV]])[%[[C4]]]
89+
// CHECK: %[[MIN:.*]] = affine.min #[[MAP]](%[[UB]], %[[IV]])[%[[C4]]]
4890
// CHECK: %[[LOAD:.*]] = memref.load %[[MEMREF]][]
4991
// CHECK: %[[CAST:.*]] = arith.index_cast %[[MIN]]
5092
// CHECK: %[[ADD:.*]] = arith.addi %[[LOAD]], %[[CAST]] : i32
@@ -83,7 +125,7 @@ func.func @no_loop_results(%ub : index, %d : memref<i32>) {
83125
// CHECK: %[[NEW_UB:.*]] = affine.apply #[[MAP0]]()[%[[LB]], %[[STEP]]]
84126
// CHECK: %[[FIRST:.*]] = scf.for %[[IV:.*]] = %[[LB]] to %[[NEW_UB]]
85127
// CHECK-SAME: step %[[STEP]] iter_args(%[[ACC:.*]] = %[[C0_I32]]) -> (i32) {
86-
// CHECK: %[[MIN:.*]] = affine.min #[[MAP1]](%[[NEW_UB]], %[[IV]])[%[[STEP]]]
128+
// CHECK: %[[MIN:.*]] = affine.min #[[MAP1]](%[[UB]], %[[IV]])[%[[STEP]]]
87129
// CHECK: %[[CAST:.*]] = arith.index_cast %[[MIN]] : index to i32
88130
// CHECK: %[[ADD:.*]] = arith.addi %[[ACC]], %[[CAST]] : i32
89131
// CHECK: scf.yield %[[ADD]]
@@ -119,7 +161,7 @@ func.func @fully_dynamic_bounds(%lb : index, %ub: index, %step: index) -> i32 {
119161
// CHECK-DAG: %[[C6:.*]] = arith.constant 6 : index
120162
// CHECK: %[[FIRST:.*]] = scf.for %[[IV:.*]] = %[[C2]] to %[[C6]]
121163
// CHECK-SAME: step %[[C4]] iter_args(%[[ACC:.*]] = %[[C0_I32]]) -> (i32) {
122-
// CHECK: %[[MIN:.*]] = affine.min #[[MAP]](%[[C6]], %[[IV]])[%[[C4]]]
164+
// CHECK: %[[MIN:.*]] = affine.min #[[MAP]](%[[C8]], %[[IV]])[%[[C4]]]
123165
// CHECK: %[[CAST:.*]] = arith.index_cast %[[MIN]] : index to i32
124166
// CHECK: %[[INIT:.*]] = arith.addi %[[ACC]], %[[CAST]] : i32
125167
// CHECK: scf.yield %[[INIT]]

mlir/test/Dialect/SCF/for-loop-peeling.mlir

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,35 @@ func.func @fully_dynamic_bounds(%lb : index, %ub: index, %step: index) -> i32 {
3535

3636
// -----
3737

38+
// CHECK-LABEL: func.func @static_two_iterations_ub_used_in_loop(
39+
// CHECK-SAME: %[[IFM1:.*]]: memref<1xi32>) -> i32 {
40+
// CHECK: %[[C7_I32:.*]] = arith.constant 7 : i32
41+
// CHECK: %[[C7_IDX:.*]] = arith.constant 7 : index
42+
// CHECK: %[[LOAD1:.*]] = memref.load %[[IFM1]]{{\[}}%[[C7_IDX]]] : memref<1xi32>
43+
// CHECK: %[[ADDI1:.*]] = arith.addi %[[LOAD1]], %[[C7_I32]] : i32
44+
// CHECK: %[[LOAD2:.*]] = memref.load %[[IFM1]]{{\[}}%[[C7_IDX]]] : memref<1xi32>
45+
// CHECK: %[[ADDI2:.*]] = arith.addi %[[ADDI1]], %[[LOAD2]] : i32
46+
// CHECK: return %[[ADDI2]] : i32
47+
// CHECK: }
48+
49+
#map = affine_map<(d0, d1)[s0] -> (s0, d0 - d1)>
50+
func.func @static_two_iterations_ub_used_in_loop(%arg0: memref<1xi32>) -> i32 {
51+
%c0_i32 = arith.constant 0 : i32
52+
%lb = arith.constant 0 : index
53+
%step = arith.constant 4 : index
54+
%ub = arith.constant 7 : index
55+
%r = scf.for %iv = %lb to %ub step %step iter_args(%arg = %c0_i32) -> i32 {
56+
%s = affine.min #map(%ub, %iv)[%step]
57+
%casted = arith.index_cast %s : index to i32
58+
%0 = arith.addi %arg, %casted : i32
59+
%1 = memref.load %arg0[%ub] : memref<1xi32>
60+
%result = arith.addi %0, %1 : i32
61+
scf.yield %result : i32
62+
}
63+
return %r : i32
64+
}
65+
// -----
66+
3867
// CHECK: func @fully_static_bounds(
3968
// CHECK-DAG: %[[C0_I32:.*]] = arith.constant 0 : i32
4069
// CHECK-DAG: %[[C1_I32:.*]] = arith.constant 1 : i32

0 commit comments

Comments
 (0)