@@ -25,11 +25,18 @@ class FIRToSCFPass : public fir::impl::FIRToSCFPassBase<FIRToSCFPass> {
2525struct DoLoopConversion : public mlir ::OpRewritePattern<fir::DoLoopOp> {
2626 using OpRewritePattern<fir::DoLoopOp>::OpRewritePattern;
2727
28+ DoLoopConversion (mlir::MLIRContext *context,
29+ bool parallelUnorderedLoop = false ,
30+ mlir::PatternBenefit benefit = 1 )
31+ : OpRewritePattern<fir::DoLoopOp>(context, benefit),
32+ parallelUnorderedLoop (parallelUnorderedLoop) {}
33+
2834 mlir::LogicalResult
2935 matchAndRewrite (fir::DoLoopOp doLoopOp,
3036 mlir::PatternRewriter &rewriter) const override {
3137 mlir::Location loc = doLoopOp.getLoc ();
3238 bool hasFinalValue = doLoopOp.getFinalValue ().has_value ();
39+ bool isUnordered = doLoopOp.getUnordered ().has_value ();
3340
3441 // Get loop values from the DoLoopOp
3542 mlir::Value low = doLoopOp.getLowerBound ();
@@ -53,39 +60,54 @@ struct DoLoopConversion : public mlir::OpRewritePattern<fir::DoLoopOp> {
5360 mlir::arith::DivSIOp::create (rewriter, loc, distance, step);
5461 auto zero = mlir::arith::ConstantIndexOp::create (rewriter, loc, 0 );
5562 auto one = mlir::arith::ConstantIndexOp::create (rewriter, loc, 1 );
56- auto scfForOp =
57- mlir::scf::ForOp::create (rewriter, loc, zero, tripCount, one, iterArgs);
5863
64+ // Create the scf.for or scf.parallel operation
65+ mlir::Operation *scfLoopOp = nullptr ;
66+ if (isUnordered && parallelUnorderedLoop) {
67+ scfLoopOp = mlir::scf::ParallelOp::create (rewriter, loc, {zero},
68+ {tripCount}, {one}, iterArgs);
69+ } else {
70+ scfLoopOp = mlir::scf::ForOp::create (rewriter, loc, zero, tripCount, one,
71+ iterArgs);
72+ }
73+
74+ // Move the body of the fir.do_loop to the scf.for or scf.parallel
5975 auto &loopOps = doLoopOp.getBody ()->getOperations ();
6076 auto resultOp =
6177 mlir::cast<fir::ResultOp>(doLoopOp.getBody ()->getTerminator ());
6278 auto results = resultOp.getOperands ();
63- mlir::Block *loweredBody = scfForOp.getBody ();
79+ auto scfLoopLikeOp = mlir::cast<mlir::LoopLikeOpInterface>(scfLoopOp);
80+ mlir::Block &scfLoopBody = scfLoopLikeOp.getLoopRegions ().front ()->front ();
6481
65- loweredBody-> getOperations ().splice (loweredBody-> begin (), loopOps,
66- loopOps.begin (),
67- std::prev (loopOps.end ()));
82+ scfLoopBody. getOperations ().splice (scfLoopBody. begin (), loopOps,
83+ loopOps.begin (),
84+ std::prev (loopOps.end ()));
6885
69- rewriter.setInsertionPointToStart (loweredBody );
86+ rewriter.setInsertionPointToStart (&scfLoopBody );
7087 mlir::Value iv = mlir::arith::MulIOp::create (
71- rewriter, loc, scfForOp. getInductionVar (), step);
88+ rewriter, loc, scfLoopLikeOp. getSingleInductionVar (). value (), step);
7289 iv = mlir::arith::AddIOp::create (rewriter, loc, low, iv);
7390
7491 if (!results.empty ()) {
75- rewriter.setInsertionPointToEnd (loweredBody );
92+ rewriter.setInsertionPointToEnd (&scfLoopBody );
7693 mlir::scf::YieldOp::create (rewriter, resultOp->getLoc (), results);
7794 }
7895 doLoopOp.getInductionVar ().replaceAllUsesWith (iv);
79- rewriter.replaceAllUsesWith (doLoopOp.getRegionIterArgs (),
80- hasFinalValue
81- ? scfForOp.getRegionIterArgs ().drop_front ()
82- : scfForOp.getRegionIterArgs ());
83-
84- // Copy all the attributes from the old to new op.
85- scfForOp->setAttrs (doLoopOp->getAttrs ());
86- rewriter.replaceOp (doLoopOp, scfForOp);
96+ rewriter.replaceAllUsesWith (
97+ doLoopOp.getRegionIterArgs (),
98+ hasFinalValue ? scfLoopLikeOp.getRegionIterArgs ().drop_front ()
99+ : scfLoopLikeOp.getRegionIterArgs ());
100+
101+ // Copy loop annotations from the fir.do_loop to scf loop op.
102+ if (auto ann = doLoopOp.getLoopAnnotation ())
103+ scfLoopOp->setAttr (" loop_annotation" , *ann);
104+
105+ rewriter.replaceOp (doLoopOp, scfLoopOp);
87106 return mlir::success ();
88107 }
108+
109+ private:
110+ bool parallelUnorderedLoop;
89111};
90112
91113struct IterWhileConversion : public mlir ::OpRewritePattern<fir::IterWhileOp> {
@@ -197,10 +219,15 @@ struct IfConversion : public mlir::OpRewritePattern<fir::IfOp> {
197219};
198220} // namespace
199221
222+ void fir::populateFIRToSCFRewrites (mlir::RewritePatternSet &patterns,
223+ bool parallelUnordered) {
224+ patterns.add <IterWhileConversion, IfConversion>(patterns.getContext ());
225+ patterns.add <DoLoopConversion>(patterns.getContext (), parallelUnordered);
226+ }
227+
200228void FIRToSCFPass::runOnOperation () {
201229 mlir::RewritePatternSet patterns (&getContext ());
202- patterns.add <DoLoopConversion, IterWhileConversion, IfConversion>(
203- patterns.getContext ());
230+ fir::populateFIRToSCFRewrites (patterns, parallelUnordered);
204231 walkAndApplyPatterns (getOperation (), std::move (patterns));
205232}
206233
0 commit comments