44
44
using namespace mlir ;
45
45
46
46
// ===----------------------------------------------------------------------===//
47
- // ToyToAffine RewritePatterns
47
+ // ToyToAffine Conversion Patterns
48
48
// ===----------------------------------------------------------------------===//
49
49
50
50
// / Convert the given RankedTensorType into the corresponding MemRefType.
@@ -69,15 +69,13 @@ static Value insertAllocAndDealloc(MemRefType type, Location loc,
69
69
}
70
70
71
71
// / This defines the function type used to process an iteration of a lowered
72
- // / loop. It takes as input an OpBuilder, an range of memRefOperands
73
- // / corresponding to the operands of the input operation, and the range of loop
74
- // / induction variables for the iteration. It returns a value to store at the
75
- // / current index of the iteration.
76
- using LoopIterationFn = function_ref<Value(
77
- OpBuilder &rewriter, ValueRange memRefOperands, ValueRange loopIvs)>;
78
-
79
- static void lowerOpToLoops (Operation *op, ValueRange operands,
80
- PatternRewriter &rewriter,
72
+ // / loop. It takes as input an OpBuilder and the range of loop induction
73
+ // / variables for the iteration. It returns a value to store at the current
74
+ // / index of the iteration.
75
+ using LoopIterationFn =
76
+ function_ref<Value(OpBuilder &rewriter, ValueRange loopIvs)>;
77
+
78
+ static void lowerOpToLoops (Operation *op, PatternRewriter &rewriter,
81
79
LoopIterationFn processIteration) {
82
80
auto tensorType = llvm::cast<RankedTensorType>((*op->result_type_begin ()));
83
81
auto loc = op->getLoc ();
@@ -95,10 +93,10 @@ static void lowerOpToLoops(Operation *op, ValueRange operands,
95
93
affine::buildAffineLoopNest (
96
94
rewriter, loc, lowerBounds, tensorType.getShape (), steps,
97
95
[&](OpBuilder &nestedBuilder, Location loc, ValueRange ivs) {
98
- // Call the processing function with the rewriter, the memref operands,
99
- // and the loop induction variables. This function will return the value
100
- // to store at the current index.
101
- Value valueToStore = processIteration (nestedBuilder, operands, ivs);
96
+ // Call the processing function with the rewriter and the loop
97
+ // induction variables. This function will return the value to store at
98
+ // the current index.
99
+ Value valueToStore = processIteration (nestedBuilder, ivs);
102
100
affine::AffineStoreOp::create (nestedBuilder, loc, valueToStore, alloc,
103
101
ivs);
104
102
});
@@ -109,53 +107,46 @@ static void lowerOpToLoops(Operation *op, ValueRange operands,
109
107
110
108
namespace {
111
109
// ===----------------------------------------------------------------------===//
112
- // ToyToAffine RewritePatterns : Binary operations
110
+ // ToyToAffine Conversion Patterns : Binary operations
113
111
// ===----------------------------------------------------------------------===//
114
112
115
113
template <typename BinaryOp, typename LoweredBinaryOp>
116
- struct BinaryOpLowering : public ConversionPattern {
117
- BinaryOpLowering (MLIRContext *ctx)
118
- : ConversionPattern( BinaryOp::getOperationName(), 1 , ctx) {}
114
+ struct BinaryOpLowering : public OpConversionPattern <BinaryOp> {
115
+ using OpConversionPattern<BinaryOp>::OpConversionPattern;
116
+ using OpAdaptor = typename OpConversionPattern< BinaryOp>::OpAdaptor;
119
117
120
118
LogicalResult
121
- matchAndRewrite (Operation * op, ArrayRef<Value> operands ,
119
+ matchAndRewrite (BinaryOp op, OpAdaptor adaptor ,
122
120
ConversionPatternRewriter &rewriter) const final {
123
121
auto loc = op->getLoc ();
124
- lowerOpToLoops (op, operands, rewriter,
125
- [loc](OpBuilder &builder, ValueRange memRefOperands,
126
- ValueRange loopIvs) {
127
- // Generate an adaptor for the remapped operands of the
128
- // BinaryOp. This allows for using the nice named accessors
129
- // that are generated by the ODS.
130
- typename BinaryOp::Adaptor binaryAdaptor (memRefOperands);
131
-
132
- // Generate loads for the element of 'lhs' and 'rhs' at the
133
- // inner loop.
134
- auto loadedLhs = affine::AffineLoadOp::create (
135
- builder, loc, binaryAdaptor.getLhs (), loopIvs);
136
- auto loadedRhs = affine::AffineLoadOp::create (
137
- builder, loc, binaryAdaptor.getRhs (), loopIvs);
138
-
139
- // Create the binary operation performed on the loaded
140
- // values.
141
- return LoweredBinaryOp::create (builder, loc, loadedLhs,
142
- loadedRhs);
143
- });
122
+ lowerOpToLoops (op, rewriter, [&](OpBuilder &builder, ValueRange loopIvs) {
123
+ // Generate loads for the element of 'lhs' and 'rhs' at the
124
+ // inner loop.
125
+ auto loadedLhs =
126
+ affine::AffineLoadOp::create (builder, loc, adaptor.getLhs (), loopIvs);
127
+ auto loadedRhs =
128
+ affine::AffineLoadOp::create (builder, loc, adaptor.getRhs (), loopIvs);
129
+
130
+ // Create the binary operation performed on the loaded
131
+ // values.
132
+ return LoweredBinaryOp::create (builder, loc, loadedLhs, loadedRhs);
133
+ });
144
134
return success ();
145
135
}
146
136
};
147
137
using AddOpLowering = BinaryOpLowering<toy::AddOp, arith::AddFOp>;
148
138
using MulOpLowering = BinaryOpLowering<toy::MulOp, arith::MulFOp>;
149
139
150
140
// ===----------------------------------------------------------------------===//
151
- // ToyToAffine RewritePatterns : Constant operations
141
+ // ToyToAffine Conversion Patterns : Constant operations
152
142
// ===----------------------------------------------------------------------===//
153
143
154
- struct ConstantOpLowering : public OpRewritePattern <toy::ConstantOp> {
155
- using OpRewritePattern <toy::ConstantOp>::OpRewritePattern ;
144
+ struct ConstantOpLowering : public OpConversionPattern <toy::ConstantOp> {
145
+ using OpConversionPattern <toy::ConstantOp>::OpConversionPattern ;
156
146
157
- LogicalResult matchAndRewrite (toy::ConstantOp op,
158
- PatternRewriter &rewriter) const final {
147
+ LogicalResult
148
+ matchAndRewrite (toy::ConstantOp op, OpAdaptor adaptor,
149
+ ConversionPatternRewriter &rewriter) const final {
159
150
DenseElementsAttr constantValue = op.getValue ();
160
151
Location loc = op.getLoc ();
161
152
@@ -216,7 +207,7 @@ struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> {
216
207
};
217
208
218
209
// ===----------------------------------------------------------------------===//
219
- // ToyToAffine RewritePatterns : Func operations
210
+ // ToyToAffine Conversion Patterns : Func operations
220
211
// ===----------------------------------------------------------------------===//
221
212
222
213
struct FuncOpLowering : public OpConversionPattern <toy::FuncOp> {
@@ -247,7 +238,7 @@ struct FuncOpLowering : public OpConversionPattern<toy::FuncOp> {
247
238
};
248
239
249
240
// ===----------------------------------------------------------------------===//
250
- // ToyToAffine RewritePatterns : Print operations
241
+ // ToyToAffine Conversion Patterns : Print operations
251
242
// ===----------------------------------------------------------------------===//
252
243
253
244
struct PrintOpLowering : public OpConversionPattern <toy::PrintOp> {
@@ -265,14 +256,15 @@ struct PrintOpLowering : public OpConversionPattern<toy::PrintOp> {
265
256
};
266
257
267
258
// ===----------------------------------------------------------------------===//
268
- // ToyToAffine RewritePatterns : Return operations
259
+ // ToyToAffine Conversion Patterns : Return operations
269
260
// ===----------------------------------------------------------------------===//
270
261
271
- struct ReturnOpLowering : public OpRewritePattern <toy::ReturnOp> {
272
- using OpRewritePattern <toy::ReturnOp>::OpRewritePattern ;
262
+ struct ReturnOpLowering : public OpConversionPattern <toy::ReturnOp> {
263
+ using OpConversionPattern <toy::ReturnOp>::OpConversionPattern ;
273
264
274
- LogicalResult matchAndRewrite (toy::ReturnOp op,
275
- PatternRewriter &rewriter) const final {
265
+ LogicalResult
266
+ matchAndRewrite (toy::ReturnOp op, OpAdaptor adaptor,
267
+ ConversionPatternRewriter &rewriter) const final {
276
268
// During this lowering, we expect that all function calls have been
277
269
// inlined.
278
270
if (op.hasOperand ())
@@ -285,32 +277,24 @@ struct ReturnOpLowering : public OpRewritePattern<toy::ReturnOp> {
285
277
};
286
278
287
279
// ===----------------------------------------------------------------------===//
288
- // ToyToAffine RewritePatterns : Transpose operations
280
+ // ToyToAffine Conversion Patterns : Transpose operations
289
281
// ===----------------------------------------------------------------------===//
290
282
291
- struct TransposeOpLowering : public ConversionPattern {
292
- TransposeOpLowering (MLIRContext *ctx)
293
- : ConversionPattern(toy::TransposeOp::getOperationName(), 1 , ctx) {}
283
+ struct TransposeOpLowering : public OpConversionPattern <toy::TransposeOp> {
284
+ using OpConversionPattern<toy::TransposeOp>::OpConversionPattern;
294
285
295
286
LogicalResult
296
- matchAndRewrite (Operation * op, ArrayRef<Value> operands ,
287
+ matchAndRewrite (toy::TransposeOp op, OpAdaptor adaptor ,
297
288
ConversionPatternRewriter &rewriter) const final {
298
289
auto loc = op->getLoc ();
299
- lowerOpToLoops (op, operands, rewriter,
300
- [loc](OpBuilder &builder, ValueRange memRefOperands,
301
- ValueRange loopIvs) {
302
- // Generate an adaptor for the remapped operands of the
303
- // TransposeOp. This allows for using the nice named
304
- // accessors that are generated by the ODS.
305
- toy::TransposeOpAdaptor transposeAdaptor (memRefOperands);
306
- Value input = transposeAdaptor.getInput ();
307
-
308
- // Transpose the elements by generating a load from the
309
- // reverse indices.
310
- SmallVector<Value, 2 > reverseIvs (llvm::reverse (loopIvs));
311
- return affine::AffineLoadOp::create (builder, loc, input,
312
- reverseIvs);
313
- });
290
+ lowerOpToLoops (op, rewriter, [&](OpBuilder &builder, ValueRange loopIvs) {
291
+ Value input = adaptor.getInput ();
292
+
293
+ // Transpose the elements by generating a load from the
294
+ // reverse indices.
295
+ SmallVector<Value, 2 > reverseIvs (llvm::reverse (loopIvs));
296
+ return affine::AffineLoadOp::create (builder, loc, input, reverseIvs);
297
+ });
314
298
return success ();
315
299
}
316
300
};
0 commit comments