2828#include " mlir/IR/PatternMatch.h"
2929#include " mlir/Pass/Pass.h"
3030#include " mlir/Transforms/GreedyPatternRewriteDriver.h"
31+ #include " llvm/ADT/SmallVector.h"
3132#include " llvm/ADT/TypeSwitch.h"
3233
34+ #include < numeric>
35+
3336namespace mlir {
3437namespace memref {
3538#define GEN_PASS_DEF_FLATTENMEMREFSPASS
@@ -182,6 +185,29 @@ struct MemRefRewritePattern : public OpRewritePattern<T> {
182185 }
183186};
184187
188+ // For any memref op that emits a new memref.
189+ template <typename T>
190+ struct MemRefSourceRewritePattern : public OpRewritePattern <T> {
191+ using OpRewritePattern<T>::OpRewritePattern;
192+ LogicalResult matchAndRewrite (T op,
193+ PatternRewriter &rewriter) const override {
194+ if (!needFlattening (op) || !checkLayout (op))
195+ return failure ();
196+ MemRefType sourceType = cast<MemRefType>(op.getType ());
197+
198+ // Get flattened size, no strides.
199+ auto dimSizes = llvm::to_vector (sourceType.getShape ());
200+ auto flattenedSize = std::accumulate (
201+ dimSizes.begin (), dimSizes.end (), 1 , std::multiplies<int64_t >());
202+ auto flatMemrefType = MemRefType::get (
203+ /* shape=*/ {flattenedSize}, sourceType.getElementType (),
204+ /* layout=*/ nullptr , sourceType.getMemorySpace ());
205+ rewriter.replaceOpWithNewOp <T>(
206+ op, flatMemrefType);
207+ return success ();
208+ }
209+ };
210+
185211struct FlattenMemrefsPass
186212 : public mlir::memref::impl::FlattenMemrefsPassBase<FlattenMemrefsPass> {
187213 using Base::Base;
@@ -206,6 +232,8 @@ struct FlattenMemrefsPass
206232void memref::populateFlattenMemrefsPatterns (RewritePatternSet &patterns) {
207233 patterns.insert <MemRefRewritePattern<memref::LoadOp>,
208234 MemRefRewritePattern<memref::StoreOp>,
235+ MemRefSourceRewritePattern<memref::AllocOp>,
236+ MemRefSourceRewritePattern<memref::AllocaOp>,
209237 MemRefRewritePattern<vector::LoadOp>,
210238 MemRefRewritePattern<vector::StoreOp>,
211239 MemRefRewritePattern<vector::TransferReadOp>,
0 commit comments