Skip to content

Commit bc742d1

Browse files
committed
Update
1 parent 7c861bc commit bc742d1

File tree

2 files changed

+48
-3
lines changed

2 files changed

+48
-3
lines changed

mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,8 @@ namespace {
9999
static Value getTargetMemref(Operation *op) {
100100
return llvm::TypeSwitch<Operation *, Value>(op)
101101
.template Case<memref::LoadOp, memref::StoreOp, memref::AllocaOp,
102-
memref::AllocOp>([](auto op) { return op.getMemref(); })
102+
memref::AllocOp, memref::DeallocOp>(
103+
[](auto op) { return op.getMemref(); })
103104
.template Case<vector::LoadOp, vector::StoreOp, vector::MaskedLoadOp,
104105
vector::MaskedStoreOp, vector::TransferReadOp,
105106
vector::TransferWriteOp>(
@@ -189,6 +190,10 @@ static void replaceOp(T op, PatternRewriter &rewriter, Value flatMemref,
189190
rewriter, loc, op.getVector(), flatMemref, ValueRange{offset});
190191
rewriter.replaceOp(op, newTransferWrite);
191192
})
193+
.template Case<memref::DeallocOp>([&](auto op) {
194+
auto newDealloc = memref::DeallocOp::create(rewriter, loc, flatMemref);
195+
rewriter.replaceOp(op, newDealloc);
196+
})
192197
.Default([&](auto op) {
193198
op->emitOpError("unimplemented: do not know how to replace op.");
194199
});
@@ -197,7 +202,8 @@ static void replaceOp(T op, PatternRewriter &rewriter, Value flatMemref,
197202
template <typename T>
198203
static ValueRange getIndices(T op) {
199204
if constexpr (std::is_same_v<T, memref::AllocaOp> ||
200-
std::is_same_v<T, memref::AllocOp>) {
205+
std::is_same_v<T, memref::AllocOp> ||
206+
std::is_same_v<T, memref::DeallocOp>) {
201207
return ValueRange{};
202208
} else {
203209
return op.getIndices();
@@ -286,7 +292,8 @@ void memref::populateFlattenMemrefOpsPatterns(RewritePatternSet &patterns) {
286292
patterns.insert<MemRefRewritePattern<memref::LoadOp>,
287293
MemRefRewritePattern<memref::StoreOp>,
288294
MemRefRewritePattern<memref::AllocOp>,
289-
MemRefRewritePattern<memref::AllocaOp>>(
295+
MemRefRewritePattern<memref::AllocaOp>,
296+
MemRefRewritePattern<memref::DeallocOp>>(
290297
patterns.getContext());
291298
}
292299

mlir/test/Dialect/MemRef/flatten_memref.mlir

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,3 +298,41 @@ func.func @load_scalar_from_memref_static_dim_col_major(%input: memref<4x8xf32,
298298
// CHECK: %[[IDX:.*]] = affine.apply #[[MAP]]()[%[[ARG2]], %[[ARG1]]]
299299
// CHECK: %[[REINT:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: [100], sizes: [32], strides: [1] : memref<4x8xf32, strided<[1, 4], offset: 100>> to memref<32xf32, strided<[1], offset: 100>>
300300
// CHECK: memref.load %[[REINT]][%[[IDX]]] : memref<32xf32, strided<[1], offset: 100>>
301+
302+
// -----
303+
304+
func.func @dealloc_static_memref(%input: memref<4x8xf32>) {
305+
memref.dealloc %input : memref<4x8xf32>
306+
return
307+
}
308+
309+
// CHECK-LABEL: func @dealloc_static_memref
310+
// CHECK-SAME: (%[[ARG0:.*]]: memref<4x8xf32>)
311+
// CHECK-NEXT: %[[REINT:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: [0], sizes: [32], strides: [1] : memref<4x8xf32> to memref<32xf32, strided<[1]>>
312+
// CHECK-NEXT: memref.dealloc %[[REINT]] : memref<32xf32, strided<[1]>>
313+
314+
// -----
315+
316+
func.func @dealloc_dynamic_memref(%input: memref<?x?xf32>) {
317+
memref.dealloc %input : memref<?x?xf32>
318+
return
319+
}
320+
321+
// CHECK-LABEL: func @dealloc_dynamic_memref
322+
// CHECK-SAME: (%[[ARG0:.*]]: memref<?x?xf32>)
323+
// CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG0]]
324+
// CHECK: %[[SIZE:.*]] = affine.max #{{.*}}()[%[[STRIDES]]#0, %[[SIZES]]#0, %[[SIZES]]#1]
325+
// CHECK: %[[REINT:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: [0], sizes: [%[[SIZE]]], strides: [1] : memref<?x?xf32> to memref<?xf32, strided<[1]>>
326+
// CHECK: memref.dealloc %[[REINT]] : memref<?xf32, strided<[1]>>
327+
328+
// -----
329+
330+
func.func @dealloc_strided_memref(%input: memref<4x8xf32, strided<[8, 1], offset: 100>>) {
331+
memref.dealloc %input : memref<4x8xf32, strided<[8, 1], offset: 100>>
332+
return
333+
}
334+
335+
// CHECK-LABEL: func @dealloc_strided_memref
336+
// CHECK-SAME: (%[[ARG0:.*]]: memref<4x8xf32, strided<[8, 1], offset: 100>>)
337+
// CHECK-NEXT: %[[REINT:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: [100], sizes: [32], strides: [1] : memref<4x8xf32, strided<[8, 1], offset: 100>> to memref<32xf32, strided<[1], offset: 100>>
338+
// CHECK-NEXT: memref.dealloc %[[REINT]] : memref<32xf32, strided<[1], offset: 100>>

0 commit comments

Comments
 (0)