Skip to content

Commit 2afb3cd

Browse files
[FlattenMemRef] Flatten MemRef AllocaOp (#8352)
1 parent 06d0c5d commit 2afb3cd

File tree

2 files changed

+30
-2
lines changed

2 files changed

+30
-2
lines changed

lib/Transforms/FlattenMemRefs.cpp

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,21 @@ struct AllocOpConversion : public OpConversionPattern<memref::AllocOp> {
178178
}
179179
};
180180

181+
struct AllocaOpConversion : public OpConversionPattern<memref::AllocaOp> {
182+
using OpConversionPattern::OpConversionPattern;
183+
184+
LogicalResult
185+
matchAndRewrite(memref::AllocaOp op, OpAdaptor /*adaptor*/,
186+
ConversionPatternRewriter &rewriter) const override {
187+
MemRefType type = op.getType();
188+
if (isUniDimensional(type) || !type.hasStaticShape())
189+
return failure();
190+
MemRefType newType = getFlattenedMemRefType(type);
191+
rewriter.replaceOpWithNewOp<memref::AllocaOp>(op, newType);
192+
return success();
193+
}
194+
};
195+
181196
struct GlobalOpConversion : public OpConversionPattern<memref::GlobalOp> {
182197
using OpConversionPattern::OpConversionPattern;
183198

@@ -352,6 +367,8 @@ static void populateFlattenMemRefsLegality(ConversionTarget &target) {
352367
target.addLegalDialect<arith::ArithDialect>();
353368
target.addDynamicallyLegalOp<memref::AllocOp>(
354369
[](memref::AllocOp op) { return isUniDimensional(op.getType()); });
370+
target.addDynamicallyLegalOp<memref::AllocaOp>(
371+
[](memref::AllocaOp op) { return isUniDimensional(op.getType()); });
355372
target.addDynamicallyLegalOp<memref::StoreOp>(
356373
[](memref::StoreOp op) { return op.getIndices().size() == 1; });
357374
target.addDynamicallyLegalOp<memref::LoadOp>(
@@ -426,8 +443,8 @@ struct FlattenMemRefPass
426443
RewritePatternSet patterns(ctx);
427444
SetVector<StringRef> rewrittenCallees;
428445
patterns.add<LoadOpConversion, StoreOpConversion, AllocOpConversion,
429-
GlobalOpConversion, GetGlobalOpConversion, ReshapeOpConversion,
430-
OperandConversionPattern<func::ReturnOp>,
446+
AllocaOpConversion, GlobalOpConversion, GetGlobalOpConversion,
447+
ReshapeOpConversion, OperandConversionPattern<func::ReturnOp>,
431448
OperandConversionPattern<memref::DeallocOp>,
432449
CondBranchOpConversion,
433450
OperandConversionPattern<memref::DeallocOp>,

test/Transforms/flatten_memref.mlir

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,3 +314,14 @@ module {
314314
return
315315
}
316316
}
317+
318+
// -----
319+
320+
// CHECK-LABEL: func @allocas() -> memref<16xi32> {
321+
// CHECK: %[[VAL_0:.*]] = memref.alloca() : memref<16xi32>
322+
// CHECK: return %[[VAL_0]] : memref<16xi32>
323+
// CHECK: }
324+
func.func @allocas() -> memref<4x4xi32> {
325+
%0 = memref.alloca() : memref<4x4xi32>
326+
return %0 : memref<4x4xi32>
327+
}

0 commit comments

Comments
 (0)