Skip to content

Commit c164e63

Browse files
[flang][fir] Add conversion of fir.iterate_while to scf.while. (#152439)
This commmit is a supplement for #140374. RFC:https://discourse.llvm.org/t/rfc-add-fir-affine-optimization-fir-pass-pipeline/86190/6
1 parent aa503f6 commit c164e63

File tree

2 files changed

+170
-3
lines changed

2 files changed

+170
-3
lines changed

flang/lib/Optimizer/Transforms/FIRToSCF.cpp

Lines changed: 71 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ struct DoLoopConversion : public mlir::OpRewritePattern<fir::DoLoopOp> {
3636
mlir::Value high = doLoopOp.getUpperBound();
3737
assert(low && high && "must be a Value");
3838
mlir::Value step = doLoopOp.getStep();
39-
llvm::SmallVector<mlir::Value> iterArgs;
39+
mlir::SmallVector<mlir::Value> iterArgs;
4040
if (hasFinalValue)
4141
iterArgs.push_back(low);
4242
iterArgs.append(doLoopOp.getIterOperands().begin(),
@@ -88,6 +88,73 @@ struct DoLoopConversion : public mlir::OpRewritePattern<fir::DoLoopOp> {
8888
}
8989
};
9090

91+
struct IterWhileConversion : public mlir::OpRewritePattern<fir::IterWhileOp> {
92+
using OpRewritePattern<fir::IterWhileOp>::OpRewritePattern;
93+
94+
mlir::LogicalResult
95+
matchAndRewrite(fir::IterWhileOp iterWhileOp,
96+
mlir::PatternRewriter &rewriter) const override {
97+
98+
mlir::Location loc = iterWhileOp.getLoc();
99+
mlir::Value lowerBound = iterWhileOp.getLowerBound();
100+
mlir::Value upperBound = iterWhileOp.getUpperBound();
101+
mlir::Value step = iterWhileOp.getStep();
102+
103+
mlir::Value okInit = iterWhileOp.getIterateIn();
104+
mlir::ValueRange iterArgs = iterWhileOp.getInitArgs();
105+
106+
mlir::SmallVector<mlir::Value> initVals;
107+
initVals.push_back(lowerBound);
108+
initVals.push_back(okInit);
109+
initVals.append(iterArgs.begin(), iterArgs.end());
110+
111+
mlir::SmallVector<mlir::Type> loopTypes;
112+
loopTypes.push_back(lowerBound.getType());
113+
loopTypes.push_back(okInit.getType());
114+
for (auto val : iterArgs)
115+
loopTypes.push_back(val.getType());
116+
117+
auto scfWhileOp =
118+
mlir::scf::WhileOp::create(rewriter, loc, loopTypes, initVals);
119+
120+
auto &beforeBlock = *rewriter.createBlock(
121+
&scfWhileOp.getBefore(), scfWhileOp.getBefore().end(), loopTypes,
122+
mlir::SmallVector<mlir::Location>(loopTypes.size(), loc));
123+
124+
mlir::Region::BlockArgListType argsInBefore =
125+
scfWhileOp.getBefore().getArguments();
126+
auto ivInBefore = argsInBefore[0];
127+
auto earlyExitInBefore = argsInBefore[1];
128+
129+
rewriter.setInsertionPointToStart(&beforeBlock);
130+
131+
mlir::Value inductionCmp = mlir::arith::CmpIOp::create(
132+
rewriter, loc, mlir::arith::CmpIPredicate::sle, ivInBefore, upperBound);
133+
mlir::Value cond = mlir::arith::AndIOp::create(rewriter, loc, inductionCmp,
134+
earlyExitInBefore);
135+
136+
mlir::scf::ConditionOp::create(rewriter, loc, cond, argsInBefore);
137+
138+
rewriter.moveBlockBefore(iterWhileOp.getBody(), &scfWhileOp.getAfter(),
139+
scfWhileOp.getAfter().begin());
140+
141+
auto *afterBody = scfWhileOp.getAfterBody();
142+
auto resultOp = mlir::cast<fir::ResultOp>(afterBody->getTerminator());
143+
mlir::SmallVector<mlir::Value> results(resultOp->getOperands());
144+
mlir::Value ivInAfter = scfWhileOp.getAfterArguments()[0];
145+
146+
rewriter.setInsertionPointToStart(afterBody);
147+
results[0] = mlir::arith::AddIOp::create(rewriter, loc, ivInAfter, step);
148+
149+
rewriter.setInsertionPointToEnd(afterBody);
150+
rewriter.replaceOpWithNewOp<mlir::scf::YieldOp>(resultOp, results);
151+
152+
scfWhileOp->setAttrs(iterWhileOp->getAttrs());
153+
rewriter.replaceOp(iterWhileOp, scfWhileOp);
154+
return mlir::success();
155+
}
156+
};
157+
91158
void copyBlockAndTransformResult(mlir::PatternRewriter &rewriter,
92159
mlir::Block &srcBlock, mlir::Block &dstBlock) {
93160
mlir::Operation *srcTerminator = srcBlock.getTerminator();
@@ -132,9 +199,10 @@ struct IfConversion : public mlir::OpRewritePattern<fir::IfOp> {
132199

133200
void FIRToSCFPass::runOnOperation() {
134201
mlir::RewritePatternSet patterns(&getContext());
135-
patterns.add<DoLoopConversion, IfConversion>(patterns.getContext());
202+
patterns.add<DoLoopConversion, IterWhileConversion, IfConversion>(
203+
patterns.getContext());
136204
mlir::ConversionTarget target(getContext());
137-
target.addIllegalOp<fir::DoLoopOp, fir::IfOp>();
205+
target.addIllegalOp<fir::DoLoopOp, fir::IterWhileOp, fir::IfOp>();
138206
target.markUnknownOpDynamicallyLegal([](mlir::Operation *) { return true; });
139207
if (failed(
140208
applyPartialConversion(getOperation(), target, std::move(patterns))))
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
// RUN: fir-opt %s --fir-to-scf | FileCheck %s
2+
3+
// CHECK-LABEL: func.func @test_simple_iterate_while_1() -> (index, i1, i16, i32) {
4+
// CHECK: %[[VAL_0:.*]] = arith.constant 11 : index
5+
// CHECK: %[[VAL_1:.*]] = arith.constant 22 : index
6+
// CHECK: %[[VAL_2:.*]] = arith.constant 2 : index
7+
// CHECK: %[[VAL_3:.*]] = arith.constant true
8+
// CHECK: %[[VAL_4:.*]] = arith.constant 123 : i16
9+
// CHECK: %[[VAL_5:.*]] = arith.constant 456 : i32
10+
// CHECK: %[[VAL_6:.*]]:4 = scf.while (%[[VAL_7:.*]] = %[[VAL_0]], %[[VAL_8:.*]] = %[[VAL_3]], %[[VAL_9:.*]] = %[[VAL_4]], %[[VAL_10:.*]] = %[[VAL_5]]) : (index, i1, i16, i32) -> (index, i1, i16, i32) {
11+
// CHECK: %[[VAL_11:.*]] = arith.cmpi sle, %[[VAL_7]], %[[VAL_1]] : index
12+
// CHECK: %[[VAL_12:.*]] = arith.andi %[[VAL_11]], %[[VAL_8]] : i1
13+
// CHECK: scf.condition(%[[VAL_12]]) %[[VAL_7]], %[[VAL_8]], %[[VAL_9]], %[[VAL_10]] : index, i1, i16, i32
14+
// CHECK: } do {
15+
// CHECK: ^bb0(%[[VAL_13:.*]]: index, %[[VAL_14:.*]]: i1, %[[VAL_15:.*]]: i16, %[[VAL_16:.*]]: i32):
16+
// CHECK: %[[VAL_17:.*]] = arith.addi %[[VAL_13]], %[[VAL_2]] : index
17+
// CHECK: %[[VAL_18:.*]] = arith.constant true
18+
// CHECK: %[[VAL_19:.*]] = arith.constant 22 : i16
19+
// CHECK: %[[VAL_20:.*]] = arith.constant 33 : i32
20+
// CHECK: scf.yield %[[VAL_17]], %[[VAL_18]], %[[VAL_19]], %[[VAL_20]] : index, i1, i16, i32
21+
// CHECK: }
22+
// CHECK: return %[[VAL_21:.*]]#0, %[[VAL_21]]#1, %[[VAL_21]]#2, %[[VAL_21]]#3 : index, i1, i16, i32
23+
// CHECK: }
24+
func.func @test_simple_iterate_while_1() -> (index, i1, i16, i32) {
25+
%lo = arith.constant 11 : index
26+
%up = arith.constant 22 : index
27+
%step = arith.constant 2 : index
28+
%ok = arith.constant 1 : i1
29+
%val1 = arith.constant 123 : i16
30+
%val2 = arith.constant 456 : i32
31+
32+
%res:4 = fir.iterate_while (%i = %lo to %up step %step) and (%c = %ok) iter_args(%v1 = %val1, %v2 = %val2) -> (index, i1, i16, i32) {
33+
%new_c = arith.constant 1 : i1
34+
%new_v1 = arith.constant 22 : i16
35+
%new_v2 = arith.constant 33 : i32
36+
fir.result %i, %new_c, %new_v1, %new_v2 : index, i1, i16, i32
37+
}
38+
39+
return %res#0, %res#1, %res#2, %res#3 : index, i1, i16, i32
40+
}
41+
42+
// CHECK-LABEL: func.func @test_simple_iterate_while_2(
43+
// CHECK-SAME: %[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: i1, %[[ARG3:.*]]: i32) -> (index, i1, i32) {
44+
// CHECK: %[[VAL_0:.*]] = arith.constant 1 : index
45+
// CHECK: %[[VAL_1:.*]]:3 = scf.while (%[[VAL_2:.*]] = %[[ARG0]], %[[VAL_3:.*]] = %[[ARG2]], %[[VAL_4:.*]] = %[[ARG3]]) : (index, i1, i32) -> (index, i1, i32) {
46+
// CHECK: %[[VAL_5:.*]] = arith.cmpi sle, %[[VAL_2]], %[[ARG1]] : index
47+
// CHECK: %[[VAL_6:.*]] = arith.andi %[[VAL_5]], %[[VAL_3]] : i1
48+
// CHECK: scf.condition(%[[VAL_6]]) %[[VAL_2]], %[[VAL_3]], %[[VAL_4]] : index, i1, i32
49+
// CHECK: } do {
50+
// CHECK: ^bb0(%[[VAL_7:.*]]: index, %[[VAL_8:.*]]: i1, %[[VAL_9:.*]]: i32):
51+
// CHECK: %[[VAL_10:.*]] = arith.addi %[[VAL_7]], %[[VAL_0]] : index
52+
// CHECK: %[[VAL_11:.*]] = arith.constant 123 : i32
53+
// CHECK: %[[VAL_12:.*]] = arith.constant true
54+
// CHECK: scf.yield %[[VAL_10]], %[[VAL_12]], %[[VAL_11]] : index, i1, i32
55+
// CHECK: }
56+
// CHECK: return %[[VAL_13:.*]]#0, %[[VAL_13]]#1, %[[VAL_13]]#2 : index, i1, i32
57+
// CHECK: }
58+
func.func @test_simple_iterate_while_2(%start: index, %stop: index, %cond: i1, %val: i32) -> (index, i1, i32) {
59+
%step = arith.constant 1 : index
60+
61+
%res:3 = fir.iterate_while (%i = %start to %stop step %step) and (%ok = %cond) iter_args(%x = %val) -> (index, i1, i32) {
62+
%new_x = arith.constant 123 : i32
63+
%new_ok = arith.constant 1 : i1
64+
fir.result %i, %new_ok, %new_x : index, i1, i32
65+
}
66+
67+
return %res#0, %res#1, %res#2 : index, i1, i32
68+
}
69+
70+
// CHECK-LABEL: func.func @test_zero_iterations() -> (index, i1, i8) {
71+
// CHECK: %[[VAL_0:.*]] = arith.constant 10 : index
72+
// CHECK: %[[VAL_1:.*]] = arith.constant 5 : index
73+
// CHECK: %[[VAL_2:.*]] = arith.constant 1 : index
74+
// CHECK: %[[VAL_3:.*]] = arith.constant true
75+
// CHECK: %[[VAL_4:.*]] = arith.constant 42 : i8
76+
// CHECK: %[[VAL_5:.*]]:3 = scf.while (%[[VAL_6:.*]] = %[[VAL_0]], %[[VAL_7:.*]] = %[[VAL_3]], %[[VAL_8:.*]] = %[[VAL_4]]) : (index, i1, i8) -> (index, i1, i8) {
77+
// CHECK: %[[VAL_9:.*]] = arith.cmpi sle, %[[VAL_6]], %[[VAL_1]] : index
78+
// CHECK: %[[VAL_10:.*]] = arith.andi %[[VAL_9]], %[[VAL_7]] : i1
79+
// CHECK: scf.condition(%[[VAL_10]]) %[[VAL_6]], %[[VAL_7]], %[[VAL_8]] : index, i1, i8
80+
// CHECK: } do {
81+
// CHECK: ^bb0(%[[VAL_11:.*]]: index, %[[VAL_12:.*]]: i1, %[[VAL_13:.*]]: i8):
82+
// CHECK: %[[VAL_14:.*]] = arith.addi %[[VAL_11]], %[[VAL_2]] : index
83+
// CHECK: scf.yield %[[VAL_14]], %[[VAL_12]], %[[VAL_13]] : index, i1, i8
84+
// CHECK: }
85+
// CHECK: return %[[VAL_15:.*]]#0, %[[VAL_15]]#1, %[[VAL_15]]#2 : index, i1, i8
86+
// CHECK: }
87+
func.func @test_zero_iterations() -> (index, i1, i8) {
88+
%lo = arith.constant 10 : index
89+
%up = arith.constant 5 : index
90+
%step = arith.constant 1 : index
91+
%ok = arith.constant 1 : i1
92+
%x = arith.constant 42 : i8
93+
94+
%res:3 = fir.iterate_while (%i = %lo to %up step %step) and (%c = %ok) iter_args(%xv = %x) -> (index, i1, i8) {
95+
fir.result %i, %c, %xv : index, i1, i8
96+
}
97+
98+
return %res#0, %res#1, %res#2 : index, i1, i8
99+
}

0 commit comments

Comments
 (0)