1414
1515#include " mlir/Dialect/Arith/IR/Arith.h"
1616#include " mlir/Dialect/EmitC/IR/EmitC.h"
17- #include " mlir/Dialect/EmitC/Transforms/TypeConversions.h"
1817#include " mlir/Dialect/SCF/IR/SCF.h"
1918#include " mlir/IR/Builders.h"
2019#include " mlir/IR/BuiltinOps.h"
@@ -40,22 +39,21 @@ struct SCFToEmitCPass : public impl::SCFToEmitCBase<SCFToEmitCPass> {
4039
4140// Lower scf::for to emitc::for, implementing result values using
4241// emitc::variable's updated within the loop body.
43- struct ForLowering : public OpConversionPattern <ForOp> {
44- using OpConversionPattern <ForOp>::OpConversionPattern ;
42+ struct ForLowering : public OpRewritePattern <ForOp> {
43+ using OpRewritePattern <ForOp>::OpRewritePattern ;
4544
46- LogicalResult
47- matchAndRewrite (ForOp forOp, OpAdaptor adaptor,
48- ConversionPatternRewriter &rewriter) const override ;
45+ LogicalResult matchAndRewrite (ForOp forOp,
46+ PatternRewriter &rewriter) const override ;
4947};
5048
5149// Create an uninitialized emitc::variable op for each result of the given op.
5250template <typename T>
53- static LogicalResult
54- createVariablesForResults (T op, const TypeConverter *typeConverter,
55- ConversionPatternRewriter &rewriter,
56- SmallVector<Value> &resultVariables) {
51+ static SmallVector<Value> createVariablesForResults (T op,
52+ PatternRewriter &rewriter) {
53+ SmallVector<Value> resultVariables;
54+
5755 if (!op.getNumResults ())
58- return success () ;
56+ return resultVariables ;
5957
6058 Location loc = op->getLoc ();
6159 MLIRContext *context = op.getContext ();
@@ -64,23 +62,21 @@ createVariablesForResults(T op, const TypeConverter *typeConverter,
6462 rewriter.setInsertionPoint (op);
6563
6664 for (OpResult result : op.getResults ()) {
67- Type resultType = typeConverter->convertType (result.getType ());
68- if (!resultType)
69- return rewriter.notifyMatchFailure (op, " result type conversion failed" );
65+ Type resultType = result.getType ();
7066 Type varType = emitc::LValueType::get (resultType);
7167 emitc::OpaqueAttr noInit = emitc::OpaqueAttr::get (context, " " );
7268 emitc::VariableOp var =
7369 rewriter.create <emitc::VariableOp>(loc, varType, noInit);
7470 resultVariables.push_back (var);
7571 }
7672
77- return success () ;
73+ return resultVariables ;
7874}
7975
8076// Create a series of assign ops assigning given values to given variables at
8177// the current insertion point of given rewriter.
82- static void assignValues (ValueRange values, ValueRange variables,
83- ConversionPatternRewriter &rewriter, Location loc) {
78+ static void assignValues (ValueRange values, SmallVector<Value> & variables,
79+ PatternRewriter &rewriter, Location loc) {
8480 for (auto [value, var] : llvm::zip (values, variables))
8581 rewriter.create <emitc::AssignOp>(loc, var, value);
8682}
@@ -93,58 +89,46 @@ SmallVector<Value> loadValues(const SmallVector<Value> &variables,
9389 });
9490}
9591
96- static LogicalResult lowerYield (Operation *op, ValueRange resultVariables,
97- ConversionPatternRewriter &rewriter,
98- scf::YieldOp yield) {
92+ static void lowerYield (SmallVector<Value> &resultVariables,
93+ PatternRewriter &rewriter, scf::YieldOp yield) {
9994 Location loc = yield.getLoc ();
95+ ValueRange operands = yield.getOperands ();
10096
10197 OpBuilder::InsertionGuard guard (rewriter);
10298 rewriter.setInsertionPoint (yield);
10399
104- SmallVector<Value> yieldOperands;
105- if (failed (rewriter.getRemappedValues (yield.getOperands (), yieldOperands))) {
106- return rewriter.notifyMatchFailure (op, " failed to lower yield operands" );
107- }
108-
109- assignValues (yieldOperands, resultVariables, rewriter, loc);
100+ assignValues (operands, resultVariables, rewriter, loc);
110101
111102 rewriter.create <emitc::YieldOp>(loc);
112103 rewriter.eraseOp (yield);
113-
114- return success ();
115104}
116105
117106// Lower the contents of an scf::if/scf::index_switch regions to an
118107// emitc::if/emitc::switch region. The contents of the lowering region is
119108// moved into the respective lowered region, but the scf::yield is replaced not
120109// only with an emitc::yield, but also with a sequence of emitc::assign ops that
121110// set the yielded values into the result variables.
122- static LogicalResult lowerRegion (Operation *op, ValueRange resultVariables,
123- ConversionPatternRewriter &rewriter ,
124- Region ®ion, Region &loweredRegion) {
111+ static void lowerRegion (SmallVector<Value> & resultVariables,
112+ PatternRewriter &rewriter, Region ®ion ,
113+ Region &loweredRegion) {
125114 rewriter.inlineRegionBefore (region, loweredRegion, loweredRegion.end ());
126115 Operation *terminator = loweredRegion.back ().getTerminator ();
127- return lowerYield (op, resultVariables, rewriter,
128- cast<scf::YieldOp>(terminator));
116+ lowerYield (resultVariables, rewriter, cast<scf::YieldOp>(terminator));
129117}
130118
131- LogicalResult
132- ForLowering::matchAndRewrite (ForOp forOp, OpAdaptor adaptor,
133- ConversionPatternRewriter &rewriter) const {
119+ LogicalResult ForLowering::matchAndRewrite (ForOp forOp,
120+ PatternRewriter &rewriter) const {
134121 Location loc = forOp.getLoc ();
135122
136123 // Create an emitc::variable op for each result. These variables will be
137124 // assigned to by emitc::assign ops within the loop body.
138- SmallVector<Value> resultVariables;
139- if (failed (createVariablesForResults (forOp, getTypeConverter (), rewriter,
140- resultVariables)))
141- return rewriter.notifyMatchFailure (forOp,
142- " create variables for results failed" );
125+ SmallVector<Value> resultVariables =
126+ createVariablesForResults (forOp, rewriter);
143127
144- assignValues (adaptor. getInitArgs (), resultVariables, rewriter, loc);
128+ assignValues (forOp. getInits (), resultVariables, rewriter, loc);
145129
146130 emitc::ForOp loweredFor = rewriter.create <emitc::ForOp>(
147- loc, adaptor .getLowerBound (), adaptor .getUpperBound (), adaptor .getStep ());
131+ loc, forOp .getLowerBound (), forOp .getUpperBound (), forOp .getStep ());
148132
149133 Block *loweredBody = loweredFor.getBody ();
150134
@@ -159,27 +143,13 @@ ForLowering::matchAndRewrite(ForOp forOp, OpAdaptor adaptor,
159143
160144 rewriter.restoreInsertionPoint (ip);
161145
162- // Convert the original region types into the new types by adding unrealized
163- // casts in the beginning of the loop. This performs the conversion in place.
164- if (failed (rewriter.convertRegionTypes (&forOp.getRegion (),
165- *getTypeConverter (), nullptr ))) {
166- return rewriter.notifyMatchFailure (forOp, " region types conversion failed" );
167- }
168-
169- // Register the replacements for the block arguments and inline the body of
170- // the scf.for loop into the body of the emitc::for loop.
171- Block *scfBody = &(forOp.getRegion ().front ());
172146 SmallVector<Value> replacingValues;
173147 replacingValues.push_back (loweredFor.getInductionVar ());
174148 replacingValues.append (iterArgsValues.begin (), iterArgsValues.end ());
175- rewriter.mergeBlocks (scfBody, loweredBody, replacingValues);
176149
177- auto result = lowerYield (forOp, resultVariables, rewriter,
178- cast<scf::YieldOp>(loweredBody->getTerminator ()));
179-
180- if (failed (result)) {
181- return result;
182- }
150+ rewriter.mergeBlocks (forOp.getBody (), loweredBody, replacingValues);
151+ lowerYield (resultVariables, rewriter,
152+ cast<scf::YieldOp>(loweredBody->getTerminator ()));
183153
184154 // Load variables into SSA values after the for loop.
185155 SmallVector<Value> resultValues = loadValues (resultVariables, rewriter, loc);
@@ -190,66 +160,38 @@ ForLowering::matchAndRewrite(ForOp forOp, OpAdaptor adaptor,
190160
191161// Lower scf::if to emitc::if, implementing result values as emitc::variable's
192162// updated within the then and else regions.
193- struct IfLowering : public OpConversionPattern <IfOp> {
194- using OpConversionPattern <IfOp>::OpConversionPattern ;
163+ struct IfLowering : public OpRewritePattern <IfOp> {
164+ using OpRewritePattern <IfOp>::OpRewritePattern ;
195165
196- LogicalResult
197- matchAndRewrite (IfOp ifOp, OpAdaptor adaptor,
198- ConversionPatternRewriter &rewriter) const override ;
166+ LogicalResult matchAndRewrite (IfOp ifOp,
167+ PatternRewriter &rewriter) const override ;
199168};
200169
201170} // namespace
202171
203- LogicalResult
204- IfLowering::matchAndRewrite (IfOp ifOp, OpAdaptor adaptor,
205- ConversionPatternRewriter &rewriter) const {
172+ LogicalResult IfLowering::matchAndRewrite (IfOp ifOp,
173+ PatternRewriter &rewriter) const {
206174 Location loc = ifOp.getLoc ();
207175
208176 // Create an emitc::variable op for each result. These variables will be
209177 // assigned to by emitc::assign ops within the then & else regions.
210- SmallVector<Value> resultVariables;
211- if (failed (createVariablesForResults (ifOp, getTypeConverter (), rewriter,
212- resultVariables)))
213- return rewriter.notifyMatchFailure (ifOp,
214- " create variables for results failed" );
215-
216- // Utility function to lower the contents of an scf::if region to an emitc::if
217- // region. The contents of the scf::if regions is moved into the respective
218- // emitc::if regions, but the scf::yield is replaced not only with an
219- // emitc::yield, but also with a sequence of emitc::assign ops that set the
220- // yielded values into the result variables.
221- auto lowerRegion = [&resultVariables, &rewriter,
222- &ifOp](Region ®ion, Region &loweredRegion) {
223- rewriter.inlineRegionBefore (region, loweredRegion, loweredRegion.end ());
224- Operation *terminator = loweredRegion.back ().getTerminator ();
225- auto result = lowerYield (ifOp, resultVariables, rewriter,
226- cast<scf::YieldOp>(terminator));
227- if (failed (result)) {
228- return result;
229- }
230- return success ();
231- };
232-
233- Region &thenRegion = adaptor.getThenRegion ();
234- Region &elseRegion = adaptor.getElseRegion ();
178+ SmallVector<Value> resultVariables =
179+ createVariablesForResults (ifOp, rewriter);
180+
181+ Region &thenRegion = ifOp.getThenRegion ();
182+ Region &elseRegion = ifOp.getElseRegion ();
235183
236184 bool hasElseBlock = !elseRegion.empty ();
237185
238186 auto loweredIf =
239- rewriter.create <emitc::IfOp>(loc, adaptor .getCondition (), false , false );
187+ rewriter.create <emitc::IfOp>(loc, ifOp .getCondition (), false , false );
240188
241189 Region &loweredThenRegion = loweredIf.getThenRegion ();
242- auto result = lowerRegion (thenRegion, loweredThenRegion);
243- if (failed (result)) {
244- return result;
245- }
190+ lowerRegion (resultVariables, rewriter, thenRegion, loweredThenRegion);
246191
247192 if (hasElseBlock) {
248193 Region &loweredElseRegion = loweredIf.getElseRegion ();
249- auto result = lowerRegion (elseRegion, loweredElseRegion);
250- if (failed (result)) {
251- return result;
252- }
194+ lowerRegion (resultVariables, rewriter, elseRegion, loweredElseRegion);
253195 }
254196
255197 rewriter.setInsertionPointAfter (ifOp);
@@ -261,46 +203,37 @@ IfLowering::matchAndRewrite(IfOp ifOp, OpAdaptor adaptor,
261203
262204// Lower scf::index_switch to emitc::switch, implementing result values as
263205// emitc::variable's updated within the case and default regions.
264- struct IndexSwitchOpLowering : public OpConversionPattern <IndexSwitchOp> {
265- using OpConversionPattern::OpConversionPattern ;
206+ struct IndexSwitchOpLowering : public OpRewritePattern <IndexSwitchOp> {
207+ using OpRewritePattern<IndexSwitchOp>::OpRewritePattern ;
266208
267- LogicalResult
268- matchAndRewrite (IndexSwitchOp indexSwitchOp, OpAdaptor adaptor,
269- ConversionPatternRewriter &rewriter) const override ;
209+ LogicalResult matchAndRewrite (IndexSwitchOp indexSwitchOp,
210+ PatternRewriter &rewriter) const override ;
270211};
271212
272- LogicalResult IndexSwitchOpLowering::matchAndRewrite (
273- IndexSwitchOp indexSwitchOp, OpAdaptor adaptor ,
274- ConversionPatternRewriter &rewriter) const {
213+ LogicalResult
214+ IndexSwitchOpLowering::matchAndRewrite ( IndexSwitchOp indexSwitchOp,
215+ PatternRewriter &rewriter) const {
275216 Location loc = indexSwitchOp.getLoc ();
276217
277218 // Create an emitc::variable op for each result. These variables will be
278219 // assigned to by emitc::assign ops within the case and default regions.
279- SmallVector<Value> resultVariables;
280- if (failed (createVariablesForResults (indexSwitchOp, getTypeConverter (),
281- rewriter, resultVariables))) {
282- return rewriter.notifyMatchFailure (indexSwitchOp,
283- " create variables for results failed" );
284- }
220+ SmallVector<Value> resultVariables =
221+ createVariablesForResults (indexSwitchOp, rewriter);
285222
286223 auto loweredSwitch = rewriter.create <emitc::SwitchOp>(
287- loc, adaptor.getArg (), adaptor.getCases (), indexSwitchOp.getNumCases ());
224+ loc, indexSwitchOp.getArg (), indexSwitchOp.getCases (),
225+ indexSwitchOp.getNumCases ());
288226
289227 // Lowering all case regions.
290- for (auto pair :
291- llvm::zip (adaptor.getCaseRegions (), loweredSwitch.getCaseRegions ())) {
292- if (failed (lowerRegion (indexSwitchOp, resultVariables, rewriter,
293- *std::get<0 >(pair), std::get<1 >(pair)))) {
294- return failure ();
295- }
228+ for (auto pair : llvm::zip (indexSwitchOp.getCaseRegions (),
229+ loweredSwitch.getCaseRegions ())) {
230+ lowerRegion (resultVariables, rewriter, std::get<0 >(pair),
231+ std::get<1 >(pair));
296232 }
297233
298234 // Lowering default region.
299- if (failed (lowerRegion (indexSwitchOp, resultVariables, rewriter,
300- adaptor.getDefaultRegion (),
301- loweredSwitch.getDefaultRegion ()))) {
302- return failure ();
303- }
235+ lowerRegion (resultVariables, rewriter, indexSwitchOp.getDefaultRegion (),
236+ loweredSwitch.getDefaultRegion ());
304237
305238 rewriter.setInsertionPointAfter (indexSwitchOp);
306239 SmallVector<Value> results = loadValues (resultVariables, rewriter, loc);
@@ -309,22 +242,15 @@ LogicalResult IndexSwitchOpLowering::matchAndRewrite(
309242 return success ();
310243}
311244
312- void mlir::populateSCFToEmitCConversionPatterns (RewritePatternSet &patterns,
313- TypeConverter &typeConverter) {
314- patterns.add <ForLowering>(typeConverter, patterns.getContext ());
315- patterns.add <IfLowering>(typeConverter, patterns.getContext ());
316- patterns.add <IndexSwitchOpLowering>(typeConverter, patterns.getContext ());
245+ void mlir::populateSCFToEmitCConversionPatterns (RewritePatternSet &patterns) {
246+ patterns.add <ForLowering>(patterns.getContext ());
247+ patterns.add <IfLowering>(patterns.getContext ());
248+ patterns.add <IndexSwitchOpLowering>(patterns.getContext ());
317249}
318250
319251void SCFToEmitCPass::runOnOperation () {
320252 RewritePatternSet patterns (&getContext ());
321- TypeConverter typeConverter;
322- // Fallback converter
323- // See note https://mlir.llvm.org/docs/DialectConversion/#type-converter
324- // Type converters are called most to least recently inserted
325- typeConverter.addConversion ([](Type t) { return t; });
326- populateEmitCSizeTTypeConversions (typeConverter);
327- populateSCFToEmitCConversionPatterns (patterns, typeConverter);
253+ populateSCFToEmitCConversionPatterns (patterns);
328254
329255 // Configure conversion to lower out SCF operations.
330256 ConversionTarget target (getContext ());
0 commit comments