-
Notifications
You must be signed in to change notification settings - Fork 15.4k
[MLIR] normalize-memrefs: Normalize memref.alloca #123293
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The pass was only handling memref.alloc, and this extends it to also handle memref.alloca.
|
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-memref Author: Matthias Gehre (mgehre-amd) ChangesThe pass was only handling Full diff: https://github.com/llvm/llvm-project/pull/123293.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Affine/Utils.h b/mlir/include/mlir/Dialect/Affine/Utils.h
index 0f801ebb6f5898..ff1900bc8f2ebc 100644
--- a/mlir/include/mlir/Dialect/Affine/Utils.h
+++ b/mlir/include/mlir/Dialect/Affine/Utils.h
@@ -31,6 +31,7 @@ class FuncOp;
namespace memref {
class AllocOp;
+class AllocaOp;
} // namespace memref
namespace affine {
@@ -245,7 +246,12 @@ LogicalResult replaceAllMemRefUsesWith(Value oldMemRef, Value newMemRef,
/// Rewrites the memref defined by this alloc op to have an identity layout map
/// and updates all its indexing uses. Returns failure if any of its uses
/// escape (while leaving the IR in a valid state).
-LogicalResult normalizeMemRef(memref::AllocOp *op);
+template <typename AllocLikeOp>
+LogicalResult normalizeMemRef(AllocLikeOp *op);
+extern template LogicalResult
+normalizeMemRef<memref::AllocaOp>(memref::AllocaOp *op);
+extern template LogicalResult
+normalizeMemRef<memref::AllocOp>(memref::AllocOp *op);
/// Normalizes `memrefType` so that the affine layout map of the memref is
/// transformed to an identity map with a new shape being computed for the
diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
index 9e3257a62b12fb..7ef016f88be375 100644
--- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
@@ -1737,9 +1737,10 @@ static AffineExpr createDimSizeExprForTiledLayout(AffineExpr oldMapOutput,
/// %c4 = arith.constant 4 : index
/// %1 = affine.apply #map1(%c4, %0)
/// %2 = affine.apply #map2(%c4, %0)
+template <typename AllocLikeOp>
static void createNewDynamicSizes(MemRefType oldMemRefType,
MemRefType newMemRefType, AffineMap map,
- memref::AllocOp *allocOp, OpBuilder b,
+ AllocLikeOp *allocOp, OpBuilder b,
SmallVectorImpl<Value> &newDynamicSizes) {
// Create new input for AffineApplyOp.
SmallVector<Value, 4> inAffineApply;
@@ -1786,7 +1787,8 @@ static void createNewDynamicSizes(MemRefType oldMemRefType,
}
// TODO: Currently works for static memrefs with a single layout map.
-LogicalResult mlir::affine::normalizeMemRef(memref::AllocOp *allocOp) {
+template <typename AllocLikeOp>
+LogicalResult mlir::affine::normalizeMemRef(AllocLikeOp *allocOp) {
MemRefType memrefType = allocOp->getType();
OpBuilder b(*allocOp);
@@ -1802,7 +1804,7 @@ LogicalResult mlir::affine::normalizeMemRef(memref::AllocOp *allocOp) {
SmallVector<Value, 4> symbolOperands(allocOp->getSymbolOperands());
AffineMap layoutMap = memrefType.getLayout().getAffineMap();
- memref::AllocOp newAlloc;
+ AllocLikeOp newAlloc;
// Check if `layoutMap` is a tiled layout. Only single layout map is
// supported for normalizing dynamic memrefs.
SmallVector<std::tuple<AffineExpr, unsigned, unsigned>> tileSizePos;
@@ -1814,11 +1816,11 @@ LogicalResult mlir::affine::normalizeMemRef(memref::AllocOp *allocOp) {
newDynamicSizes);
// Add the new dynamic sizes in new AllocOp.
newAlloc =
- b.create<memref::AllocOp>(allocOp->getLoc(), newMemRefType,
- newDynamicSizes, allocOp->getAlignmentAttr());
+ b.create<AllocLikeOp>(allocOp->getLoc(), newMemRefType, newDynamicSizes,
+ allocOp->getAlignmentAttr());
} else {
- newAlloc = b.create<memref::AllocOp>(allocOp->getLoc(), newMemRefType,
- allocOp->getAlignmentAttr());
+ newAlloc = b.create<AllocLikeOp>(allocOp->getLoc(), newMemRefType,
+ allocOp->getAlignmentAttr());
}
// Replace all uses of the old memref.
if (failed(replaceAllMemRefUsesWith(oldMemRef, /*newMemRef=*/newAlloc,
@@ -1843,6 +1845,11 @@ LogicalResult mlir::affine::normalizeMemRef(memref::AllocOp *allocOp) {
return success();
}
+template LogicalResult
+mlir::affine::normalizeMemRef<memref::AllocaOp>(memref::AllocaOp *op);
+template LogicalResult
+mlir::affine::normalizeMemRef<memref::AllocOp>(memref::AllocOp *op);
+
MemRefType mlir::affine::normalizeMemRefType(MemRefType memrefType) {
unsigned rank = memrefType.getRank();
if (rank == 0)
diff --git a/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp b/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp
index 33772ccb7dd9d3..08b853fe65b857 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp
@@ -151,11 +151,11 @@ void NormalizeMemRefs::setCalleesAndCallersNonNormalizable(
});
}
-/// Check whether all the uses of AllocOps, CallOps and function arguments of a
-/// function are either of dereferencing type or are uses in: DeallocOp, CallOp
-/// or ReturnOp. Only if these constraints are satisfied will the function
-/// become a candidate for normalization. When the uses of a memref are
-/// non-normalizable and the memref map layout is trivial (identity), we can
+/// Check whether all the uses of AllocOps, AllocaOps, CallOps and function
+/// arguments of a function are either of dereferencing type or are uses in:
+/// DeallocOp, CallOp or ReturnOp. Only if these constraints are satisfied will
+/// the function become a candidate for normalization. When the uses of a memref
+/// are non-normalizable and the memref map layout is trivial (identity), we can
/// still label the entire function as normalizable. We assume external
/// functions to be normalizable.
bool NormalizeMemRefs::areMemRefsNormalizable(func::FuncOp funcOp) {
@@ -174,6 +174,17 @@ bool NormalizeMemRefs::areMemRefsNormalizable(func::FuncOp funcOp) {
.wasInterrupted())
return false;
+ if (funcOp
+ .walk([&](memref::AllocaOp allocaOp) -> WalkResult {
+ Value oldMemRef = allocaOp.getResult();
+ if (!allocaOp.getType().getLayout().isIdentity() &&
+ !isMemRefNormalizable(oldMemRef.getUsers()))
+ return WalkResult::interrupt();
+ return WalkResult::advance();
+ })
+ .wasInterrupted())
+ return false;
+
if (funcOp
.walk([&](func::CallOp callOp) -> WalkResult {
for (unsigned resIndex :
@@ -335,18 +346,23 @@ void NormalizeMemRefs::updateFunctionSignature(func::FuncOp funcOp,
}
/// Normalizes the memrefs within a function which includes those arising as a
-/// result of AllocOps, CallOps and function's argument. The ModuleOp argument
-/// is used to help update function's signature after normalization.
+/// result of AllocOps, AllocaOps, CallOps and function's argument. The ModuleOp
+/// argument is used to help update function's signature after normalization.
void NormalizeMemRefs::normalizeFuncOpMemRefs(func::FuncOp funcOp,
ModuleOp moduleOp) {
// Turn memrefs' non-identity layouts maps into ones with identity. Collect
- // alloc ops first and then process since normalizeMemRef replaces/erases ops
- // during memref rewriting.
+ // alloc/alloca ops first and then process since normalizeMemRef
+ // replaces/erases ops during memref rewriting.
SmallVector<memref::AllocOp, 4> allocOps;
funcOp.walk([&](memref::AllocOp op) { allocOps.push_back(op); });
for (memref::AllocOp allocOp : allocOps)
(void)normalizeMemRef(&allocOp);
+ SmallVector<memref::AllocaOp> allocaOps;
+ funcOp.walk([&](memref::AllocaOp op) { allocaOps.push_back(op); });
+ for (memref::AllocaOp allocaOp : allocaOps)
+ (void)normalizeMemRef(&allocaOp);
+
// We use this OpBuilder to create new memref layout later.
OpBuilder b(funcOp);
diff --git a/mlir/test/Dialect/MemRef/normalize-memrefs.mlir b/mlir/test/Dialect/MemRef/normalize-memrefs.mlir
index 6d20ccbf2ca055..e93a1a4ebae532 100644
--- a/mlir/test/Dialect/MemRef/normalize-memrefs.mlir
+++ b/mlir/test/Dialect/MemRef/normalize-memrefs.mlir
@@ -31,6 +31,20 @@ func.func @permute() {
// CHECK-NEXT: memref.dealloc [[MEM]]
// CHECK-NEXT: return
+// CHECK-LABEL: func @alloca
+func.func @alloca(%idx : index) {
+ // CHECK-NEXT: memref.alloca() : memref<65xf32>
+ %A = memref.alloca() : memref<64xf32, affine_map<(d0) -> (d0 + 1)>>
+ // CHECK-NEXT: affine.load %{{.*}}[symbol(%arg0) + 1] : memref<65xf32>
+ affine.load %A[%idx] : memref<64xf32, affine_map<(d0) -> (d0 + 1)>>
+ affine.for %i = 0 to 64 {
+ %1 = affine.load %A[%i] : memref<64xf32, affine_map<(d0) -> (d0 + 1)>>
+ "prevent.dce"(%1) : (f32) -> ()
+ // CHECK: %{{.*}} = affine.load %{{.*}}[%arg{{.*}} + 1] : memref<65xf32>
+ }
+ return
+}
+
// CHECK-LABEL: func @shift
func.func @shift(%idx : index) {
// CHECK-NEXT: memref.alloc() : memref<65xf32>
|
|
@MaheshRavishankar maybe you have enough context to review as well? Sorry if no... |
MaheshRavishankar
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I dont have much context about what this is doing w.r.t affine dialect, but the PR itself looks fine to me.
The pass was only handling
memref.allocand this extends it to also handlememref.alloca.