Skip to content

Commit fba26bc

Browse files
authored
[MLIR] Fix SCF loop specialization (peeling) to work on scf.for with non-index type (#158707)
The current code would crash with integer. This is visible on this modified example (the original with index was incorrect)
1 parent 92f5d8d commit fba26bc

File tree

2 files changed

+44
-1
lines changed

2 files changed

+44
-1
lines changed

mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,9 @@ static void specializeForLoopForUnrolling(ForOp op) {
9494

9595
OpBuilder b(op);
9696
IRMapping map;
97-
Value constant = arith::ConstantIndexOp::create(b, op.getLoc(), minConstant);
97+
Value constant = arith::ConstantOp::create(
98+
b, op.getLoc(),
99+
IntegerAttr::get(op.getUpperBound().getType(), minConstant));
98100
Value cond = arith::CmpIOp::create(b, op.getLoc(), arith::CmpIPredicate::eq,
99101
bound, constant);
100102
map.map(bound, constant);
@@ -150,6 +152,9 @@ static LogicalResult peelForLoop(RewriterBase &b, ForOp forOp,
150152
ValueRange{forOp.getLowerBound(),
151153
forOp.getUpperBound(),
152154
forOp.getStep()});
155+
if (splitBound.getType() != forOp.getLowerBound().getType())
156+
splitBound = b.createOrFold<arith::IndexCastOp>(
157+
loc, forOp.getLowerBound().getType(), splitBound);
153158

154159
// Create ForOp for partial iteration.
155160
b.setInsertionPointAfter(forOp);
@@ -230,6 +235,9 @@ LogicalResult mlir::scf::peelForLoopFirstIteration(RewriterBase &b, ForOp forOp,
230235
auto loc = forOp.getLoc();
231236
Value splitBound = b.createOrFold<AffineApplyOp>(
232237
loc, ubMap, ValueRange{forOp.getLowerBound(), forOp.getStep()});
238+
if (splitBound.getType() != forOp.getUpperBound().getType())
239+
splitBound = b.createOrFold<arith::IndexCastOp>(
240+
loc, forOp.getUpperBound().getType(), splitBound);
233241

234242
// Peel the first iteration.
235243
IRMapping map;

mlir/test/Dialect/SCF/for-loop-peeling.mlir

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,41 @@ func.func @fully_static_bounds() -> i32 {
6767

6868
// -----
6969

70+
// CHECK: func @fully_static_bounds_integers(
71+
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : i32
72+
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : i32
73+
// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : i32
74+
// CHECK-DAG: %[[C16:.*]] = arith.constant 16 : i32
75+
// CHECK: %[[LOOP:.*]] = scf.for %[[IV:.*]] = %[[C0]] to %[[C16]]
76+
// CHECK-SAME: step %[[C4]] iter_args(%[[ACC:.*]] = %[[C0]]) -> (i32)
77+
// CHECK: %[[MAP:.*]] = affine.min
78+
// CHECK: %[[MAP_CAST:.*]] = arith.index_cast %[[MAP]]
79+
// CHECK: %[[ADD:.*]] = arith.addi %[[ACC]], %[[MAP_CAST]] : i32
80+
// CHECK: scf.yield %[[ADD]]
81+
// CHECK: }
82+
// CHECK: %[[RESULT:.*]] = arith.addi %[[LOOP]], %[[C1]] : i32
83+
// CHECK: return %[[RESULT]]
84+
#map = affine_map<(d0, d1)[s0] -> (s0, d0 - d1)>
85+
func.func @fully_static_bounds_integers() -> i32 {
86+
%c0_i32 = arith.constant 0 : i32
87+
%lb = arith.constant 0 : i32
88+
%step = arith.constant 4 : i32
89+
%ub = arith.constant 17 : i32
90+
%r = scf.for %iv = %lb to %ub step %step
91+
iter_args(%arg = %c0_i32) -> i32 : i32 {
92+
%ub_index = arith.index_cast %ub : i32 to index
93+
%iv_index = arith.index_cast %iv : i32 to index
94+
%step_index = arith.index_cast %step : i32 to index
95+
%s = affine.min #map(%ub_index, %iv_index)[%step_index]
96+
%casted = arith.index_cast %s : index to i32
97+
%0 = arith.addi %arg, %casted : i32
98+
scf.yield %0 : i32
99+
}
100+
return %r : i32
101+
}
102+
103+
// -----
104+
70105
// CHECK-DAG: #[[MAP0:.*]] = affine_map<()[s0] -> ((s0 floordiv 4) * 4)>
71106
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0)[s0] -> (-d0 + s0)>
72107
// CHECK: func @dynamic_upper_bound(

0 commit comments

Comments
 (0)