Skip to content

Commit f4c05be

Browse files
[mlir][toy] Update dialect conversion example (#150826)
The Toy tutorial used outdated API. Update the example to: * Use the `OpAdaptor` in all places. * Do not mix `RewritePattern` and `ConversionPattern`. This cannot always be done safely and should not be advertised in the example code.
1 parent 0f2484a commit f4c05be

File tree

6 files changed

+193
-262
lines changed

6 files changed

+193
-262
lines changed

mlir/docs/Tutorials/Toy/Ch-5.md

Lines changed: 20 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -91,13 +91,11 @@ doesn't matter. See `ConversionTarget::getOpInfo` for the details.
9191
After the conversion target has been defined, we can define how to convert the
9292
*illegal* operations into *legal* ones. Similarly to the canonicalization
9393
framework introduced in [chapter 3](Ch-3.md), the
94-
[`DialectConversion` framework](../../DialectConversion.md) also uses
95-
[RewritePatterns](../QuickstartRewrites.md) to perform the conversion logic.
96-
These patterns may be the `RewritePatterns` seen before or a new type of pattern
97-
specific to the conversion framework `ConversionPattern`. `ConversionPatterns`
94+
[`DialectConversion` framework](../../DialectConversion.md) uses a special kind
95+
of `ConversionPattern` to perform the conversion logic. `ConversionPatterns`
9896
are different from traditional `RewritePatterns` in that they accept an
99-
additional `operands` parameter containing operands that have been
100-
remapped/replaced. This is used when dealing with type conversions, as the
97+
additional `operands` (or `adaptor`) parameter containing operands that have
98+
been remapped/replaced. This is used when dealing with type conversions, as the
10199
pattern will want to operate on values of the new type but match against the
102100
old. For our lowering, this invariant will be useful as it translates from the
103101
[TensorType](../../Dialects/Builtin.md/#rankedtensortype) currently being
@@ -106,38 +104,23 @@ look at a snippet of lowering the `toy.transpose` operation:
106104

107105
```c++
108106
/// Lower the `toy.transpose` operation to an affine loop nest.
109-
struct TransposeOpLowering : public mlir::ConversionPattern {
110-
TransposeOpLowering(mlir::MLIRContext *ctx)
111-
: mlir::ConversionPattern(TransposeOp::getOperationName(), 1, ctx) {}
112-
113-
/// Match and rewrite the given `toy.transpose` operation, with the given
114-
/// operands that have been remapped from `tensor<...>` to `memref<...>`.
115-
llvm::LogicalResult
116-
matchAndRewrite(mlir::Operation *op, ArrayRef<mlir::Value> operands,
117-
mlir::ConversionPatternRewriter &rewriter) const final {
118-
auto loc = op->getLoc();
107+
struct TransposeOpLowering : public OpConversionPattern<toy::TransposeOp> {
108+
using OpConversionPattern<toy::TransposeOp>::OpConversionPattern;
119109

120-
// Call to a helper function that will lower the current operation to a set
121-
// of affine loops. We provide a functor that operates on the remapped
122-
// operands, as well as the loop induction variables for the inner most
123-
// loop body.
124-
lowerOpToLoops(
125-
op, operands, rewriter,
126-
[loc](mlir::PatternRewriter &rewriter,
127-
ArrayRef<mlir::Value> memRefOperands,
128-
ArrayRef<mlir::Value> loopIvs) {
129-
// Generate an adaptor for the remapped operands of the TransposeOp.
130-
// This allows for using the nice named accessors that are generated
131-
// by the ODS. This adaptor is automatically provided by the ODS
132-
// framework.
133-
TransposeOpAdaptor transposeAdaptor(memRefOperands);
134-
mlir::Value input = transposeAdaptor.input();
135-
136-
// Transpose the elements by generating a load from the reverse
137-
// indices.
138-
SmallVector<mlir::Value, 2> reverseIvs(llvm::reverse(loopIvs));
139-
return mlir::AffineLoadOp::create(rewriter, loc, input, reverseIvs);
140-
});
110+
LogicalResult
111+
matchAndRewrite(toy::TransposeOp op, OpAdaptor adaptor,
112+
ConversionPatternRewriter &rewriter) const final {
113+
auto loc = op->getLoc();
114+
lowerOpToLoops(op, rewriter,
115+
[&](OpBuilder &builder, ValueRange loopIvs) {
116+
Value input = adaptor.getInput();
117+
118+
// Transpose the elements by generating a load from the
119+
// reverse indices.
120+
SmallVector<Value, 2> reverseIvs(llvm::reverse(loopIvs));
121+
return affine::AffineLoadOp::create(builder, loc, input,
122+
reverseIvs);
123+
});
141124
return success();
142125
}
143126
};

mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp

Lines changed: 55 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
using namespace mlir;
4545

4646
//===----------------------------------------------------------------------===//
47-
// ToyToAffine RewritePatterns
47+
// ToyToAffine Conversion Patterns
4848
//===----------------------------------------------------------------------===//
4949

5050
/// Convert the given RankedTensorType into the corresponding MemRefType.
@@ -69,15 +69,13 @@ static Value insertAllocAndDealloc(MemRefType type, Location loc,
6969
}
7070

7171
/// This defines the function type used to process an iteration of a lowered
72-
/// loop. It takes as input an OpBuilder, an range of memRefOperands
73-
/// corresponding to the operands of the input operation, and the range of loop
74-
/// induction variables for the iteration. It returns a value to store at the
75-
/// current index of the iteration.
76-
using LoopIterationFn = function_ref<Value(
77-
OpBuilder &rewriter, ValueRange memRefOperands, ValueRange loopIvs)>;
78-
79-
static void lowerOpToLoops(Operation *op, ValueRange operands,
80-
PatternRewriter &rewriter,
72+
/// loop. It takes as input an OpBuilder and the range of loop induction
73+
/// variables for the iteration. It returns a value to store at the current
74+
/// index of the iteration.
75+
using LoopIterationFn =
76+
function_ref<Value(OpBuilder &rewriter, ValueRange loopIvs)>;
77+
78+
static void lowerOpToLoops(Operation *op, PatternRewriter &rewriter,
8179
LoopIterationFn processIteration) {
8280
auto tensorType = llvm::cast<RankedTensorType>((*op->result_type_begin()));
8381
auto loc = op->getLoc();
@@ -95,10 +93,10 @@ static void lowerOpToLoops(Operation *op, ValueRange operands,
9593
affine::buildAffineLoopNest(
9694
rewriter, loc, lowerBounds, tensorType.getShape(), steps,
9795
[&](OpBuilder &nestedBuilder, Location loc, ValueRange ivs) {
98-
// Call the processing function with the rewriter, the memref operands,
99-
// and the loop induction variables. This function will return the value
100-
// to store at the current index.
101-
Value valueToStore = processIteration(nestedBuilder, operands, ivs);
96+
// Call the processing function with the rewriter and the loop
97+
// induction variables. This function will return the value to store at
98+
// the current index.
99+
Value valueToStore = processIteration(nestedBuilder, ivs);
102100
affine::AffineStoreOp::create(nestedBuilder, loc, valueToStore, alloc,
103101
ivs);
104102
});
@@ -109,53 +107,46 @@ static void lowerOpToLoops(Operation *op, ValueRange operands,
109107

110108
namespace {
111109
//===----------------------------------------------------------------------===//
112-
// ToyToAffine RewritePatterns: Binary operations
110+
// ToyToAffine Conversion Patterns: Binary operations
113111
//===----------------------------------------------------------------------===//
114112

115113
template <typename BinaryOp, typename LoweredBinaryOp>
116-
struct BinaryOpLowering : public ConversionPattern {
117-
BinaryOpLowering(MLIRContext *ctx)
118-
: ConversionPattern(BinaryOp::getOperationName(), 1, ctx) {}
114+
struct BinaryOpLowering : public OpConversionPattern<BinaryOp> {
115+
using OpConversionPattern<BinaryOp>::OpConversionPattern;
116+
using OpAdaptor = typename OpConversionPattern<BinaryOp>::OpAdaptor;
119117

120118
LogicalResult
121-
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
119+
matchAndRewrite(BinaryOp op, OpAdaptor adaptor,
122120
ConversionPatternRewriter &rewriter) const final {
123121
auto loc = op->getLoc();
124-
lowerOpToLoops(op, operands, rewriter,
125-
[loc](OpBuilder &builder, ValueRange memRefOperands,
126-
ValueRange loopIvs) {
127-
// Generate an adaptor for the remapped operands of the
128-
// BinaryOp. This allows for using the nice named accessors
129-
// that are generated by the ODS.
130-
typename BinaryOp::Adaptor binaryAdaptor(memRefOperands);
131-
132-
// Generate loads for the element of 'lhs' and 'rhs' at the
133-
// inner loop.
134-
auto loadedLhs = affine::AffineLoadOp::create(
135-
builder, loc, binaryAdaptor.getLhs(), loopIvs);
136-
auto loadedRhs = affine::AffineLoadOp::create(
137-
builder, loc, binaryAdaptor.getRhs(), loopIvs);
138-
139-
// Create the binary operation performed on the loaded
140-
// values.
141-
return LoweredBinaryOp::create(builder, loc, loadedLhs,
142-
loadedRhs);
143-
});
122+
lowerOpToLoops(op, rewriter, [&](OpBuilder &builder, ValueRange loopIvs) {
123+
// Generate loads for the element of 'lhs' and 'rhs' at the
124+
// inner loop.
125+
auto loadedLhs =
126+
affine::AffineLoadOp::create(builder, loc, adaptor.getLhs(), loopIvs);
127+
auto loadedRhs =
128+
affine::AffineLoadOp::create(builder, loc, adaptor.getRhs(), loopIvs);
129+
130+
// Create the binary operation performed on the loaded
131+
// values.
132+
return LoweredBinaryOp::create(builder, loc, loadedLhs, loadedRhs);
133+
});
144134
return success();
145135
}
146136
};
147137
using AddOpLowering = BinaryOpLowering<toy::AddOp, arith::AddFOp>;
148138
using MulOpLowering = BinaryOpLowering<toy::MulOp, arith::MulFOp>;
149139

150140
//===----------------------------------------------------------------------===//
151-
// ToyToAffine RewritePatterns: Constant operations
141+
// ToyToAffine Conversion Patterns: Constant operations
152142
//===----------------------------------------------------------------------===//
153143

154-
struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> {
155-
using OpRewritePattern<toy::ConstantOp>::OpRewritePattern;
144+
struct ConstantOpLowering : public OpConversionPattern<toy::ConstantOp> {
145+
using OpConversionPattern<toy::ConstantOp>::OpConversionPattern;
156146

157-
LogicalResult matchAndRewrite(toy::ConstantOp op,
158-
PatternRewriter &rewriter) const final {
147+
LogicalResult
148+
matchAndRewrite(toy::ConstantOp op, OpAdaptor adaptor,
149+
ConversionPatternRewriter &rewriter) const final {
159150
DenseElementsAttr constantValue = op.getValue();
160151
Location loc = op.getLoc();
161152

@@ -216,7 +207,7 @@ struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> {
216207
};
217208

218209
//===----------------------------------------------------------------------===//
219-
// ToyToAffine RewritePatterns: Func operations
210+
// ToyToAffine Conversion Patterns: Func operations
220211
//===----------------------------------------------------------------------===//
221212

222213
struct FuncOpLowering : public OpConversionPattern<toy::FuncOp> {
@@ -247,7 +238,7 @@ struct FuncOpLowering : public OpConversionPattern<toy::FuncOp> {
247238
};
248239

249240
//===----------------------------------------------------------------------===//
250-
// ToyToAffine RewritePatterns: Print operations
241+
// ToyToAffine Conversion Patterns: Print operations
251242
//===----------------------------------------------------------------------===//
252243

253244
struct PrintOpLowering : public OpConversionPattern<toy::PrintOp> {
@@ -265,14 +256,15 @@ struct PrintOpLowering : public OpConversionPattern<toy::PrintOp> {
265256
};
266257

267258
//===----------------------------------------------------------------------===//
268-
// ToyToAffine RewritePatterns: Return operations
259+
// ToyToAffine Conversion Patterns: Return operations
269260
//===----------------------------------------------------------------------===//
270261

271-
struct ReturnOpLowering : public OpRewritePattern<toy::ReturnOp> {
272-
using OpRewritePattern<toy::ReturnOp>::OpRewritePattern;
262+
struct ReturnOpLowering : public OpConversionPattern<toy::ReturnOp> {
263+
using OpConversionPattern<toy::ReturnOp>::OpConversionPattern;
273264

274-
LogicalResult matchAndRewrite(toy::ReturnOp op,
275-
PatternRewriter &rewriter) const final {
265+
LogicalResult
266+
matchAndRewrite(toy::ReturnOp op, OpAdaptor adaptor,
267+
ConversionPatternRewriter &rewriter) const final {
276268
// During this lowering, we expect that all function calls have been
277269
// inlined.
278270
if (op.hasOperand())
@@ -285,32 +277,24 @@ struct ReturnOpLowering : public OpRewritePattern<toy::ReturnOp> {
285277
};
286278

287279
//===----------------------------------------------------------------------===//
288-
// ToyToAffine RewritePatterns: Transpose operations
280+
// ToyToAffine Conversion Patterns: Transpose operations
289281
//===----------------------------------------------------------------------===//
290282

291-
struct TransposeOpLowering : public ConversionPattern {
292-
TransposeOpLowering(MLIRContext *ctx)
293-
: ConversionPattern(toy::TransposeOp::getOperationName(), 1, ctx) {}
283+
struct TransposeOpLowering : public OpConversionPattern<toy::TransposeOp> {
284+
using OpConversionPattern<toy::TransposeOp>::OpConversionPattern;
294285

295286
LogicalResult
296-
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
287+
matchAndRewrite(toy::TransposeOp op, OpAdaptor adaptor,
297288
ConversionPatternRewriter &rewriter) const final {
298289
auto loc = op->getLoc();
299-
lowerOpToLoops(op, operands, rewriter,
300-
[loc](OpBuilder &builder, ValueRange memRefOperands,
301-
ValueRange loopIvs) {
302-
// Generate an adaptor for the remapped operands of the
303-
// TransposeOp. This allows for using the nice named
304-
// accessors that are generated by the ODS.
305-
toy::TransposeOpAdaptor transposeAdaptor(memRefOperands);
306-
Value input = transposeAdaptor.getInput();
307-
308-
// Transpose the elements by generating a load from the
309-
// reverse indices.
310-
SmallVector<Value, 2> reverseIvs(llvm::reverse(loopIvs));
311-
return affine::AffineLoadOp::create(builder, loc, input,
312-
reverseIvs);
313-
});
290+
lowerOpToLoops(op, rewriter, [&](OpBuilder &builder, ValueRange loopIvs) {
291+
Value input = adaptor.getInput();
292+
293+
// Transpose the elements by generating a load from the
294+
// reverse indices.
295+
SmallVector<Value, 2> reverseIvs(llvm::reverse(loopIvs));
296+
return affine::AffineLoadOp::create(builder, loc, input, reverseIvs);
297+
});
314298
return success();
315299
}
316300
};

0 commit comments

Comments
 (0)