1414
1515#include " mlir/Dialect/Arith/IR/Arith.h"
1616#include " mlir/Dialect/EmitC/IR/EmitC.h"
17+ #include " mlir/Dialect/EmitC/Transforms/TypeConversions.h"
1718#include " mlir/Dialect/SCF/IR/SCF.h"
1819#include " mlir/IR/Builders.h"
1920#include " mlir/IR/BuiltinOps.h"
@@ -39,21 +40,22 @@ struct SCFToEmitCPass : public impl::SCFToEmitCBase<SCFToEmitCPass> {
3940
4041// Lower scf::for to emitc::for, implementing result values using
4142// emitc::variable's updated within the loop body.
42- struct ForLowering : public OpRewritePattern <ForOp> {
43- using OpRewritePattern <ForOp>::OpRewritePattern ;
43+ struct ForLowering : public OpConversionPattern <ForOp> {
44+ using OpConversionPattern <ForOp>::OpConversionPattern ;
4445
45- LogicalResult matchAndRewrite (ForOp forOp,
46- PatternRewriter &rewriter) const override ;
46+ LogicalResult
47+ matchAndRewrite (ForOp forOp, OpAdaptor adaptor,
48+ ConversionPatternRewriter &rewriter) const override ;
4749};
4850
4951// Create an uninitialized emitc::variable op for each result of the given op.
5052template <typename T>
51- static SmallVector<Value> createVariablesForResults (T op,
52- PatternRewriter &rewriter) {
53- SmallVector<Value> resultVariables;
54-
53+ static LogicalResult
54+ createVariablesForResults (T op, const TypeConverter *typeConverter,
55+ ConversionPatternRewriter &rewriter,
56+ SmallVector<Value> &resultVariables) {
5557 if (!op.getNumResults ())
56- return resultVariables ;
58+ return success () ;
5759
5860 Location loc = op->getLoc ();
5961 MLIRContext *context = op.getContext ();
@@ -62,21 +64,23 @@ static SmallVector<Value> createVariablesForResults(T op,
6264 rewriter.setInsertionPoint (op);
6365
6466 for (OpResult result : op.getResults ()) {
65- Type resultType = result.getType ();
67+ Type resultType = typeConverter->convertType (result.getType ());
68+ if (!resultType)
69+ return rewriter.notifyMatchFailure (op, " result type conversion failed" );
6670 Type varType = emitc::LValueType::get (resultType);
6771 emitc::OpaqueAttr noInit = emitc::OpaqueAttr::get (context, " " );
6872 emitc::VariableOp var =
6973 rewriter.create <emitc::VariableOp>(loc, varType, noInit);
7074 resultVariables.push_back (var);
7175 }
7276
73- return resultVariables ;
77+ return success () ;
7478}
7579
7680// Create a series of assign ops assigning given values to given variables at
7781// the current insertion point of given rewriter.
78- static void assignValues (ValueRange values, SmallVector<Value> & variables,
79- PatternRewriter &rewriter, Location loc) {
82+ static void assignValues (ValueRange values, ValueRange variables,
83+ ConversionPatternRewriter &rewriter, Location loc) {
8084 for (auto [value, var] : llvm::zip (values, variables))
8185 rewriter.create <emitc::AssignOp>(loc, var, value);
8286}
@@ -89,46 +93,58 @@ SmallVector<Value> loadValues(const SmallVector<Value> &variables,
8993 });
9094}
9195
92- static void lowerYield (SmallVector<Value> &resultVariables,
93- PatternRewriter &rewriter, scf::YieldOp yield) {
96+ static LogicalResult lowerYield (Operation *op, ValueRange resultVariables,
97+ ConversionPatternRewriter &rewriter,
98+ scf::YieldOp yield) {
9499 Location loc = yield.getLoc ();
95- ValueRange operands = yield.getOperands ();
96100
97101 OpBuilder::InsertionGuard guard (rewriter);
98102 rewriter.setInsertionPoint (yield);
99103
100- assignValues (operands, resultVariables, rewriter, loc);
104+ SmallVector<Value> yieldOperands;
105+ if (failed (rewriter.getRemappedValues (yield.getOperands (), yieldOperands))) {
106+ return rewriter.notifyMatchFailure (op, " failed to lower yield operands" );
107+ }
108+
109+ assignValues (yieldOperands, resultVariables, rewriter, loc);
101110
102111 rewriter.create <emitc::YieldOp>(loc);
103112 rewriter.eraseOp (yield);
113+
114+ return success ();
104115}
105116
106117// Lower the contents of an scf::if/scf::index_switch regions to an
107118// emitc::if/emitc::switch region. The contents of the lowering region is
108119// moved into the respective lowered region, but the scf::yield is replaced not
109120// only with an emitc::yield, but also with a sequence of emitc::assign ops that
110121// set the yielded values into the result variables.
111- static void lowerRegion (SmallVector<Value> & resultVariables,
112- PatternRewriter &rewriter, Region ®ion ,
113- Region &loweredRegion) {
122+ static LogicalResult lowerRegion (Operation *op, ValueRange resultVariables,
123+ ConversionPatternRewriter &rewriter ,
124+ Region ®ion, Region &loweredRegion) {
114125 rewriter.inlineRegionBefore (region, loweredRegion, loweredRegion.end ());
115126 Operation *terminator = loweredRegion.back ().getTerminator ();
116- lowerYield (resultVariables, rewriter, cast<scf::YieldOp>(terminator));
127+ return lowerYield (op, resultVariables, rewriter,
128+ cast<scf::YieldOp>(terminator));
117129}
118130
119- LogicalResult ForLowering::matchAndRewrite (ForOp forOp,
120- PatternRewriter &rewriter) const {
131+ LogicalResult
132+ ForLowering::matchAndRewrite (ForOp forOp, OpAdaptor adaptor,
133+ ConversionPatternRewriter &rewriter) const {
121134 Location loc = forOp.getLoc ();
122135
123136 // Create an emitc::variable op for each result. These variables will be
124137 // assigned to by emitc::assign ops within the loop body.
125- SmallVector<Value> resultVariables =
126- createVariablesForResults (forOp, rewriter);
138+ SmallVector<Value> resultVariables;
139+ if (failed (createVariablesForResults (forOp, getTypeConverter (), rewriter,
140+ resultVariables)))
141+ return rewriter.notifyMatchFailure (forOp,
142+ " create variables for results failed" );
127143
128- assignValues (forOp. getInits (), resultVariables, rewriter, loc);
144+ assignValues (adaptor. getInitArgs (), resultVariables, rewriter, loc);
129145
130146 emitc::ForOp loweredFor = rewriter.create <emitc::ForOp>(
131- loc, forOp .getLowerBound (), forOp .getUpperBound (), forOp .getStep ());
147+ loc, adaptor .getLowerBound (), adaptor .getUpperBound (), adaptor .getStep ());
132148
133149 Block *loweredBody = loweredFor.getBody ();
134150
@@ -143,13 +159,27 @@ LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
143159
144160 rewriter.restoreInsertionPoint (ip);
145161
162+ // Convert the original region types into the new types by adding unrealized
163+ // casts in the beginning of the loop. This performs the conversion in place.
164+ if (failed (rewriter.convertRegionTypes (&forOp.getRegion (),
165+ *getTypeConverter (), nullptr ))) {
166+ return rewriter.notifyMatchFailure (forOp, " region types conversion failed" );
167+ }
168+
169+ // Register the replacements for the block arguments and inline the body of
170+ // the scf.for loop into the body of the emitc::for loop.
171+ Block *scfBody = &(forOp.getRegion ().front ());
146172 SmallVector<Value> replacingValues;
147173 replacingValues.push_back (loweredFor.getInductionVar ());
148174 replacingValues.append (iterArgsValues.begin (), iterArgsValues.end ());
175+ rewriter.mergeBlocks (scfBody, loweredBody, replacingValues);
149176
150- rewriter.mergeBlocks (forOp.getBody (), loweredBody, replacingValues);
151- lowerYield (resultVariables, rewriter,
152- cast<scf::YieldOp>(loweredBody->getTerminator ()));
177+ auto result = lowerYield (forOp, resultVariables, rewriter,
178+ cast<scf::YieldOp>(loweredBody->getTerminator ()));
179+
180+ if (failed (result)) {
181+ return result;
182+ }
153183
154184 // Load variables into SSA values after the for loop.
155185 SmallVector<Value> resultValues = loadValues (resultVariables, rewriter, loc);
@@ -160,38 +190,66 @@ LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
160190
161191// Lower scf::if to emitc::if, implementing result values as emitc::variable's
162192// updated within the then and else regions.
163- struct IfLowering : public OpRewritePattern <IfOp> {
164- using OpRewritePattern <IfOp>::OpRewritePattern ;
193+ struct IfLowering : public OpConversionPattern <IfOp> {
194+ using OpConversionPattern <IfOp>::OpConversionPattern ;
165195
166- LogicalResult matchAndRewrite (IfOp ifOp,
167- PatternRewriter &rewriter) const override ;
196+ LogicalResult
197+ matchAndRewrite (IfOp ifOp, OpAdaptor adaptor,
198+ ConversionPatternRewriter &rewriter) const override ;
168199};
169200
170201} // namespace
171202
172- LogicalResult IfLowering::matchAndRewrite (IfOp ifOp,
173- PatternRewriter &rewriter) const {
203+ LogicalResult
204+ IfLowering::matchAndRewrite (IfOp ifOp, OpAdaptor adaptor,
205+ ConversionPatternRewriter &rewriter) const {
174206 Location loc = ifOp.getLoc ();
175207
176208 // Create an emitc::variable op for each result. These variables will be
177209 // assigned to by emitc::assign ops within the then & else regions.
178- SmallVector<Value> resultVariables =
179- createVariablesForResults (ifOp, rewriter);
180-
181- Region &thenRegion = ifOp.getThenRegion ();
182- Region &elseRegion = ifOp.getElseRegion ();
210+ SmallVector<Value> resultVariables;
211+ if (failed (createVariablesForResults (ifOp, getTypeConverter (), rewriter,
212+ resultVariables)))
213+ return rewriter.notifyMatchFailure (ifOp,
214+ " create variables for results failed" );
215+
216+ // Utility function to lower the contents of an scf::if region to an emitc::if
217+ // region. The contents of the scf::if regions is moved into the respective
218+ // emitc::if regions, but the scf::yield is replaced not only with an
219+ // emitc::yield, but also with a sequence of emitc::assign ops that set the
220+ // yielded values into the result variables.
221+ auto lowerRegion = [&resultVariables, &rewriter,
222+ &ifOp](Region ®ion, Region &loweredRegion) {
223+ rewriter.inlineRegionBefore (region, loweredRegion, loweredRegion.end ());
224+ Operation *terminator = loweredRegion.back ().getTerminator ();
225+ auto result = lowerYield (ifOp, resultVariables, rewriter,
226+ cast<scf::YieldOp>(terminator));
227+ if (failed (result)) {
228+ return result;
229+ }
230+ return success ();
231+ };
232+
233+ Region &thenRegion = adaptor.getThenRegion ();
234+ Region &elseRegion = adaptor.getElseRegion ();
183235
184236 bool hasElseBlock = !elseRegion.empty ();
185237
186238 auto loweredIf =
187- rewriter.create <emitc::IfOp>(loc, ifOp .getCondition (), false , false );
239+ rewriter.create <emitc::IfOp>(loc, adaptor .getCondition (), false , false );
188240
189241 Region &loweredThenRegion = loweredIf.getThenRegion ();
190- lowerRegion (resultVariables, rewriter, thenRegion, loweredThenRegion);
242+ auto result = lowerRegion (thenRegion, loweredThenRegion);
243+ if (failed (result)) {
244+ return result;
245+ }
191246
192247 if (hasElseBlock) {
193248 Region &loweredElseRegion = loweredIf.getElseRegion ();
194- lowerRegion (resultVariables, rewriter, elseRegion, loweredElseRegion);
249+ auto result = lowerRegion (elseRegion, loweredElseRegion);
250+ if (failed (result)) {
251+ return result;
252+ }
195253 }
196254
197255 rewriter.setInsertionPointAfter (ifOp);
@@ -203,37 +261,46 @@ LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
203261
204262// Lower scf::index_switch to emitc::switch, implementing result values as
205263// emitc::variable's updated within the case and default regions.
206- struct IndexSwitchOpLowering : public OpRewritePattern <IndexSwitchOp> {
207- using OpRewritePattern<IndexSwitchOp>::OpRewritePattern ;
264+ struct IndexSwitchOpLowering : public OpConversionPattern <IndexSwitchOp> {
265+ using OpConversionPattern::OpConversionPattern ;
208266
209- LogicalResult matchAndRewrite (IndexSwitchOp indexSwitchOp,
210- PatternRewriter &rewriter) const override ;
267+ LogicalResult
268+ matchAndRewrite (IndexSwitchOp indexSwitchOp, OpAdaptor adaptor,
269+ ConversionPatternRewriter &rewriter) const override ;
211270};
212271
213- LogicalResult
214- IndexSwitchOpLowering::matchAndRewrite ( IndexSwitchOp indexSwitchOp,
215- PatternRewriter &rewriter) const {
272+ LogicalResult IndexSwitchOpLowering::matchAndRewrite (
273+ IndexSwitchOp indexSwitchOp, OpAdaptor adaptor ,
274+ ConversionPatternRewriter &rewriter) const {
216275 Location loc = indexSwitchOp.getLoc ();
217276
218277 // Create an emitc::variable op for each result. These variables will be
219278 // assigned to by emitc::assign ops within the case and default regions.
220- SmallVector<Value> resultVariables =
221- createVariablesForResults (indexSwitchOp, rewriter);
279+ SmallVector<Value> resultVariables;
280+ if (failed (createVariablesForResults (indexSwitchOp, getTypeConverter (),
281+ rewriter, resultVariables))) {
282+ return rewriter.notifyMatchFailure (indexSwitchOp,
283+ " create variables for results failed" );
284+ }
222285
223286 auto loweredSwitch = rewriter.create <emitc::SwitchOp>(
224- loc, indexSwitchOp.getArg (), indexSwitchOp.getCases (),
225- indexSwitchOp.getNumCases ());
287+ loc, adaptor.getArg (), adaptor.getCases (), indexSwitchOp.getNumCases ());
226288
227289 // Lowering all case regions.
228- for (auto pair : llvm::zip (indexSwitchOp.getCaseRegions (),
229- loweredSwitch.getCaseRegions ())) {
230- lowerRegion (resultVariables, rewriter, std::get<0 >(pair),
231- std::get<1 >(pair));
290+ for (auto pair :
291+ llvm::zip (adaptor.getCaseRegions (), loweredSwitch.getCaseRegions ())) {
292+ if (failed (lowerRegion (indexSwitchOp, resultVariables, rewriter,
293+ *std::get<0 >(pair), std::get<1 >(pair)))) {
294+ return failure ();
295+ }
232296 }
233297
234298 // Lowering default region.
235- lowerRegion (resultVariables, rewriter, indexSwitchOp.getDefaultRegion (),
236- loweredSwitch.getDefaultRegion ());
299+ if (failed (lowerRegion (indexSwitchOp, resultVariables, rewriter,
300+ adaptor.getDefaultRegion (),
301+ loweredSwitch.getDefaultRegion ()))) {
302+ return failure ();
303+ }
237304
238305 rewriter.setInsertionPointAfter (indexSwitchOp);
239306 SmallVector<Value> results = loadValues (resultVariables, rewriter, loc);
@@ -242,15 +309,22 @@ IndexSwitchOpLowering::matchAndRewrite(IndexSwitchOp indexSwitchOp,
242309 return success ();
243310}
244311
245- void mlir::populateSCFToEmitCConversionPatterns (RewritePatternSet &patterns) {
246- patterns.add <ForLowering>(patterns.getContext ());
247- patterns.add <IfLowering>(patterns.getContext ());
248- patterns.add <IndexSwitchOpLowering>(patterns.getContext ());
312+ void mlir::populateSCFToEmitCConversionPatterns (RewritePatternSet &patterns,
313+ TypeConverter &typeConverter) {
314+ patterns.add <ForLowering>(typeConverter, patterns.getContext ());
315+ patterns.add <IfLowering>(typeConverter, patterns.getContext ());
316+ patterns.add <IndexSwitchOpLowering>(typeConverter, patterns.getContext ());
249317}
250318
251319void SCFToEmitCPass::runOnOperation () {
252320 RewritePatternSet patterns (&getContext ());
253- populateSCFToEmitCConversionPatterns (patterns);
321+ TypeConverter typeConverter;
322+ // Fallback converter
323+ // See note https://mlir.llvm.org/docs/DialectConversion/#type-converter
324+ // Type converters are called most to least recently inserted
325+ typeConverter.addConversion ([](Type t) { return t; });
326+ populateEmitCSizeTTypeConversions (typeConverter);
327+ populateSCFToEmitCConversionPatterns (patterns, typeConverter);
254328
255329 // Configure conversion to lower out SCF operations.
256330 ConversionTarget target (getContext ());
0 commit comments