@@ -99,7 +99,8 @@ namespace {
9999static Value getTargetMemref (Operation *op) {
100100 return llvm::TypeSwitch<Operation *, Value>(op)
101101 .template Case <memref::LoadOp, memref::StoreOp, memref::AllocaOp,
102- memref::AllocOp>([](auto op) { return op.getMemref (); })
102+ memref::AllocOp, memref::DeallocOp>(
103+ [](auto op) { return op.getMemref (); })
103104 .template Case <vector::LoadOp, vector::StoreOp, vector::MaskedLoadOp,
104105 vector::MaskedStoreOp, vector::TransferReadOp,
105106 vector::TransferWriteOp>(
@@ -189,6 +190,10 @@ static void replaceOp(T op, PatternRewriter &rewriter, Value flatMemref,
189190 rewriter, loc, op.getVector (), flatMemref, ValueRange{offset});
190191 rewriter.replaceOp (op, newTransferWrite);
191192 })
193+ .template Case <memref::DeallocOp>([&](auto op) {
194+ auto newDealloc = memref::DeallocOp::create (rewriter, loc, flatMemref);
195+ rewriter.replaceOp (op, newDealloc);
196+ })
192197 .Default ([&](auto op) {
193198 op->emitOpError (" unimplemented: do not know how to replace op." );
194199 });
@@ -197,7 +202,8 @@ static void replaceOp(T op, PatternRewriter &rewriter, Value flatMemref,
197202template <typename T>
198203static ValueRange getIndices (T op) {
199204 if constexpr (std::is_same_v<T, memref::AllocaOp> ||
200- std::is_same_v<T, memref::AllocOp>) {
205+ std::is_same_v<T, memref::AllocOp> ||
206+ std::is_same_v<T, memref::DeallocOp>) {
201207 return ValueRange{};
202208 } else {
203209 return op.getIndices ();
@@ -286,7 +292,8 @@ void memref::populateFlattenMemrefOpsPatterns(RewritePatternSet &patterns) {
286292 patterns.insert <MemRefRewritePattern<memref::LoadOp>,
287293 MemRefRewritePattern<memref::StoreOp>,
288294 MemRefRewritePattern<memref::AllocOp>,
289- MemRefRewritePattern<memref::AllocaOp>>(
295+ MemRefRewritePattern<memref::AllocaOp>,
296+ MemRefRewritePattern<memref::DeallocOp>>(
290297 patterns.getContext ());
291298}
292299
0 commit comments