Skip to content

Commit f0ea2d3

Browse files
committed
Adding memref alloc/alloca.
1 parent 3007c1d commit f0ea2d3

File tree

1 file changed

+28
-0
lines changed

1 file changed

+28
-0
lines changed

mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,11 @@
2828
#include "mlir/IR/PatternMatch.h"
2929
#include "mlir/Pass/Pass.h"
3030
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
31+
#include "llvm/ADT/SmallVector.h"
3132
#include "llvm/ADT/TypeSwitch.h"
3233

34+
#include <numeric>
35+
3336
namespace mlir {
3437
namespace memref {
3538
#define GEN_PASS_DEF_FLATTENMEMREFSPASS
@@ -182,6 +185,29 @@ struct MemRefRewritePattern : public OpRewritePattern<T> {
182185
}
183186
};
184187

188+
// For any memref op that emits a new memref.
189+
template <typename T>
190+
struct MemRefSourceRewritePattern : public OpRewritePattern<T> {
191+
using OpRewritePattern<T>::OpRewritePattern;
192+
LogicalResult matchAndRewrite(T op,
193+
PatternRewriter &rewriter) const override {
194+
if (!needFlattening(op) || !checkLayout(op))
195+
return failure();
196+
MemRefType sourceType = cast<MemRefType>(op.getType());
197+
198+
// Get flattened size, no strides.
199+
auto dimSizes = llvm::to_vector(sourceType.getShape());
200+
auto flattenedSize = std::accumulate(
201+
dimSizes.begin(), dimSizes.end(), 1, std::multiplies<int64_t>());
202+
auto flatMemrefType = MemRefType::get(
203+
/*shape=*/{flattenedSize}, sourceType.getElementType(),
204+
/*layout=*/nullptr, sourceType.getMemorySpace());
205+
rewriter.replaceOpWithNewOp<T>(
206+
op, flatMemrefType);
207+
return success();
208+
}
209+
};
210+
185211
struct FlattenMemrefsPass
186212
: public mlir::memref::impl::FlattenMemrefsPassBase<FlattenMemrefsPass> {
187213
using Base::Base;
@@ -206,6 +232,8 @@ struct FlattenMemrefsPass
206232
void memref::populateFlattenMemrefsPatterns(RewritePatternSet &patterns) {
207233
patterns.insert<MemRefRewritePattern<memref::LoadOp>,
208234
MemRefRewritePattern<memref::StoreOp>,
235+
MemRefSourceRewritePattern<memref::AllocOp>,
236+
MemRefSourceRewritePattern<memref::AllocaOp>,
209237
MemRefRewritePattern<vector::LoadOp>,
210238
MemRefRewritePattern<vector::StoreOp>,
211239
MemRefRewritePattern<vector::TransferReadOp>,

0 commit comments

Comments
 (0)