diff --git a/mlir/include/mlir/Dialect/Affine/Utils.h b/mlir/include/mlir/Dialect/Affine/Utils.h index 0f801ebb6f589..ff1900bc8f2eb 100644 --- a/mlir/include/mlir/Dialect/Affine/Utils.h +++ b/mlir/include/mlir/Dialect/Affine/Utils.h @@ -31,6 +31,7 @@ class FuncOp; namespace memref { class AllocOp; +class AllocaOp; } // namespace memref namespace affine { @@ -245,7 +246,12 @@ LogicalResult replaceAllMemRefUsesWith(Value oldMemRef, Value newMemRef, /// Rewrites the memref defined by this alloc op to have an identity layout map /// and updates all its indexing uses. Returns failure if any of its uses /// escape (while leaving the IR in a valid state). -LogicalResult normalizeMemRef(memref::AllocOp *op); +template +LogicalResult normalizeMemRef(AllocLikeOp *op); +extern template LogicalResult +normalizeMemRef(memref::AllocaOp *op); +extern template LogicalResult +normalizeMemRef(memref::AllocOp *op); /// Normalizes `memrefType` so that the affine layout map of the memref is /// transformed to an identity map with a new shape being computed for the diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp index 9e3257a62b12f..7ef016f88be37 100644 --- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp @@ -1737,9 +1737,10 @@ static AffineExpr createDimSizeExprForTiledLayout(AffineExpr oldMapOutput, /// %c4 = arith.constant 4 : index /// %1 = affine.apply #map1(%c4, %0) /// %2 = affine.apply #map2(%c4, %0) +template static void createNewDynamicSizes(MemRefType oldMemRefType, MemRefType newMemRefType, AffineMap map, - memref::AllocOp *allocOp, OpBuilder b, + AllocLikeOp *allocOp, OpBuilder b, SmallVectorImpl &newDynamicSizes) { // Create new input for AffineApplyOp. SmallVector inAffineApply; @@ -1786,7 +1787,8 @@ static void createNewDynamicSizes(MemRefType oldMemRefType, } // TODO: Currently works for static memrefs with a single layout map. -LogicalResult mlir::affine::normalizeMemRef(memref::AllocOp *allocOp) { +template +LogicalResult mlir::affine::normalizeMemRef(AllocLikeOp *allocOp) { MemRefType memrefType = allocOp->getType(); OpBuilder b(*allocOp); @@ -1802,7 +1804,7 @@ LogicalResult mlir::affine::normalizeMemRef(memref::AllocOp *allocOp) { SmallVector symbolOperands(allocOp->getSymbolOperands()); AffineMap layoutMap = memrefType.getLayout().getAffineMap(); - memref::AllocOp newAlloc; + AllocLikeOp newAlloc; // Check if `layoutMap` is a tiled layout. Only single layout map is // supported for normalizing dynamic memrefs. SmallVector> tileSizePos; @@ -1814,11 +1816,11 @@ LogicalResult mlir::affine::normalizeMemRef(memref::AllocOp *allocOp) { newDynamicSizes); // Add the new dynamic sizes in new AllocOp. newAlloc = - b.create(allocOp->getLoc(), newMemRefType, - newDynamicSizes, allocOp->getAlignmentAttr()); + b.create(allocOp->getLoc(), newMemRefType, newDynamicSizes, + allocOp->getAlignmentAttr()); } else { - newAlloc = b.create(allocOp->getLoc(), newMemRefType, - allocOp->getAlignmentAttr()); + newAlloc = b.create(allocOp->getLoc(), newMemRefType, + allocOp->getAlignmentAttr()); } // Replace all uses of the old memref. if (failed(replaceAllMemRefUsesWith(oldMemRef, /*newMemRef=*/newAlloc, @@ -1843,6 +1845,11 @@ LogicalResult mlir::affine::normalizeMemRef(memref::AllocOp *allocOp) { return success(); } +template LogicalResult +mlir::affine::normalizeMemRef(memref::AllocaOp *op); +template LogicalResult +mlir::affine::normalizeMemRef(memref::AllocOp *op); + MemRefType mlir::affine::normalizeMemRefType(MemRefType memrefType) { unsigned rank = memrefType.getRank(); if (rank == 0) diff --git a/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp b/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp index 33772ccb7dd9d..08b853fe65b85 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp @@ -151,11 +151,11 @@ void NormalizeMemRefs::setCalleesAndCallersNonNormalizable( }); } -/// Check whether all the uses of AllocOps, CallOps and function arguments of a -/// function are either of dereferencing type or are uses in: DeallocOp, CallOp -/// or ReturnOp. Only if these constraints are satisfied will the function -/// become a candidate for normalization. When the uses of a memref are -/// non-normalizable and the memref map layout is trivial (identity), we can +/// Check whether all the uses of AllocOps, AllocaOps, CallOps and function +/// arguments of a function are either of dereferencing type or are uses in: +/// DeallocOp, CallOp or ReturnOp. Only if these constraints are satisfied will +/// the function become a candidate for normalization. When the uses of a memref +/// are non-normalizable and the memref map layout is trivial (identity), we can /// still label the entire function as normalizable. We assume external /// functions to be normalizable. bool NormalizeMemRefs::areMemRefsNormalizable(func::FuncOp funcOp) { @@ -174,6 +174,17 @@ bool NormalizeMemRefs::areMemRefsNormalizable(func::FuncOp funcOp) { .wasInterrupted()) return false; + if (funcOp + .walk([&](memref::AllocaOp allocaOp) -> WalkResult { + Value oldMemRef = allocaOp.getResult(); + if (!allocaOp.getType().getLayout().isIdentity() && + !isMemRefNormalizable(oldMemRef.getUsers())) + return WalkResult::interrupt(); + return WalkResult::advance(); + }) + .wasInterrupted()) + return false; + if (funcOp .walk([&](func::CallOp callOp) -> WalkResult { for (unsigned resIndex : @@ -335,18 +346,23 @@ void NormalizeMemRefs::updateFunctionSignature(func::FuncOp funcOp, } /// Normalizes the memrefs within a function which includes those arising as a -/// result of AllocOps, CallOps and function's argument. The ModuleOp argument -/// is used to help update function's signature after normalization. +/// result of AllocOps, AllocaOps, CallOps and function's argument. The ModuleOp +/// argument is used to help update function's signature after normalization. void NormalizeMemRefs::normalizeFuncOpMemRefs(func::FuncOp funcOp, ModuleOp moduleOp) { // Turn memrefs' non-identity layouts maps into ones with identity. Collect - // alloc ops first and then process since normalizeMemRef replaces/erases ops - // during memref rewriting. + // alloc/alloca ops first and then process since normalizeMemRef + // replaces/erases ops during memref rewriting. SmallVector allocOps; funcOp.walk([&](memref::AllocOp op) { allocOps.push_back(op); }); for (memref::AllocOp allocOp : allocOps) (void)normalizeMemRef(&allocOp); + SmallVector allocaOps; + funcOp.walk([&](memref::AllocaOp op) { allocaOps.push_back(op); }); + for (memref::AllocaOp allocaOp : allocaOps) + (void)normalizeMemRef(&allocaOp); + // We use this OpBuilder to create new memref layout later. OpBuilder b(funcOp); diff --git a/mlir/test/Dialect/MemRef/normalize-memrefs.mlir b/mlir/test/Dialect/MemRef/normalize-memrefs.mlir index 6d20ccbf2ca05..e93a1a4ebae53 100644 --- a/mlir/test/Dialect/MemRef/normalize-memrefs.mlir +++ b/mlir/test/Dialect/MemRef/normalize-memrefs.mlir @@ -31,6 +31,20 @@ func.func @permute() { // CHECK-NEXT: memref.dealloc [[MEM]] // CHECK-NEXT: return +// CHECK-LABEL: func @alloca +func.func @alloca(%idx : index) { + // CHECK-NEXT: memref.alloca() : memref<65xf32> + %A = memref.alloca() : memref<64xf32, affine_map<(d0) -> (d0 + 1)>> + // CHECK-NEXT: affine.load %{{.*}}[symbol(%arg0) + 1] : memref<65xf32> + affine.load %A[%idx] : memref<64xf32, affine_map<(d0) -> (d0 + 1)>> + affine.for %i = 0 to 64 { + %1 = affine.load %A[%i] : memref<64xf32, affine_map<(d0) -> (d0 + 1)>> + "prevent.dce"(%1) : (f32) -> () + // CHECK: %{{.*}} = affine.load %{{.*}}[%arg{{.*}} + 1] : memref<65xf32> + } + return +} + // CHECK-LABEL: func @shift func.func @shift(%idx : index) { // CHECK-NEXT: memref.alloc() : memref<65xf32>