4444using namespace mlir ;
4545
4646// ===----------------------------------------------------------------------===//
47- // ToyToAffine RewritePatterns
47+ // ToyToAffine Conversion Patterns
4848// ===----------------------------------------------------------------------===//
4949
5050// / Convert the given RankedTensorType into the corresponding MemRefType.
@@ -69,15 +69,13 @@ static Value insertAllocAndDealloc(MemRefType type, Location loc,
6969}
7070
7171// / This defines the function type used to process an iteration of a lowered
72- // / loop. It takes as input an OpBuilder, an range of memRefOperands
73- // / corresponding to the operands of the input operation, and the range of loop
74- // / induction variables for the iteration. It returns a value to store at the
75- // / current index of the iteration.
76- using LoopIterationFn = function_ref<Value(
77- OpBuilder &rewriter, ValueRange memRefOperands, ValueRange loopIvs)>;
78-
79- static void lowerOpToLoops (Operation *op, ValueRange operands,
80- PatternRewriter &rewriter,
72+ // / loop. It takes as input an OpBuilder and the range of loop induction
73+ // / variables for the iteration. It returns a value to store at the current
74+ // / index of the iteration.
75+ using LoopIterationFn =
76+ function_ref<Value(OpBuilder &rewriter, ValueRange loopIvs)>;
77+
78+ static void lowerOpToLoops (Operation *op, PatternRewriter &rewriter,
8179 LoopIterationFn processIteration) {
8280 auto tensorType = llvm::cast<RankedTensorType>((*op->result_type_begin ()));
8381 auto loc = op->getLoc ();
@@ -95,10 +93,10 @@ static void lowerOpToLoops(Operation *op, ValueRange operands,
9593 affine::buildAffineLoopNest (
9694 rewriter, loc, lowerBounds, tensorType.getShape (), steps,
9795 [&](OpBuilder &nestedBuilder, Location loc, ValueRange ivs) {
98- // Call the processing function with the rewriter, the memref operands,
99- // and the loop induction variables. This function will return the value
100- // to store at the current index.
101- Value valueToStore = processIteration (nestedBuilder, operands, ivs);
96+ // Call the processing function with the rewriter and the loop
97+ // induction variables. This function will return the value to store at
98+ // the current index.
99+ Value valueToStore = processIteration (nestedBuilder, ivs);
102100 affine::AffineStoreOp::create (nestedBuilder, loc, valueToStore, alloc,
103101 ivs);
104102 });
@@ -109,53 +107,46 @@ static void lowerOpToLoops(Operation *op, ValueRange operands,
109107
110108namespace {
111109// ===----------------------------------------------------------------------===//
112- // ToyToAffine RewritePatterns : Binary operations
110+ // ToyToAffine Conversion Patterns : Binary operations
113111// ===----------------------------------------------------------------------===//
114112
115113template <typename BinaryOp, typename LoweredBinaryOp>
116- struct BinaryOpLowering : public ConversionPattern {
117- BinaryOpLowering (MLIRContext *ctx)
118- : ConversionPattern( BinaryOp::getOperationName(), 1 , ctx) {}
114+ struct BinaryOpLowering : public OpConversionPattern <BinaryOp> {
115+ using OpConversionPattern<BinaryOp>::OpConversionPattern;
116+ using OpAdaptor = typename OpConversionPattern< BinaryOp>::OpAdaptor;
119117
120118 LogicalResult
121- matchAndRewrite (Operation * op, ArrayRef<Value> operands ,
119+ matchAndRewrite (BinaryOp op, OpAdaptor adaptor ,
122120 ConversionPatternRewriter &rewriter) const final {
123121 auto loc = op->getLoc ();
124- lowerOpToLoops (op, operands, rewriter,
125- [loc](OpBuilder &builder, ValueRange memRefOperands,
126- ValueRange loopIvs) {
127- // Generate an adaptor for the remapped operands of the
128- // BinaryOp. This allows for using the nice named accessors
129- // that are generated by the ODS.
130- typename BinaryOp::Adaptor binaryAdaptor (memRefOperands);
131-
132- // Generate loads for the element of 'lhs' and 'rhs' at the
133- // inner loop.
134- auto loadedLhs = affine::AffineLoadOp::create (
135- builder, loc, binaryAdaptor.getLhs (), loopIvs);
136- auto loadedRhs = affine::AffineLoadOp::create (
137- builder, loc, binaryAdaptor.getRhs (), loopIvs);
138-
139- // Create the binary operation performed on the loaded
140- // values.
141- return LoweredBinaryOp::create (builder, loc, loadedLhs,
142- loadedRhs);
143- });
122+ lowerOpToLoops (op, rewriter, [&](OpBuilder &builder, ValueRange loopIvs) {
123+ // Generate loads for the element of 'lhs' and 'rhs' at the
124+ // inner loop.
125+ auto loadedLhs =
126+ affine::AffineLoadOp::create (builder, loc, adaptor.getLhs (), loopIvs);
127+ auto loadedRhs =
128+ affine::AffineLoadOp::create (builder, loc, adaptor.getRhs (), loopIvs);
129+
130+ // Create the binary operation performed on the loaded
131+ // values.
132+ return LoweredBinaryOp::create (builder, loc, loadedLhs, loadedRhs);
133+ });
144134 return success ();
145135 }
146136};
147137using AddOpLowering = BinaryOpLowering<toy::AddOp, arith::AddFOp>;
148138using MulOpLowering = BinaryOpLowering<toy::MulOp, arith::MulFOp>;
149139
150140// ===----------------------------------------------------------------------===//
151- // ToyToAffine RewritePatterns : Constant operations
141+ // ToyToAffine Conversion Patterns : Constant operations
152142// ===----------------------------------------------------------------------===//
153143
154- struct ConstantOpLowering : public OpRewritePattern <toy::ConstantOp> {
155- using OpRewritePattern <toy::ConstantOp>::OpRewritePattern ;
144+ struct ConstantOpLowering : public OpConversionPattern <toy::ConstantOp> {
145+ using OpConversionPattern <toy::ConstantOp>::OpConversionPattern ;
156146
157- LogicalResult matchAndRewrite (toy::ConstantOp op,
158- PatternRewriter &rewriter) const final {
147+ LogicalResult
148+ matchAndRewrite (toy::ConstantOp op, OpAdaptor adaptor,
149+ ConversionPatternRewriter &rewriter) const final {
159150 DenseElementsAttr constantValue = op.getValue ();
160151 Location loc = op.getLoc ();
161152
@@ -216,7 +207,7 @@ struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> {
216207};
217208
218209// ===----------------------------------------------------------------------===//
219- // ToyToAffine RewritePatterns : Func operations
210+ // ToyToAffine Conversion Patterns : Func operations
220211// ===----------------------------------------------------------------------===//
221212
222213struct FuncOpLowering : public OpConversionPattern <toy::FuncOp> {
@@ -247,7 +238,7 @@ struct FuncOpLowering : public OpConversionPattern<toy::FuncOp> {
247238};
248239
249240// ===----------------------------------------------------------------------===//
250- // ToyToAffine RewritePatterns : Print operations
241+ // ToyToAffine Conversion Patterns : Print operations
251242// ===----------------------------------------------------------------------===//
252243
253244struct PrintOpLowering : public OpConversionPattern <toy::PrintOp> {
@@ -265,14 +256,15 @@ struct PrintOpLowering : public OpConversionPattern<toy::PrintOp> {
265256};
266257
267258// ===----------------------------------------------------------------------===//
268- // ToyToAffine RewritePatterns : Return operations
259+ // ToyToAffine Conversion Patterns : Return operations
269260// ===----------------------------------------------------------------------===//
270261
271- struct ReturnOpLowering : public OpRewritePattern <toy::ReturnOp> {
272- using OpRewritePattern <toy::ReturnOp>::OpRewritePattern ;
262+ struct ReturnOpLowering : public OpConversionPattern <toy::ReturnOp> {
263+ using OpConversionPattern <toy::ReturnOp>::OpConversionPattern ;
273264
274- LogicalResult matchAndRewrite (toy::ReturnOp op,
275- PatternRewriter &rewriter) const final {
265+ LogicalResult
266+ matchAndRewrite (toy::ReturnOp op, OpAdaptor adaptor,
267+ ConversionPatternRewriter &rewriter) const final {
276268 // During this lowering, we expect that all function calls have been
277269 // inlined.
278270 if (op.hasOperand ())
@@ -285,32 +277,24 @@ struct ReturnOpLowering : public OpRewritePattern<toy::ReturnOp> {
285277};
286278
287279// ===----------------------------------------------------------------------===//
288- // ToyToAffine RewritePatterns : Transpose operations
280+ // ToyToAffine Conversion Patterns : Transpose operations
289281// ===----------------------------------------------------------------------===//
290282
291- struct TransposeOpLowering : public ConversionPattern {
292- TransposeOpLowering (MLIRContext *ctx)
293- : ConversionPattern(toy::TransposeOp::getOperationName(), 1 , ctx) {}
283+ struct TransposeOpLowering : public OpConversionPattern <toy::TransposeOp> {
284+ using OpConversionPattern<toy::TransposeOp>::OpConversionPattern;
294285
295286 LogicalResult
296- matchAndRewrite (Operation * op, ArrayRef<Value> operands ,
287+ matchAndRewrite (toy::TransposeOp op, OpAdaptor adaptor ,
297288 ConversionPatternRewriter &rewriter) const final {
298289 auto loc = op->getLoc ();
299- lowerOpToLoops (op, operands, rewriter,
300- [loc](OpBuilder &builder, ValueRange memRefOperands,
301- ValueRange loopIvs) {
302- // Generate an adaptor for the remapped operands of the
303- // TransposeOp. This allows for using the nice named
304- // accessors that are generated by the ODS.
305- toy::TransposeOpAdaptor transposeAdaptor (memRefOperands);
306- Value input = transposeAdaptor.getInput ();
307-
308- // Transpose the elements by generating a load from the
309- // reverse indices.
310- SmallVector<Value, 2 > reverseIvs (llvm::reverse (loopIvs));
311- return affine::AffineLoadOp::create (builder, loc, input,
312- reverseIvs);
313- });
290+ lowerOpToLoops (op, rewriter, [&](OpBuilder &builder, ValueRange loopIvs) {
291+ Value input = adaptor.getInput ();
292+
293+ // Transpose the elements by generating a load from the
294+ // reverse indices.
295+ SmallVector<Value, 2 > reverseIvs (llvm::reverse (loopIvs));
296+ return affine::AffineLoadOp::create (builder, loc, input, reverseIvs);
297+ });
314298 return success ();
315299 }
316300};
0 commit comments