@@ -16,29 +16,27 @@ namespace fir {
16
16
#include " flang/Optimizer/Transforms/Passes.h.inc"
17
17
} // namespace fir
18
18
19
- using namespace fir ;
20
- using namespace mlir ;
21
-
22
19
namespace {
23
20
class FIRToSCFPass : public fir ::impl::FIRToSCFPassBase<FIRToSCFPass> {
24
21
public:
25
22
void runOnOperation () override ;
26
23
};
27
24
28
- struct DoLoopConversion : public OpRewritePattern <fir::DoLoopOp> {
25
+ struct DoLoopConversion : public mlir :: OpRewritePattern<fir::DoLoopOp> {
29
26
using OpRewritePattern<fir::DoLoopOp>::OpRewritePattern;
30
27
31
- LogicalResult matchAndRewrite (fir::DoLoopOp doLoopOp,
32
- PatternRewriter &rewriter) const override {
33
- auto loc = doLoopOp.getLoc ();
28
+ mlir::LogicalResult
29
+ matchAndRewrite (fir::DoLoopOp doLoopOp,
30
+ mlir::PatternRewriter &rewriter) const override {
31
+ mlir::Location loc = doLoopOp.getLoc ();
34
32
bool hasFinalValue = doLoopOp.getFinalValue ().has_value ();
35
33
36
34
// Get loop values from the DoLoopOp
37
- auto low = doLoopOp.getLowerBound ();
38
- auto high = doLoopOp.getUpperBound ();
35
+ mlir::Value low = doLoopOp.getLowerBound ();
36
+ mlir::Value high = doLoopOp.getUpperBound ();
39
37
assert (low && high && " must be a Value" );
40
- auto step = doLoopOp.getStep ();
41
- llvm::SmallVector<Value> iterArgs;
38
+ mlir::Value step = doLoopOp.getStep ();
39
+ llvm::SmallVector<mlir:: Value> iterArgs;
42
40
if (hasFinalValue)
43
41
iterArgs.push_back (low);
44
42
iterArgs.append (doLoopOp.getIterOperands ().begin (),
@@ -49,31 +47,33 @@ struct DoLoopConversion : public OpRewritePattern<fir::DoLoopOp> {
49
47
// must be a positive value.
50
48
// For easier conversion, we calculate the trip count and use a canonical
51
49
// induction variable.
52
- auto diff = arith::SubIOp::create (rewriter, loc, high, low);
53
- auto distance = arith::AddIOp::create (rewriter, loc, diff, step);
54
- auto tripCount = arith::DivSIOp::create (rewriter, loc, distance, step);
55
- auto zero = arith::ConstantIndexOp::create (rewriter, loc, 0 );
56
- auto one = arith::ConstantIndexOp::create (rewriter, loc, 1 );
50
+ auto diff = mlir::arith::SubIOp::create (rewriter, loc, high, low);
51
+ auto distance = mlir::arith::AddIOp::create (rewriter, loc, diff, step);
52
+ auto tripCount =
53
+ mlir::arith::DivSIOp::create (rewriter, loc, distance, step);
54
+ auto zero = mlir::arith::ConstantIndexOp::create (rewriter, loc, 0 );
55
+ auto one = mlir::arith::ConstantIndexOp::create (rewriter, loc, 1 );
57
56
auto scfForOp =
58
- scf::ForOp::create (rewriter, loc, zero, tripCount, one, iterArgs);
57
+ mlir:: scf::ForOp::create (rewriter, loc, zero, tripCount, one, iterArgs);
59
58
60
59
auto &loopOps = doLoopOp.getBody ()->getOperations ();
61
- auto resultOp = cast<fir::ResultOp>(doLoopOp.getBody ()->getTerminator ());
60
+ auto resultOp =
61
+ mlir::cast<fir::ResultOp>(doLoopOp.getBody ()->getTerminator ());
62
62
auto results = resultOp.getOperands ();
63
- Block *loweredBody = scfForOp.getBody ();
63
+ mlir:: Block *loweredBody = scfForOp.getBody ();
64
64
65
65
loweredBody->getOperations ().splice (loweredBody->begin (), loopOps,
66
66
loopOps.begin (),
67
67
std::prev (loopOps.end ()));
68
68
69
69
rewriter.setInsertionPointToStart (loweredBody);
70
- Value iv =
71
- arith::MulIOp::create ( rewriter, loc, scfForOp.getInductionVar (), step);
72
- iv = arith::AddIOp::create (rewriter, loc, low, iv);
70
+ mlir:: Value iv = mlir::arith::MulIOp::create (
71
+ rewriter, loc, scfForOp.getInductionVar (), step);
72
+ iv = mlir:: arith::AddIOp::create (rewriter, loc, low, iv);
73
73
74
74
if (!results.empty ()) {
75
75
rewriter.setInsertionPointToEnd (loweredBody);
76
- scf::YieldOp::create (rewriter, resultOp->getLoc (), results);
76
+ mlir:: scf::YieldOp::create (rewriter, resultOp->getLoc (), results);
77
77
}
78
78
doLoopOp.getInductionVar ().replaceAllUsesWith (iv);
79
79
rewriter.replaceAllUsesWith (doLoopOp.getRegionIterArgs (),
@@ -84,34 +84,36 @@ struct DoLoopConversion : public OpRewritePattern<fir::DoLoopOp> {
84
84
// Copy all the attributes from the old to new op.
85
85
scfForOp->setAttrs (doLoopOp->getAttrs ());
86
86
rewriter.replaceOp (doLoopOp, scfForOp);
87
- return success ();
87
+ return mlir:: success ();
88
88
}
89
89
};
90
90
91
- void copyBlockAndTransformResult (PatternRewriter &rewriter, Block &srcBlock ,
92
- Block &dstBlock) {
93
- Operation *srcTerminator = srcBlock.getTerminator ();
94
- auto resultOp = cast<fir::ResultOp>(srcTerminator);
91
+ void copyBlockAndTransformResult (mlir:: PatternRewriter &rewriter,
92
+ mlir::Block &srcBlock, mlir:: Block &dstBlock) {
93
+ mlir:: Operation *srcTerminator = srcBlock.getTerminator ();
94
+ auto resultOp = mlir:: cast<fir::ResultOp>(srcTerminator);
95
95
96
96
dstBlock.getOperations ().splice (dstBlock.begin (), srcBlock.getOperations (),
97
97
srcBlock.begin (), std::prev (srcBlock.end ()));
98
98
99
99
if (!resultOp->getOperands ().empty ()) {
100
100
rewriter.setInsertionPointToEnd (&dstBlock);
101
- scf::YieldOp::create (rewriter, resultOp->getLoc (), resultOp->getOperands ());
101
+ mlir::scf::YieldOp::create (rewriter, resultOp->getLoc (),
102
+ resultOp->getOperands ());
102
103
}
103
104
104
105
rewriter.eraseOp (srcTerminator);
105
106
}
106
107
107
- struct IfConversion : public OpRewritePattern <fir::IfOp> {
108
+ struct IfConversion : public mlir :: OpRewritePattern<fir::IfOp> {
108
109
using OpRewritePattern<fir::IfOp>::OpRewritePattern;
109
- LogicalResult matchAndRewrite (fir::IfOp ifOp,
110
- PatternRewriter &rewriter) const override {
110
+ mlir::LogicalResult
111
+ matchAndRewrite (fir::IfOp ifOp,
112
+ mlir::PatternRewriter &rewriter) const override {
111
113
bool hasElse = !ifOp.getElseRegion ().empty ();
112
114
auto scfIfOp =
113
- scf::IfOp::create (rewriter, ifOp.getLoc (), ifOp.getResultTypes (),
114
- ifOp.getCondition (), hasElse);
115
+ mlir:: scf::IfOp::create (rewriter, ifOp.getLoc (), ifOp.getResultTypes (),
116
+ ifOp.getCondition (), hasElse);
115
117
116
118
copyBlockAndTransformResult (rewriter, ifOp.getThenRegion ().front (),
117
119
scfIfOp.getThenRegion ().front ());
@@ -123,22 +125,22 @@ struct IfConversion : public OpRewritePattern<fir::IfOp> {
123
125
124
126
scfIfOp->setAttrs (ifOp->getAttrs ());
125
127
rewriter.replaceOp (ifOp, scfIfOp);
126
- return success ();
128
+ return mlir:: success ();
127
129
}
128
130
};
129
131
} // namespace
130
132
131
133
void FIRToSCFPass::runOnOperation () {
132
- RewritePatternSet patterns (&getContext ());
134
+ mlir:: RewritePatternSet patterns (&getContext ());
133
135
patterns.add <DoLoopConversion, IfConversion>(patterns.getContext ());
134
- ConversionTarget target (getContext ());
136
+ mlir:: ConversionTarget target (getContext ());
135
137
target.addIllegalOp <fir::DoLoopOp, fir::IfOp>();
136
- target.markUnknownOpDynamicallyLegal ([](Operation *) { return true ; });
138
+ target.markUnknownOpDynamicallyLegal ([](mlir:: Operation *) { return true ; });
137
139
if (failed (
138
140
applyPartialConversion (getOperation (), target, std::move (patterns))))
139
141
signalPassFailure ();
140
142
}
141
143
142
- std::unique_ptr<Pass> fir::createFIRToSCFPass () {
144
+ std::unique_ptr<mlir:: Pass> fir::createFIRToSCFPass () {
143
145
return std::make_unique<FIRToSCFPass>();
144
146
}
0 commit comments