@@ -87,13 +87,48 @@ struct DoLoopConversion : public OpRewritePattern<fir::DoLoopOp> {
8787 return success ();
8888 }
8989};
90+
91+ struct IfConversion : public OpRewritePattern <fir::IfOp> {
92+ using OpRewritePattern<fir::IfOp>::OpRewritePattern;
93+ LogicalResult matchAndRewrite (fir::IfOp ifOp,
94+ PatternRewriter &rewriter) const override {
95+ mlir::Location loc = ifOp.getLoc ();
96+ mlir::detail::TypedValue<mlir::IntegerType> condition = ifOp.getCondition ();
97+ ValueTypeRange<ResultRange> resultTypes = ifOp.getResultTypes ();
98+ mlir::scf::IfOp scfIfOp = rewriter.create <scf::IfOp>(
99+ loc, resultTypes, condition, !ifOp.getElseRegion ().empty ());
100+ // then region
101+ scfIfOp.getThenRegion ().takeBody (ifOp.getThenRegion ());
102+ Block &scfThenBlock = scfIfOp.getThenRegion ().front ();
103+ Operation *scfThenTerminator = scfThenBlock.getTerminator ();
104+ // fir.result->scf.yield
105+ rewriter.setInsertionPointToEnd (&scfThenBlock);
106+ rewriter.replaceOpWithNewOp <scf::YieldOp>(scfThenTerminator,
107+ scfThenTerminator->getOperands ());
108+
109+ // else region
110+ if (!ifOp.getElseRegion ().empty ()) {
111+ scfIfOp.getElseRegion ().takeBody (ifOp.getElseRegion ());
112+ mlir::Block &elseBlock = scfIfOp.getElseRegion ().front ();
113+ mlir::Operation *elseTerminator = elseBlock.getTerminator ();
114+
115+ rewriter.setInsertionPointToEnd (&elseBlock);
116+ rewriter.replaceOpWithNewOp <scf::YieldOp>(elseTerminator,
117+ elseTerminator->getOperands ());
118+ }
119+
120+ scfIfOp->setAttrs (ifOp->getAttrs ());
121+ rewriter.replaceOp (ifOp, scfIfOp);
122+ return success ();
123+ }
124+ };
90125} // namespace
91126
92127void FIRToSCFPass::runOnOperation () {
93128 RewritePatternSet patterns (&getContext ());
94- patterns.add <DoLoopConversion>(patterns.getContext ());
129+ patterns.add <DoLoopConversion, IfConversion >(patterns.getContext ());
95130 ConversionTarget target (getContext ());
96- target.addIllegalOp <fir::DoLoopOp>();
131+ target.addIllegalOp <fir::DoLoopOp, fir::IfOp >();
97132 target.markUnknownOpDynamicallyLegal ([](Operation *) { return true ; });
98133 if (failed (
99134 applyPartialConversion (getOperation (), target, std::move (patterns))))
0 commit comments