1414
1515#include " mlir/Dialect/Arith/IR/Arith.h"
1616#include " mlir/Dialect/EmitC/IR/EmitC.h"
17+ #include " mlir/Dialect/EmitC/Transforms/TypeConversions.h"
1718#include " mlir/Dialect/SCF/IR/SCF.h"
1819#include " mlir/IR/Builders.h"
1920#include " mlir/IR/BuiltinOps.h"
2021#include " mlir/IR/IRMapping.h"
2122#include " mlir/IR/MLIRContext.h"
2223#include " mlir/IR/PatternMatch.h"
2324#include " mlir/Transforms/DialectConversion.h"
25+ #include " mlir/Transforms/OneToNTypeConversion.h"
2426#include " mlir/Transforms/Passes.h"
2527
2628namespace mlir {
@@ -39,21 +41,22 @@ struct SCFToEmitCPass : public impl::SCFToEmitCBase<SCFToEmitCPass> {
3941
4042// Lower scf::for to emitc::for, implementing result values using
4143// emitc::variable's updated within the loop body.
42- struct ForLowering : public OpRewritePattern <ForOp> {
43- using OpRewritePattern <ForOp>::OpRewritePattern ;
44+ struct ForLowering : public OpConversionPattern <ForOp> {
45+ using OpConversionPattern <ForOp>::OpConversionPattern ;
4446
45- LogicalResult matchAndRewrite (ForOp forOp,
46- PatternRewriter &rewriter) const override ;
47+ LogicalResult
48+ matchAndRewrite (ForOp forOp, OpAdaptor adaptor,
49+ ConversionPatternRewriter &rewriter) const override ;
4750};
4851
4952// Create an uninitialized emitc::variable op for each result of the given op.
5053template <typename T>
51- static SmallVector<Value> createVariablesForResults (T op,
52- PatternRewriter &rewriter) {
53- SmallVector<Value> resultVariables;
54-
54+ static LogicalResult
55+ createVariablesForResults (T op, const TypeConverter *typeConverter,
56+ ConversionPatternRewriter &rewriter,
57+ SmallVector<Value> &resultVariables) {
5558 if (!op.getNumResults ())
56- return resultVariables ;
59+ return success () ;
5760
5861 Location loc = op->getLoc ();
5962 MLIRContext *context = op.getContext ();
@@ -62,21 +65,23 @@ static SmallVector<Value> createVariablesForResults(T op,
6265 rewriter.setInsertionPoint (op);
6366
6467 for (OpResult result : op.getResults ()) {
65- Type resultType = result.getType ();
68+ Type resultType = typeConverter->convertType (result.getType ());
69+ if (!resultType)
70+ return rewriter.notifyMatchFailure (op, " result type conversion failed" );
6671 Type varType = emitc::LValueType::get (resultType);
6772 emitc::OpaqueAttr noInit = emitc::OpaqueAttr::get (context, " " );
6873 emitc::VariableOp var =
6974 rewriter.create <emitc::VariableOp>(loc, varType, noInit);
7075 resultVariables.push_back (var);
7176 }
7277
73- return resultVariables ;
78+ return success () ;
7479}
7580
7681// Create a series of assign ops assigning given values to given variables at
7782// the current insertion point of given rewriter.
78- static void assignValues (ValueRange values, SmallVector<Value> & variables,
79- PatternRewriter &rewriter, Location loc) {
83+ static void assignValues (ValueRange values, ValueRange variables,
84+ ConversionPatternRewriter &rewriter, Location loc) {
8085 for (auto [value, var] : llvm::zip (values, variables))
8186 rewriter.create <emitc::AssignOp>(loc, var, value);
8287}
@@ -89,46 +94,58 @@ SmallVector<Value> loadValues(const SmallVector<Value> &variables,
8994 });
9095}
9196
92- static void lowerYield (SmallVector<Value> &resultVariables,
93- PatternRewriter &rewriter, scf::YieldOp yield) {
97+ static LogicalResult lowerYield (Operation *op, ValueRange resultVariables,
98+ ConversionPatternRewriter &rewriter,
99+ scf::YieldOp yield) {
94100 Location loc = yield.getLoc ();
95- ValueRange operands = yield.getOperands ();
96101
97102 OpBuilder::InsertionGuard guard (rewriter);
98103 rewriter.setInsertionPoint (yield);
99104
100- assignValues (operands, resultVariables, rewriter, loc);
105+ SmallVector<Value> yieldOperands;
106+ if (failed (rewriter.getRemappedValues (yield.getOperands (), yieldOperands))) {
107+ return rewriter.notifyMatchFailure (op, " failed to lower yield operands" );
108+ }
109+
110+ assignValues (yieldOperands, resultVariables, rewriter, loc);
101111
102112 rewriter.create <emitc::YieldOp>(loc);
103113 rewriter.eraseOp (yield);
114+
115+ return success ();
104116}
105117
106118// Lower the contents of an scf::if/scf::index_switch regions to an
107119// emitc::if/emitc::switch region. The contents of the lowering region is
108120// moved into the respective lowered region, but the scf::yield is replaced not
109121// only with an emitc::yield, but also with a sequence of emitc::assign ops that
110122// set the yielded values into the result variables.
111- static void lowerRegion (SmallVector<Value> & resultVariables,
112- PatternRewriter &rewriter, Region ®ion ,
113- Region &loweredRegion) {
123+ static LogicalResult lowerRegion (Operation *op, ValueRange resultVariables,
124+ ConversionPatternRewriter &rewriter ,
125+ Region ®ion, Region &loweredRegion) {
114126 rewriter.inlineRegionBefore (region, loweredRegion, loweredRegion.end ());
115127 Operation *terminator = loweredRegion.back ().getTerminator ();
116- lowerYield (resultVariables, rewriter, cast<scf::YieldOp>(terminator));
128+ return lowerYield (op, resultVariables, rewriter,
129+ cast<scf::YieldOp>(terminator));
117130}
118131
119- LogicalResult ForLowering::matchAndRewrite (ForOp forOp,
120- PatternRewriter &rewriter) const {
132+ LogicalResult
133+ ForLowering::matchAndRewrite (ForOp forOp, OpAdaptor adaptor,
134+ ConversionPatternRewriter &rewriter) const {
121135 Location loc = forOp.getLoc ();
122136
123137 // Create an emitc::variable op for each result. These variables will be
124138 // assigned to by emitc::assign ops within the loop body.
125- SmallVector<Value> resultVariables =
126- createVariablesForResults (forOp, rewriter);
139+ SmallVector<Value> resultVariables;
140+ if (failed (createVariablesForResults (forOp, getTypeConverter (), rewriter,
141+ resultVariables)))
142+ return rewriter.notifyMatchFailure (forOp,
143+ " create variables for results failed" );
127144
128- assignValues (forOp. getInits (), resultVariables, rewriter, loc);
145+ assignValues (adaptor. getInitArgs (), resultVariables, rewriter, loc);
129146
130147 emitc::ForOp loweredFor = rewriter.create <emitc::ForOp>(
131- loc, forOp .getLowerBound (), forOp .getUpperBound (), forOp .getStep ());
148+ loc, adaptor .getLowerBound (), adaptor .getUpperBound (), adaptor .getStep ());
132149
133150 Block *loweredBody = loweredFor.getBody ();
134151
@@ -143,13 +160,27 @@ LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
143160
144161 rewriter.restoreInsertionPoint (ip);
145162
163+ // Convert the original region types into the new types by adding unrealized
164+ // casts in the beginning of the loop. This performs the conversion in place.
165+ if (failed (rewriter.convertRegionTypes (&forOp.getRegion (),
166+ *getTypeConverter (), nullptr ))) {
167+ return rewriter.notifyMatchFailure (forOp, " region types conversion failed" );
168+ }
169+
170+ // Register the replacements for the block arguments and inline the body of
171+ // the scf.for loop into the body of the emitc::for loop.
172+ Block *scfBody = &(forOp.getRegion ().front ());
146173 SmallVector<Value> replacingValues;
147174 replacingValues.push_back (loweredFor.getInductionVar ());
148175 replacingValues.append (iterArgsValues.begin (), iterArgsValues.end ());
176+ rewriter.mergeBlocks (scfBody, loweredBody, replacingValues);
149177
150- rewriter.mergeBlocks (forOp.getBody (), loweredBody, replacingValues);
151- lowerYield (resultVariables, rewriter,
152- cast<scf::YieldOp>(loweredBody->getTerminator ()));
178+ auto result = lowerYield (forOp, resultVariables, rewriter,
179+ cast<scf::YieldOp>(loweredBody->getTerminator ()));
180+
181+ if (failed (result)) {
182+ return result;
183+ }
153184
154185 // Load variables into SSA values after the for loop.
155186 SmallVector<Value> resultValues = loadValues (resultVariables, rewriter, loc);
@@ -160,38 +191,66 @@ LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
160191
161192// Lower scf::if to emitc::if, implementing result values as emitc::variable's
162193// updated within the then and else regions.
163- struct IfLowering : public OpRewritePattern <IfOp> {
164- using OpRewritePattern <IfOp>::OpRewritePattern ;
194+ struct IfLowering : public OpConversionPattern <IfOp> {
195+ using OpConversionPattern <IfOp>::OpConversionPattern ;
165196
166- LogicalResult matchAndRewrite (IfOp ifOp,
167- PatternRewriter &rewriter) const override ;
197+ LogicalResult
198+ matchAndRewrite (IfOp ifOp, OpAdaptor adaptor,
199+ ConversionPatternRewriter &rewriter) const override ;
168200};
169201
170202} // namespace
171203
172- LogicalResult IfLowering::matchAndRewrite (IfOp ifOp,
173- PatternRewriter &rewriter) const {
204+ LogicalResult
205+ IfLowering::matchAndRewrite (IfOp ifOp, OpAdaptor adaptor,
206+ ConversionPatternRewriter &rewriter) const {
174207 Location loc = ifOp.getLoc ();
175208
176209 // Create an emitc::variable op for each result. These variables will be
177210 // assigned to by emitc::assign ops within the then & else regions.
178- SmallVector<Value> resultVariables =
179- createVariablesForResults (ifOp, rewriter);
180-
181- Region &thenRegion = ifOp.getThenRegion ();
182- Region &elseRegion = ifOp.getElseRegion ();
211+ SmallVector<Value> resultVariables;
212+ if (failed (createVariablesForResults (ifOp, getTypeConverter (), rewriter,
213+ resultVariables)))
214+ return rewriter.notifyMatchFailure (ifOp,
215+ " create variables for results failed" );
216+
217+ // Utility function to lower the contents of an scf::if region to an emitc::if
218+ // region. The contents of the scf::if regions is moved into the respective
219+ // emitc::if regions, but the scf::yield is replaced not only with an
220+ // emitc::yield, but also with a sequence of emitc::assign ops that set the
221+ // yielded values into the result variables.
222+ auto lowerRegion = [&resultVariables, &rewriter,
223+ &ifOp](Region ®ion, Region &loweredRegion) {
224+ rewriter.inlineRegionBefore (region, loweredRegion, loweredRegion.end ());
225+ Operation *terminator = loweredRegion.back ().getTerminator ();
226+ auto result = lowerYield (ifOp, resultVariables, rewriter,
227+ cast<scf::YieldOp>(terminator));
228+ if (failed (result)) {
229+ return result;
230+ }
231+ return success ();
232+ };
233+
234+ Region &thenRegion = adaptor.getThenRegion ();
235+ Region &elseRegion = adaptor.getElseRegion ();
183236
184237 bool hasElseBlock = !elseRegion.empty ();
185238
186239 auto loweredIf =
187- rewriter.create <emitc::IfOp>(loc, ifOp .getCondition (), false , false );
240+ rewriter.create <emitc::IfOp>(loc, adaptor .getCondition (), false , false );
188241
189242 Region &loweredThenRegion = loweredIf.getThenRegion ();
190- lowerRegion (resultVariables, rewriter, thenRegion, loweredThenRegion);
243+ auto result = lowerRegion (thenRegion, loweredThenRegion);
244+ if (failed (result)) {
245+ return result;
246+ }
191247
192248 if (hasElseBlock) {
193249 Region &loweredElseRegion = loweredIf.getElseRegion ();
194- lowerRegion (resultVariables, rewriter, elseRegion, loweredElseRegion);
250+ auto result = lowerRegion (elseRegion, loweredElseRegion);
251+ if (failed (result)) {
252+ return result;
253+ }
195254 }
196255
197256 rewriter.setInsertionPointAfter (ifOp);
@@ -203,37 +262,46 @@ LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
203262
204263// Lower scf::index_switch to emitc::switch, implementing result values as
205264// emitc::variable's updated within the case and default regions.
206- struct IndexSwitchOpLowering : public OpRewritePattern <IndexSwitchOp> {
207- using OpRewritePattern<IndexSwitchOp>::OpRewritePattern ;
265+ struct IndexSwitchOpLowering : public OpConversionPattern <IndexSwitchOp> {
266+ using OpConversionPattern::OpConversionPattern ;
208267
209- LogicalResult matchAndRewrite (IndexSwitchOp indexSwitchOp,
210- PatternRewriter &rewriter) const override ;
268+ LogicalResult
269+ matchAndRewrite (IndexSwitchOp indexSwitchOp, OpAdaptor adaptor,
270+ ConversionPatternRewriter &rewriter) const override ;
211271};
212272
213- LogicalResult
214- IndexSwitchOpLowering::matchAndRewrite ( IndexSwitchOp indexSwitchOp,
215- PatternRewriter &rewriter) const {
273+ LogicalResult IndexSwitchOpLowering::matchAndRewrite (
274+ IndexSwitchOp indexSwitchOp, OpAdaptor adaptor ,
275+ ConversionPatternRewriter &rewriter) const {
216276 Location loc = indexSwitchOp.getLoc ();
217277
218278 // Create an emitc::variable op for each result. These variables will be
219279 // assigned to by emitc::assign ops within the case and default regions.
220- SmallVector<Value> resultVariables =
221- createVariablesForResults (indexSwitchOp, rewriter);
280+ SmallVector<Value> resultVariables;
281+ if (failed (createVariablesForResults (indexSwitchOp, getTypeConverter (),
282+ rewriter, resultVariables))) {
283+ return rewriter.notifyMatchFailure (indexSwitchOp,
284+ " create variables for results failed" );
285+ }
222286
223287 auto loweredSwitch = rewriter.create <emitc::SwitchOp>(
224- loc, indexSwitchOp.getArg (), indexSwitchOp.getCases (),
225- indexSwitchOp.getNumCases ());
288+ loc, adaptor.getArg (), adaptor.getCases (), indexSwitchOp.getNumCases ());
226289
227290 // Lowering all case regions.
228- for (auto pair : llvm::zip (indexSwitchOp.getCaseRegions (),
229- loweredSwitch.getCaseRegions ())) {
230- lowerRegion (resultVariables, rewriter, std::get<0 >(pair),
231- std::get<1 >(pair));
291+ for (auto pair :
292+ llvm::zip (adaptor.getCaseRegions (), loweredSwitch.getCaseRegions ())) {
293+ if (failed (lowerRegion (indexSwitchOp, resultVariables, rewriter,
294+ *std::get<0 >(pair), std::get<1 >(pair)))) {
295+ return failure ();
296+ }
232297 }
233298
234299 // Lowering default region.
235- lowerRegion (resultVariables, rewriter, indexSwitchOp.getDefaultRegion (),
236- loweredSwitch.getDefaultRegion ());
300+ if (failed (lowerRegion (indexSwitchOp, resultVariables, rewriter,
301+ adaptor.getDefaultRegion (),
302+ loweredSwitch.getDefaultRegion ()))) {
303+ return failure ();
304+ }
237305
238306 rewriter.setInsertionPointAfter (indexSwitchOp);
239307 SmallVector<Value> results = loadValues (resultVariables, rewriter, loc);
@@ -242,15 +310,22 @@ IndexSwitchOpLowering::matchAndRewrite(IndexSwitchOp indexSwitchOp,
242310 return success ();
243311}
244312
245- void mlir::populateSCFToEmitCConversionPatterns (RewritePatternSet &patterns) {
246- patterns.add <ForLowering>(patterns.getContext ());
247- patterns.add <IfLowering>(patterns.getContext ());
248- patterns.add <IndexSwitchOpLowering>(patterns.getContext ());
313+ void mlir::populateSCFToEmitCConversionPatterns (RewritePatternSet &patterns,
314+ TypeConverter &typeConverter) {
315+ patterns.add <ForLowering>(typeConverter, patterns.getContext ());
316+ patterns.add <IfLowering>(typeConverter, patterns.getContext ());
317+ patterns.add <IndexSwitchOpLowering>(typeConverter, patterns.getContext ());
249318}
250319
251320void SCFToEmitCPass::runOnOperation () {
252321 RewritePatternSet patterns (&getContext ());
253- populateSCFToEmitCConversionPatterns (patterns);
322+ TypeConverter typeConverter;
323+ // Fallback converter
324+ // See note https://mlir.llvm.org/docs/DialectConversion/#type-converter
325+ // Type converters are called most to least recently inserted
326+ typeConverter.addConversion ([](Type t) { return t; });
327+ populateEmitCSizeTTypeConversions (typeConverter);
328+ populateSCFToEmitCConversionPatterns (patterns, typeConverter);
254329
255330 // Configure conversion to lower out SCF operations.
256331 ConversionTarget target (getContext ());
0 commit comments