Skip to content

Commit cd10aea

Browse files
committed
SCFToEmitC: Convert types while converting from SCF to EmitC
Switch from rewrite patterns to conversion patterns. This allows to perform type conversions together with other parts of the IR. For example, this allows to convert from index to emit.size_t types
1 parent e9be217 commit cd10aea

File tree

4 files changed

+229
-80
lines changed

4 files changed

+229
-80
lines changed

mlir/include/mlir/Conversion/SCFToEmitC/SCFToEmitC.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#ifndef MLIR_CONVERSION_SCFTOEMITC_SCFTOEMITC_H
1010
#define MLIR_CONVERSION_SCFTOEMITC_SCFTOEMITC_H
1111

12+
#include "mlir/Transforms/DialectConversion.h"
1213
#include <memory>
1314

1415
namespace mlir {
@@ -19,7 +20,8 @@ class RewritePatternSet;
1920
#include "mlir/Conversion/Passes.h.inc"
2021

2122
/// Collect a set of patterns to convert SCF operations to the EmitC dialect.
22-
void populateSCFToEmitCConversionPatterns(RewritePatternSet &patterns);
23+
void populateSCFToEmitCConversionPatterns(RewritePatternSet &patterns,
24+
TypeConverter &typeConverter);
2325
} // namespace mlir
2426

2527
#endif // MLIR_CONVERSION_SCFTOEMITC_SCFTOEMITC_H

mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp

Lines changed: 141 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,15 @@
1414

1515
#include "mlir/Dialect/Arith/IR/Arith.h"
1616
#include "mlir/Dialect/EmitC/IR/EmitC.h"
17+
#include "mlir/Dialect/EmitC/Transforms/TypeConversions.h"
1718
#include "mlir/Dialect/SCF/IR/SCF.h"
1819
#include "mlir/IR/Builders.h"
1920
#include "mlir/IR/BuiltinOps.h"
2021
#include "mlir/IR/IRMapping.h"
2122
#include "mlir/IR/MLIRContext.h"
2223
#include "mlir/IR/PatternMatch.h"
2324
#include "mlir/Transforms/DialectConversion.h"
25+
#include "mlir/Transforms/OneToNTypeConversion.h"
2426
#include "mlir/Transforms/Passes.h"
2527

2628
namespace mlir {
@@ -39,21 +41,22 @@ struct SCFToEmitCPass : public impl::SCFToEmitCBase<SCFToEmitCPass> {
3941

4042
// Lower scf::for to emitc::for, implementing result values using
4143
// emitc::variable's updated within the loop body.
42-
struct ForLowering : public OpRewritePattern<ForOp> {
43-
using OpRewritePattern<ForOp>::OpRewritePattern;
44+
struct ForLowering : public OpConversionPattern<ForOp> {
45+
using OpConversionPattern<ForOp>::OpConversionPattern;
4446

45-
LogicalResult matchAndRewrite(ForOp forOp,
46-
PatternRewriter &rewriter) const override;
47+
LogicalResult
48+
matchAndRewrite(ForOp forOp, OpAdaptor adaptor,
49+
ConversionPatternRewriter &rewriter) const override;
4750
};
4851

4952
// Create an uninitialized emitc::variable op for each result of the given op.
5053
template <typename T>
51-
static SmallVector<Value> createVariablesForResults(T op,
52-
PatternRewriter &rewriter) {
53-
SmallVector<Value> resultVariables;
54-
54+
static LogicalResult
55+
createVariablesForResults(T op, const TypeConverter *typeConverter,
56+
ConversionPatternRewriter &rewriter,
57+
SmallVector<Value> &resultVariables) {
5558
if (!op.getNumResults())
56-
return resultVariables;
59+
return success();
5760

5861
Location loc = op->getLoc();
5962
MLIRContext *context = op.getContext();
@@ -62,21 +65,23 @@ static SmallVector<Value> createVariablesForResults(T op,
6265
rewriter.setInsertionPoint(op);
6366

6467
for (OpResult result : op.getResults()) {
65-
Type resultType = result.getType();
68+
Type resultType = typeConverter->convertType(result.getType());
69+
if (!resultType)
70+
return rewriter.notifyMatchFailure(op, "result type conversion failed");
6671
Type varType = emitc::LValueType::get(resultType);
6772
emitc::OpaqueAttr noInit = emitc::OpaqueAttr::get(context, "");
6873
emitc::VariableOp var =
6974
rewriter.create<emitc::VariableOp>(loc, varType, noInit);
7075
resultVariables.push_back(var);
7176
}
7277

73-
return resultVariables;
78+
return success();
7479
}
7580

7681
// Create a series of assign ops assigning given values to given variables at
7782
// the current insertion point of given rewriter.
78-
static void assignValues(ValueRange values, SmallVector<Value> &variables,
79-
PatternRewriter &rewriter, Location loc) {
83+
static void assignValues(ValueRange values, ValueRange variables,
84+
ConversionPatternRewriter &rewriter, Location loc) {
8085
for (auto [value, var] : llvm::zip(values, variables))
8186
rewriter.create<emitc::AssignOp>(loc, var, value);
8287
}
@@ -89,46 +94,58 @@ SmallVector<Value> loadValues(const SmallVector<Value> &variables,
8994
});
9095
}
9196

92-
static void lowerYield(SmallVector<Value> &resultVariables,
93-
PatternRewriter &rewriter, scf::YieldOp yield) {
97+
static LogicalResult lowerYield(Operation *op, ValueRange resultVariables,
98+
ConversionPatternRewriter &rewriter,
99+
scf::YieldOp yield) {
94100
Location loc = yield.getLoc();
95-
ValueRange operands = yield.getOperands();
96101

97102
OpBuilder::InsertionGuard guard(rewriter);
98103
rewriter.setInsertionPoint(yield);
99104

100-
assignValues(operands, resultVariables, rewriter, loc);
105+
SmallVector<Value> yieldOperands;
106+
if (failed(rewriter.getRemappedValues(yield.getOperands(), yieldOperands))) {
107+
return rewriter.notifyMatchFailure(op, "failed to lower yield operands");
108+
}
109+
110+
assignValues(yieldOperands, resultVariables, rewriter, loc);
101111

102112
rewriter.create<emitc::YieldOp>(loc);
103113
rewriter.eraseOp(yield);
114+
115+
return success();
104116
}
105117

106118
// Lower the contents of an scf::if/scf::index_switch regions to an
107119
// emitc::if/emitc::switch region. The contents of the lowering region is
108120
// moved into the respective lowered region, but the scf::yield is replaced not
109121
// only with an emitc::yield, but also with a sequence of emitc::assign ops that
110122
// set the yielded values into the result variables.
111-
static void lowerRegion(SmallVector<Value> &resultVariables,
112-
PatternRewriter &rewriter, Region &region,
113-
Region &loweredRegion) {
123+
static LogicalResult lowerRegion(Operation *op, ValueRange resultVariables,
124+
ConversionPatternRewriter &rewriter,
125+
Region &region, Region &loweredRegion) {
114126
rewriter.inlineRegionBefore(region, loweredRegion, loweredRegion.end());
115127
Operation *terminator = loweredRegion.back().getTerminator();
116-
lowerYield(resultVariables, rewriter, cast<scf::YieldOp>(terminator));
128+
return lowerYield(op, resultVariables, rewriter,
129+
cast<scf::YieldOp>(terminator));
117130
}
118131

119-
LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
120-
PatternRewriter &rewriter) const {
132+
LogicalResult
133+
ForLowering::matchAndRewrite(ForOp forOp, OpAdaptor adaptor,
134+
ConversionPatternRewriter &rewriter) const {
121135
Location loc = forOp.getLoc();
122136

123137
// Create an emitc::variable op for each result. These variables will be
124138
// assigned to by emitc::assign ops within the loop body.
125-
SmallVector<Value> resultVariables =
126-
createVariablesForResults(forOp, rewriter);
139+
SmallVector<Value> resultVariables;
140+
if (failed(createVariablesForResults(forOp, getTypeConverter(), rewriter,
141+
resultVariables)))
142+
return rewriter.notifyMatchFailure(forOp,
143+
"create variables for results failed");
127144

128-
assignValues(forOp.getInits(), resultVariables, rewriter, loc);
145+
assignValues(adaptor.getInitArgs(), resultVariables, rewriter, loc);
129146

130147
emitc::ForOp loweredFor = rewriter.create<emitc::ForOp>(
131-
loc, forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep());
148+
loc, adaptor.getLowerBound(), adaptor.getUpperBound(), adaptor.getStep());
132149

133150
Block *loweredBody = loweredFor.getBody();
134151

@@ -143,13 +160,27 @@ LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
143160

144161
rewriter.restoreInsertionPoint(ip);
145162

163+
// Convert the original region types into the new types by adding unrealized
164+
// casts in the beginning of the loop. This performs the conversion in place.
165+
if (failed(rewriter.convertRegionTypes(&forOp.getRegion(),
166+
*getTypeConverter(), nullptr))) {
167+
return rewriter.notifyMatchFailure(forOp, "region types conversion failed");
168+
}
169+
170+
// Register the replacements for the block arguments and inline the body of
171+
// the scf.for loop into the body of the emitc::for loop.
172+
Block *scfBody = &(forOp.getRegion().front());
146173
SmallVector<Value> replacingValues;
147174
replacingValues.push_back(loweredFor.getInductionVar());
148175
replacingValues.append(iterArgsValues.begin(), iterArgsValues.end());
176+
rewriter.mergeBlocks(scfBody, loweredBody, replacingValues);
149177

150-
rewriter.mergeBlocks(forOp.getBody(), loweredBody, replacingValues);
151-
lowerYield(resultVariables, rewriter,
152-
cast<scf::YieldOp>(loweredBody->getTerminator()));
178+
auto result = lowerYield(forOp, resultVariables, rewriter,
179+
cast<scf::YieldOp>(loweredBody->getTerminator()));
180+
181+
if (failed(result)) {
182+
return result;
183+
}
153184

154185
// Load variables into SSA values after the for loop.
155186
SmallVector<Value> resultValues = loadValues(resultVariables, rewriter, loc);
@@ -160,38 +191,66 @@ LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
160191

161192
// Lower scf::if to emitc::if, implementing result values as emitc::variable's
162193
// updated within the then and else regions.
163-
struct IfLowering : public OpRewritePattern<IfOp> {
164-
using OpRewritePattern<IfOp>::OpRewritePattern;
194+
struct IfLowering : public OpConversionPattern<IfOp> {
195+
using OpConversionPattern<IfOp>::OpConversionPattern;
165196

166-
LogicalResult matchAndRewrite(IfOp ifOp,
167-
PatternRewriter &rewriter) const override;
197+
LogicalResult
198+
matchAndRewrite(IfOp ifOp, OpAdaptor adaptor,
199+
ConversionPatternRewriter &rewriter) const override;
168200
};
169201

170202
} // namespace
171203

172-
LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
173-
PatternRewriter &rewriter) const {
204+
LogicalResult
205+
IfLowering::matchAndRewrite(IfOp ifOp, OpAdaptor adaptor,
206+
ConversionPatternRewriter &rewriter) const {
174207
Location loc = ifOp.getLoc();
175208

176209
// Create an emitc::variable op for each result. These variables will be
177210
// assigned to by emitc::assign ops within the then & else regions.
178-
SmallVector<Value> resultVariables =
179-
createVariablesForResults(ifOp, rewriter);
180-
181-
Region &thenRegion = ifOp.getThenRegion();
182-
Region &elseRegion = ifOp.getElseRegion();
211+
SmallVector<Value> resultVariables;
212+
if (failed(createVariablesForResults(ifOp, getTypeConverter(), rewriter,
213+
resultVariables)))
214+
return rewriter.notifyMatchFailure(ifOp,
215+
"create variables for results failed");
216+
217+
// Utility function to lower the contents of an scf::if region to an emitc::if
218+
// region. The contents of the scf::if regions is moved into the respective
219+
// emitc::if regions, but the scf::yield is replaced not only with an
220+
// emitc::yield, but also with a sequence of emitc::assign ops that set the
221+
// yielded values into the result variables.
222+
auto lowerRegion = [&resultVariables, &rewriter,
223+
&ifOp](Region &region, Region &loweredRegion) {
224+
rewriter.inlineRegionBefore(region, loweredRegion, loweredRegion.end());
225+
Operation *terminator = loweredRegion.back().getTerminator();
226+
auto result = lowerYield(ifOp, resultVariables, rewriter,
227+
cast<scf::YieldOp>(terminator));
228+
if (failed(result)) {
229+
return result;
230+
}
231+
return success();
232+
};
233+
234+
Region &thenRegion = adaptor.getThenRegion();
235+
Region &elseRegion = adaptor.getElseRegion();
183236

184237
bool hasElseBlock = !elseRegion.empty();
185238

186239
auto loweredIf =
187-
rewriter.create<emitc::IfOp>(loc, ifOp.getCondition(), false, false);
240+
rewriter.create<emitc::IfOp>(loc, adaptor.getCondition(), false, false);
188241

189242
Region &loweredThenRegion = loweredIf.getThenRegion();
190-
lowerRegion(resultVariables, rewriter, thenRegion, loweredThenRegion);
243+
auto result = lowerRegion(thenRegion, loweredThenRegion);
244+
if (failed(result)) {
245+
return result;
246+
}
191247

192248
if (hasElseBlock) {
193249
Region &loweredElseRegion = loweredIf.getElseRegion();
194-
lowerRegion(resultVariables, rewriter, elseRegion, loweredElseRegion);
250+
auto result = lowerRegion(elseRegion, loweredElseRegion);
251+
if (failed(result)) {
252+
return result;
253+
}
195254
}
196255

197256
rewriter.setInsertionPointAfter(ifOp);
@@ -203,37 +262,46 @@ LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
203262

204263
// Lower scf::index_switch to emitc::switch, implementing result values as
205264
// emitc::variable's updated within the case and default regions.
206-
struct IndexSwitchOpLowering : public OpRewritePattern<IndexSwitchOp> {
207-
using OpRewritePattern<IndexSwitchOp>::OpRewritePattern;
265+
struct IndexSwitchOpLowering : public OpConversionPattern<IndexSwitchOp> {
266+
using OpConversionPattern::OpConversionPattern;
208267

209-
LogicalResult matchAndRewrite(IndexSwitchOp indexSwitchOp,
210-
PatternRewriter &rewriter) const override;
268+
LogicalResult
269+
matchAndRewrite(IndexSwitchOp indexSwitchOp, OpAdaptor adaptor,
270+
ConversionPatternRewriter &rewriter) const override;
211271
};
212272

213-
LogicalResult
214-
IndexSwitchOpLowering::matchAndRewrite(IndexSwitchOp indexSwitchOp,
215-
PatternRewriter &rewriter) const {
273+
LogicalResult IndexSwitchOpLowering::matchAndRewrite(
274+
IndexSwitchOp indexSwitchOp, OpAdaptor adaptor,
275+
ConversionPatternRewriter &rewriter) const {
216276
Location loc = indexSwitchOp.getLoc();
217277

218278
// Create an emitc::variable op for each result. These variables will be
219279
// assigned to by emitc::assign ops within the case and default regions.
220-
SmallVector<Value> resultVariables =
221-
createVariablesForResults(indexSwitchOp, rewriter);
280+
SmallVector<Value> resultVariables;
281+
if (failed(createVariablesForResults(indexSwitchOp, getTypeConverter(),
282+
rewriter, resultVariables))) {
283+
return rewriter.notifyMatchFailure(indexSwitchOp,
284+
"create variables for results failed");
285+
}
222286

223287
auto loweredSwitch = rewriter.create<emitc::SwitchOp>(
224-
loc, indexSwitchOp.getArg(), indexSwitchOp.getCases(),
225-
indexSwitchOp.getNumCases());
288+
loc, adaptor.getArg(), adaptor.getCases(), indexSwitchOp.getNumCases());
226289

227290
// Lowering all case regions.
228-
for (auto pair : llvm::zip(indexSwitchOp.getCaseRegions(),
229-
loweredSwitch.getCaseRegions())) {
230-
lowerRegion(resultVariables, rewriter, std::get<0>(pair),
231-
std::get<1>(pair));
291+
for (auto pair :
292+
llvm::zip(adaptor.getCaseRegions(), loweredSwitch.getCaseRegions())) {
293+
if (failed(lowerRegion(indexSwitchOp, resultVariables, rewriter,
294+
*std::get<0>(pair), std::get<1>(pair)))) {
295+
return failure();
296+
}
232297
}
233298

234299
// Lowering default region.
235-
lowerRegion(resultVariables, rewriter, indexSwitchOp.getDefaultRegion(),
236-
loweredSwitch.getDefaultRegion());
300+
if (failed(lowerRegion(indexSwitchOp, resultVariables, rewriter,
301+
adaptor.getDefaultRegion(),
302+
loweredSwitch.getDefaultRegion()))) {
303+
return failure();
304+
}
237305

238306
rewriter.setInsertionPointAfter(indexSwitchOp);
239307
SmallVector<Value> results = loadValues(resultVariables, rewriter, loc);
@@ -242,15 +310,22 @@ IndexSwitchOpLowering::matchAndRewrite(IndexSwitchOp indexSwitchOp,
242310
return success();
243311
}
244312

245-
void mlir::populateSCFToEmitCConversionPatterns(RewritePatternSet &patterns) {
246-
patterns.add<ForLowering>(patterns.getContext());
247-
patterns.add<IfLowering>(patterns.getContext());
248-
patterns.add<IndexSwitchOpLowering>(patterns.getContext());
313+
void mlir::populateSCFToEmitCConversionPatterns(RewritePatternSet &patterns,
314+
TypeConverter &typeConverter) {
315+
patterns.add<ForLowering>(typeConverter, patterns.getContext());
316+
patterns.add<IfLowering>(typeConverter, patterns.getContext());
317+
patterns.add<IndexSwitchOpLowering>(typeConverter, patterns.getContext());
249318
}
250319

251320
void SCFToEmitCPass::runOnOperation() {
252321
RewritePatternSet patterns(&getContext());
253-
populateSCFToEmitCConversionPatterns(patterns);
322+
TypeConverter typeConverter;
323+
// Fallback converter
324+
// See note https://mlir.llvm.org/docs/DialectConversion/#type-converter
325+
// Type converters are called most to least recently inserted
326+
typeConverter.addConversion([](Type t) { return t; });
327+
populateEmitCSizeTTypeConversions(typeConverter);
328+
populateSCFToEmitCConversionPatterns(patterns, typeConverter);
254329

255330
// Configure conversion to lower out SCF operations.
256331
ConversionTarget target(getContext());

0 commit comments

Comments
 (0)