Skip to content

Commit d30ce39

Browse files
committed
support alloc/alloca
1 parent f0ea2d3 commit d30ce39

File tree

2 files changed

+102
-39
lines changed

2 files changed

+102
-39
lines changed

mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp

Lines changed: 81 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -103,8 +103,8 @@ static bool checkLayout(Value val) {
103103
namespace {
104104
static Value getTargetMemref(Operation *op) {
105105
return llvm::TypeSwitch<Operation *, Value>(op)
106-
.template Case<memref::LoadOp, memref::StoreOp>(
107-
[](auto op) { return op.getMemref(); })
106+
.template Case<memref::LoadOp, memref::StoreOp, memref::AllocaOp,
107+
memref::AllocOp>([](auto op) { return op.getMemref(); })
108108
.template Case<vector::LoadOp, vector::StoreOp, vector::MaskedLoadOp,
109109
vector::MaskedStoreOp>(
110110
[](auto op) { return op.getBase(); })
@@ -113,54 +113,78 @@ static Value getTargetMemref(Operation *op) {
113113
.Default([](auto) { return Value{}; });
114114
}
115115

116-
static void replaceOp(Operation *op, PatternRewriter &rewriter,
117-
Value flatMemref, Value offset) {
116+
template <typename T>
117+
static void replaceOp(T op, PatternRewriter &rewriter, Value flatMemref,
118+
Value offset) {
118119
auto loc = op->getLoc();
119-
llvm::TypeSwitch<Operation *>(op)
120-
.Case<memref::LoadOp>([&](auto op) {
120+
llvm::TypeSwitch<Operation *>(op.getOperation())
121+
.template Case<memref::AllocOp>([&](auto oper) {
122+
// grab flatMemref's type, and replace op with a new one. Then
123+
// reinterpret it back.
124+
auto flatMemrefType = cast<MemRefType>(flatMemref.getType());
125+
auto loc = oper.getLoc();
126+
auto newAlloc = rewriter.create<memref::AllocOp>(
127+
loc, flatMemrefType, oper.getAlignmentAttr());
128+
auto originalType = cast<MemRefType>(oper.getType());
129+
130+
auto rank = originalType.getRank();
131+
SmallVector<OpFoldResult, 4> sizes, strides;
132+
sizes.resize(rank);
133+
strides.resize(rank);
134+
int64_t staticStride = 1;
135+
for (int i = rank - 1; i >= 0; --i) {
136+
sizes[i] = rewriter.getIndexAttr(originalType.getShape()[i]);
137+
strides[i] = rewriter.getIndexAttr(staticStride);
138+
staticStride *= originalType.getShape()[i];
139+
}
140+
rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
141+
op, originalType, newAlloc,
142+
/*offset=*/rewriter.getIndexAttr(0), sizes, strides);
143+
})
144+
.template Case<memref::LoadOp>([&](auto op) {
121145
auto newLoad = rewriter.create<memref::LoadOp>(
122146
loc, op->getResultTypes(), flatMemref, ValueRange{offset});
123147
newLoad->setAttrs(op->getAttrs());
124148
rewriter.replaceOp(op, newLoad.getResult());
125149
})
126-
.Case<memref::StoreOp>([&](auto op) {
150+
.template Case<memref::StoreOp>([&](auto op) {
127151
auto newStore = rewriter.create<memref::StoreOp>(
128152
loc, op->getOperands().front(), flatMemref, ValueRange{offset});
129153
newStore->setAttrs(op->getAttrs());
130154
rewriter.replaceOp(op, newStore);
131155
})
132-
.Case<vector::LoadOp>([&](auto op) {
156+
.template Case<vector::LoadOp>([&](auto op) {
133157
auto newLoad = rewriter.create<vector::LoadOp>(
134158
loc, op->getResultTypes(), flatMemref, ValueRange{offset});
135159
newLoad->setAttrs(op->getAttrs());
136160
rewriter.replaceOp(op, newLoad.getResult());
137161
})
138-
.Case<vector::StoreOp>([&](auto op) {
162+
.template Case<vector::StoreOp>([&](auto op) {
139163
auto newStore = rewriter.create<vector::StoreOp>(
140164
loc, op->getOperands().front(), flatMemref, ValueRange{offset});
141165
newStore->setAttrs(op->getAttrs());
142166
rewriter.replaceOp(op, newStore);
143167
})
144-
.Case<vector::MaskedLoadOp>([&](auto op) {
168+
.template Case<vector::MaskedLoadOp>([&](auto op) {
145169
auto newMaskedLoad = rewriter.create<vector::MaskedLoadOp>(
146170
loc, op.getType(), flatMemref, ValueRange{offset}, op.getMask(),
147171
op.getPassThru());
148172
newMaskedLoad->setAttrs(op->getAttrs());
149173
rewriter.replaceOp(op, newMaskedLoad.getResult());
150174
})
151-
.Case<vector::MaskedStoreOp>([&](auto op) {
175+
.template Case<vector::MaskedStoreOp>([&](auto op) {
152176
auto newMaskedStore = rewriter.create<vector::MaskedStoreOp>(
153177
loc, flatMemref, ValueRange{offset}, op.getMask(),
154178
op.getValueToStore());
155179
newMaskedStore->setAttrs(op->getAttrs());
156180
rewriter.replaceOp(op, newMaskedStore);
157181
})
158-
.Case<vector::TransferReadOp>([&](auto op) {
182+
.template Case<vector::TransferReadOp>([&](auto op) {
159183
auto newTransferRead = rewriter.create<vector::TransferReadOp>(
160184
loc, op.getType(), flatMemref, ValueRange{offset}, op.getPadding());
161185
rewriter.replaceOp(op, newTransferRead.getResult());
162186
})
163-
.Case<vector::TransferWriteOp>([&](auto op) {
187+
.template Case<vector::TransferWriteOp>([&](auto op) {
164188
auto newTransferWrite = rewriter.create<vector::TransferWriteOp>(
165189
loc, op.getVector(), flatMemref, ValueRange{offset});
166190
rewriter.replaceOp(op, newTransferWrite);
@@ -170,6 +194,16 @@ static void replaceOp(Operation *op, PatternRewriter &rewriter,
170194
});
171195
}
172196

197+
template <typename T>
198+
static ValueRange getIndices(T op) {
199+
if constexpr (std::is_same_v<T, memref::AllocaOp> ||
200+
std::is_same_v<T, memref::AllocOp>) {
201+
return ValueRange{};
202+
} else {
203+
return op.getIndices();
204+
}
205+
}
206+
173207
template <typename T>
174208
struct MemRefRewritePattern : public OpRewritePattern<T> {
175209
using OpRewritePattern<T>::OpRewritePattern;
@@ -179,34 +213,42 @@ struct MemRefRewritePattern : public OpRewritePattern<T> {
179213
if (!needFlattening(memref) || !checkLayout(memref))
180214
return failure();
181215
auto &&[flatMemref, offset] = getFlattenMemrefAndOffset(
182-
rewriter, op->getLoc(), memref, op.getIndices());
183-
replaceOp(op, rewriter, flatMemref, offset);
216+
rewriter, op->getLoc(), memref, getIndices<T>(op));
217+
replaceOp<T>(op, rewriter, flatMemref, offset);
184218
return success();
185219
}
186220
};
187221

188-
// For any memref op that emits a new memref.
189-
template <typename T>
190-
struct MemRefSourceRewritePattern : public OpRewritePattern<T> {
191-
using OpRewritePattern<T>::OpRewritePattern;
192-
LogicalResult matchAndRewrite(T op,
193-
PatternRewriter &rewriter) const override {
194-
if (!needFlattening(op) || !checkLayout(op))
195-
return failure();
196-
MemRefType sourceType = cast<MemRefType>(op.getType());
197-
198-
// Get flattened size, no strides.
199-
auto dimSizes = llvm::to_vector(sourceType.getShape());
200-
auto flattenedSize = std::accumulate(
201-
dimSizes.begin(), dimSizes.end(), 1, std::multiplies<int64_t>());
202-
auto flatMemrefType = MemRefType::get(
203-
/*shape=*/{flattenedSize}, sourceType.getElementType(),
204-
/*layout=*/nullptr, sourceType.getMemorySpace());
205-
rewriter.replaceOpWithNewOp<T>(
206-
op, flatMemrefType);
207-
return success();
208-
}
209-
};
222+
// // For any memref op that emits a new memref.
223+
// template <typename T>
224+
// struct MemRefSourceRewritePattern : public OpRewritePattern<T> {
225+
// using OpRewritePattern<T>::OpRewritePattern;
226+
// LogicalResult matchAndRewrite(T op,
227+
// PatternRewriter &rewriter) const override {
228+
// if (!needFlattening(op) || !checkLayout(op))
229+
// return failure();
230+
// MemRefType sourceType = cast<MemRefType>(op.getType());
231+
232+
// auto mixedSizes = op.getMixedSizes();
233+
234+
// // Get flattened size, no strides.
235+
// auto flattenedSize = std::accumulate(
236+
// mixedSizes.begin(), mixedSizes.end(), 1,
237+
// [](int64_t a, OpFoldResult b) {
238+
// return a * getConstantIntValue(b).value_or(1);
239+
// });
240+
241+
// auto flatMemrefType = MemRefType::get(
242+
// /*shape=*/{flattenedSize}, sourceType.getElementType(),
243+
// /*layout=*/nullptr, sourceType.getMemorySpace());
244+
// auto newSource = rewriter.create<T>(
245+
// op.getLoc(), flatMemrefType, op.getDynamicSizes());
246+
// auto reinterpretCast = rewriter.create<memref::ReinterpretCastOp>(
247+
// op.getLoc(), sourceType, newSource, op.getOffset(),
248+
// op.getMixedSizes(), op.getStrides());
249+
// return success();
250+
// }
251+
// };
210252

211253
struct FlattenMemrefsPass
212254
: public mlir::memref::impl::FlattenMemrefsPassBase<FlattenMemrefsPass> {
@@ -232,8 +274,8 @@ struct FlattenMemrefsPass
232274
void memref::populateFlattenMemrefsPatterns(RewritePatternSet &patterns) {
233275
patterns.insert<MemRefRewritePattern<memref::LoadOp>,
234276
MemRefRewritePattern<memref::StoreOp>,
235-
MemRefSourceRewritePattern<memref::AllocOp>,
236-
MemRefSourceRewritePattern<memref::AllocaOp>,
277+
MemRefRewritePattern<memref::AllocOp>,
278+
MemRefRewritePattern<memref::AllocaOp>,
237279
MemRefRewritePattern<vector::LoadOp>,
238280
MemRefRewritePattern<vector::StoreOp>,
239281
MemRefRewritePattern<vector::TransferReadOp>,

mlir/test/Dialect/MemRef/flatten_memref.mlir

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,3 +217,24 @@ func.func @transfer_write_memref(%input: memref<4x8xi2>, %value: vector<8xi2>, %
217217
// CHECK: %[[IDX:.*]] = affine.apply #[[MAP]]()[%[[ARG3]], %[[ARG2]]]
218218
// CHECK: %[[REINT:.*]] = memref.reinterpret_cast %[[ARG0]]
219219
// CHECK: vector.transfer_write %[[ARG1]], %[[REINT]][%[[IDX]]]
220+
221+
// -----
222+
223+
func.func @alloc_4x8_f32() -> memref<4x8xf32> {
224+
// Allocate a memref of size 4x8 with f32 elements.
225+
// The memref is uninitialized by default.
226+
%0 = memref.alloc() : memref<4x8xf32>
227+
228+
// Return the allocated memref.
229+
return %0 : memref<4x8xf32>
230+
}
231+
232+
// -----
233+
234+
func.func @chained_alloc_load() -> vector<8xf32> {
235+
%c3 = arith.constant 3 : index
236+
%c6 = arith.constant 6 : index
237+
%0 = memref.alloc() : memref<4x8xf32>
238+
%value = vector.load %0[%c3, %c6] : memref<4x8xf32>, vector<8xf32>
239+
return %value : vector<8xf32>
240+
}

0 commit comments

Comments
 (0)