Skip to content

Conversation

@mgehre-amd
Copy link
Contributor

The pass was only handling memref.alloc and this extends it to also handle memref.alloca.

The pass was only handling memref.alloc,
and this extends it to also handle memref.alloca.
@llvmbot
Copy link
Member

llvmbot commented Jan 17, 2025

@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-affine

@llvm/pr-subscribers-mlir-memref

Author: Matthias Gehre (mgehre-amd)

Changes

The pass was only handling memref.alloc and this extends it to also handle memref.alloca.


Full diff: https://github.com/llvm/llvm-project/pull/123293.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Affine/Utils.h (+7-1)
  • (modified) mlir/lib/Dialect/Affine/Utils/Utils.cpp (+14-7)
  • (modified) mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp (+25-9)
  • (modified) mlir/test/Dialect/MemRef/normalize-memrefs.mlir (+14)
diff --git a/mlir/include/mlir/Dialect/Affine/Utils.h b/mlir/include/mlir/Dialect/Affine/Utils.h
index 0f801ebb6f5898..ff1900bc8f2ebc 100644
--- a/mlir/include/mlir/Dialect/Affine/Utils.h
+++ b/mlir/include/mlir/Dialect/Affine/Utils.h
@@ -31,6 +31,7 @@ class FuncOp;
 
 namespace memref {
 class AllocOp;
+class AllocaOp;
 } // namespace memref
 
 namespace affine {
@@ -245,7 +246,12 @@ LogicalResult replaceAllMemRefUsesWith(Value oldMemRef, Value newMemRef,
 /// Rewrites the memref defined by this alloc op to have an identity layout map
 /// and updates all its indexing uses. Returns failure if any of its uses
 /// escape (while leaving the IR in a valid state).
-LogicalResult normalizeMemRef(memref::AllocOp *op);
+template <typename AllocLikeOp>
+LogicalResult normalizeMemRef(AllocLikeOp *op);
+extern template LogicalResult
+normalizeMemRef<memref::AllocaOp>(memref::AllocaOp *op);
+extern template LogicalResult
+normalizeMemRef<memref::AllocOp>(memref::AllocOp *op);
 
 /// Normalizes `memrefType` so that the affine layout map of the memref is
 /// transformed to an identity map with a new shape being computed for the
diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
index 9e3257a62b12fb..7ef016f88be375 100644
--- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
@@ -1737,9 +1737,10 @@ static AffineExpr createDimSizeExprForTiledLayout(AffineExpr oldMapOutput,
 ///   %c4 = arith.constant 4 : index
 ///   %1 = affine.apply #map1(%c4, %0)
 ///   %2 = affine.apply #map2(%c4, %0)
+template <typename AllocLikeOp>
 static void createNewDynamicSizes(MemRefType oldMemRefType,
                                   MemRefType newMemRefType, AffineMap map,
-                                  memref::AllocOp *allocOp, OpBuilder b,
+                                  AllocLikeOp *allocOp, OpBuilder b,
                                   SmallVectorImpl<Value> &newDynamicSizes) {
   // Create new input for AffineApplyOp.
   SmallVector<Value, 4> inAffineApply;
@@ -1786,7 +1787,8 @@ static void createNewDynamicSizes(MemRefType oldMemRefType,
 }
 
 // TODO: Currently works for static memrefs with a single layout map.
-LogicalResult mlir::affine::normalizeMemRef(memref::AllocOp *allocOp) {
+template <typename AllocLikeOp>
+LogicalResult mlir::affine::normalizeMemRef(AllocLikeOp *allocOp) {
   MemRefType memrefType = allocOp->getType();
   OpBuilder b(*allocOp);
 
@@ -1802,7 +1804,7 @@ LogicalResult mlir::affine::normalizeMemRef(memref::AllocOp *allocOp) {
 
   SmallVector<Value, 4> symbolOperands(allocOp->getSymbolOperands());
   AffineMap layoutMap = memrefType.getLayout().getAffineMap();
-  memref::AllocOp newAlloc;
+  AllocLikeOp newAlloc;
   // Check if `layoutMap` is a tiled layout. Only single layout map is
   // supported for normalizing dynamic memrefs.
   SmallVector<std::tuple<AffineExpr, unsigned, unsigned>> tileSizePos;
@@ -1814,11 +1816,11 @@ LogicalResult mlir::affine::normalizeMemRef(memref::AllocOp *allocOp) {
                           newDynamicSizes);
     // Add the new dynamic sizes in new AllocOp.
     newAlloc =
-        b.create<memref::AllocOp>(allocOp->getLoc(), newMemRefType,
-                                  newDynamicSizes, allocOp->getAlignmentAttr());
+        b.create<AllocLikeOp>(allocOp->getLoc(), newMemRefType, newDynamicSizes,
+                              allocOp->getAlignmentAttr());
   } else {
-    newAlloc = b.create<memref::AllocOp>(allocOp->getLoc(), newMemRefType,
-                                         allocOp->getAlignmentAttr());
+    newAlloc = b.create<AllocLikeOp>(allocOp->getLoc(), newMemRefType,
+                                     allocOp->getAlignmentAttr());
   }
   // Replace all uses of the old memref.
   if (failed(replaceAllMemRefUsesWith(oldMemRef, /*newMemRef=*/newAlloc,
@@ -1843,6 +1845,11 @@ LogicalResult mlir::affine::normalizeMemRef(memref::AllocOp *allocOp) {
   return success();
 }
 
+template LogicalResult
+mlir::affine::normalizeMemRef<memref::AllocaOp>(memref::AllocaOp *op);
+template LogicalResult
+mlir::affine::normalizeMemRef<memref::AllocOp>(memref::AllocOp *op);
+
 MemRefType mlir::affine::normalizeMemRefType(MemRefType memrefType) {
   unsigned rank = memrefType.getRank();
   if (rank == 0)
diff --git a/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp b/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp
index 33772ccb7dd9d3..08b853fe65b857 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp
@@ -151,11 +151,11 @@ void NormalizeMemRefs::setCalleesAndCallersNonNormalizable(
   });
 }
 
-/// Check whether all the uses of AllocOps, CallOps and function arguments of a
-/// function are either of dereferencing type or are uses in: DeallocOp, CallOp
-/// or ReturnOp. Only if these constraints are satisfied will the function
-/// become a candidate for normalization. When the uses of a memref are
-/// non-normalizable and the memref map layout is trivial (identity), we can
+/// Check whether all the uses of AllocOps, AllocaOps, CallOps and function
+/// arguments of a function are either of dereferencing type or are uses in:
+/// DeallocOp, CallOp or ReturnOp. Only if these constraints are satisfied will
+/// the function become a candidate for normalization. When the uses of a memref
+/// are non-normalizable and the memref map layout is trivial (identity), we can
 /// still label the entire function as normalizable. We assume external
 /// functions to be normalizable.
 bool NormalizeMemRefs::areMemRefsNormalizable(func::FuncOp funcOp) {
@@ -174,6 +174,17 @@ bool NormalizeMemRefs::areMemRefsNormalizable(func::FuncOp funcOp) {
           .wasInterrupted())
     return false;
 
+  if (funcOp
+          .walk([&](memref::AllocaOp allocaOp) -> WalkResult {
+            Value oldMemRef = allocaOp.getResult();
+            if (!allocaOp.getType().getLayout().isIdentity() &&
+                !isMemRefNormalizable(oldMemRef.getUsers()))
+              return WalkResult::interrupt();
+            return WalkResult::advance();
+          })
+          .wasInterrupted())
+    return false;
+
   if (funcOp
           .walk([&](func::CallOp callOp) -> WalkResult {
             for (unsigned resIndex :
@@ -335,18 +346,23 @@ void NormalizeMemRefs::updateFunctionSignature(func::FuncOp funcOp,
 }
 
 /// Normalizes the memrefs within a function which includes those arising as a
-/// result of AllocOps, CallOps and function's argument. The ModuleOp argument
-/// is used to help update function's signature after normalization.
+/// result of AllocOps, AllocaOps, CallOps and function's argument. The ModuleOp
+/// argument is used to help update function's signature after normalization.
 void NormalizeMemRefs::normalizeFuncOpMemRefs(func::FuncOp funcOp,
                                               ModuleOp moduleOp) {
   // Turn memrefs' non-identity layouts maps into ones with identity. Collect
-  // alloc ops first and then process since normalizeMemRef replaces/erases ops
-  // during memref rewriting.
+  // alloc/alloca ops first and then process since normalizeMemRef
+  // replaces/erases ops during memref rewriting.
   SmallVector<memref::AllocOp, 4> allocOps;
   funcOp.walk([&](memref::AllocOp op) { allocOps.push_back(op); });
   for (memref::AllocOp allocOp : allocOps)
     (void)normalizeMemRef(&allocOp);
 
+  SmallVector<memref::AllocaOp> allocaOps;
+  funcOp.walk([&](memref::AllocaOp op) { allocaOps.push_back(op); });
+  for (memref::AllocaOp allocaOp : allocaOps)
+    (void)normalizeMemRef(&allocaOp);
+
   // We use this OpBuilder to create new memref layout later.
   OpBuilder b(funcOp);
 
diff --git a/mlir/test/Dialect/MemRef/normalize-memrefs.mlir b/mlir/test/Dialect/MemRef/normalize-memrefs.mlir
index 6d20ccbf2ca055..e93a1a4ebae532 100644
--- a/mlir/test/Dialect/MemRef/normalize-memrefs.mlir
+++ b/mlir/test/Dialect/MemRef/normalize-memrefs.mlir
@@ -31,6 +31,20 @@ func.func @permute() {
 // CHECK-NEXT: memref.dealloc [[MEM]]
 // CHECK-NEXT: return
 
+// CHECK-LABEL: func @alloca
+func.func @alloca(%idx : index) {
+  // CHECK-NEXT: memref.alloca() : memref<65xf32>
+  %A = memref.alloca() : memref<64xf32, affine_map<(d0) -> (d0 + 1)>>
+  // CHECK-NEXT: affine.load %{{.*}}[symbol(%arg0) + 1] : memref<65xf32>
+  affine.load %A[%idx] : memref<64xf32, affine_map<(d0) -> (d0 + 1)>>
+  affine.for %i = 0 to 64 {
+    %1 = affine.load %A[%i] : memref<64xf32, affine_map<(d0) -> (d0 + 1)>>
+    "prevent.dce"(%1) : (f32) -> ()
+    // CHECK: %{{.*}} = affine.load %{{.*}}[%arg{{.*}} + 1] : memref<65xf32>
+  }
+  return
+}
+
 // CHECK-LABEL: func @shift
 func.func @shift(%idx : index) {
   // CHECK-NEXT: memref.alloc() : memref<65xf32>

@makslevental
Copy link
Contributor

@MaheshRavishankar maybe you have enough context to review as well? Sorry if no...

Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont have much context about what this is doing w.r.t affine dialect, but the PR itself looks fine to me.

@mgehre-amd mgehre-amd merged commit 2ec2784 into llvm:main Jan 29, 2025
12 checks passed
@mgehre-amd mgehre-amd deleted the matthias.normalize_alloca branch January 29, 2025 07:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants