@@ -178,6 +178,21 @@ struct AllocOpConversion : public OpConversionPattern<memref::AllocOp> {
178178 }
179179};
180180
181+ struct AllocaOpConversion : public OpConversionPattern <memref::AllocaOp> {
182+ using OpConversionPattern::OpConversionPattern;
183+
184+ LogicalResult
185+ matchAndRewrite (memref::AllocaOp op, OpAdaptor /* adaptor*/ ,
186+ ConversionPatternRewriter &rewriter) const override {
187+ MemRefType type = op.getType ();
188+ if (isUniDimensional (type) || !type.hasStaticShape ())
189+ return failure ();
190+ MemRefType newType = getFlattenedMemRefType (type);
191+ rewriter.replaceOpWithNewOp <memref::AllocaOp>(op, newType);
192+ return success ();
193+ }
194+ };
195+
181196struct GlobalOpConversion : public OpConversionPattern <memref::GlobalOp> {
182197 using OpConversionPattern::OpConversionPattern;
183198
@@ -352,6 +367,8 @@ static void populateFlattenMemRefsLegality(ConversionTarget &target) {
352367 target.addLegalDialect <arith::ArithDialect>();
353368 target.addDynamicallyLegalOp <memref::AllocOp>(
354369 [](memref::AllocOp op) { return isUniDimensional (op.getType ()); });
370+ target.addDynamicallyLegalOp <memref::AllocaOp>(
371+ [](memref::AllocaOp op) { return isUniDimensional (op.getType ()); });
355372 target.addDynamicallyLegalOp <memref::StoreOp>(
356373 [](memref::StoreOp op) { return op.getIndices ().size () == 1 ; });
357374 target.addDynamicallyLegalOp <memref::LoadOp>(
@@ -426,8 +443,8 @@ struct FlattenMemRefPass
426443 RewritePatternSet patterns (ctx);
427444 SetVector<StringRef> rewrittenCallees;
428445 patterns.add <LoadOpConversion, StoreOpConversion, AllocOpConversion,
429- GlobalOpConversion, GetGlobalOpConversion, ReshapeOpConversion ,
430- OperandConversionPattern<func::ReturnOp>,
446+ AllocaOpConversion, GlobalOpConversion, GetGlobalOpConversion ,
447+ ReshapeOpConversion, OperandConversionPattern<func::ReturnOp>,
431448 OperandConversionPattern<memref::DeallocOp>,
432449 CondBranchOpConversion,
433450 OperandConversionPattern<memref::DeallocOp>,
0 commit comments