Skip to content

Commit 2ec2784

Browse files
authored
[MLIR] normalize-memrefs: Normalize memref.alloca (llvm#123293)
The pass was only handling `memref.alloc` and this extends it to also handle `memref.alloca`.
1 parent 1c3ea59 commit 2ec2784

File tree

4 files changed

+60
-17
lines changed

4 files changed

+60
-17
lines changed

mlir/include/mlir/Dialect/Affine/Utils.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ class FuncOp;
3131

3232
namespace memref {
3333
class AllocOp;
34+
class AllocaOp;
3435
} // namespace memref
3536

3637
namespace affine {
@@ -245,7 +246,12 @@ LogicalResult replaceAllMemRefUsesWith(Value oldMemRef, Value newMemRef,
245246
/// Rewrites the memref defined by this alloc op to have an identity layout map
246247
/// and updates all its indexing uses. Returns failure if any of its uses
247248
/// escape (while leaving the IR in a valid state).
248-
LogicalResult normalizeMemRef(memref::AllocOp *op);
249+
template <typename AllocLikeOp>
250+
LogicalResult normalizeMemRef(AllocLikeOp *op);
251+
extern template LogicalResult
252+
normalizeMemRef<memref::AllocaOp>(memref::AllocaOp *op);
253+
extern template LogicalResult
254+
normalizeMemRef<memref::AllocOp>(memref::AllocOp *op);
249255

250256
/// Normalizes `memrefType` so that the affine layout map of the memref is
251257
/// transformed to an identity map with a new shape being computed for the

mlir/lib/Dialect/Affine/Utils/Utils.cpp

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1737,9 +1737,10 @@ static AffineExpr createDimSizeExprForTiledLayout(AffineExpr oldMapOutput,
17371737
/// %c4 = arith.constant 4 : index
17381738
/// %1 = affine.apply #map1(%c4, %0)
17391739
/// %2 = affine.apply #map2(%c4, %0)
1740+
template <typename AllocLikeOp>
17401741
static void createNewDynamicSizes(MemRefType oldMemRefType,
17411742
MemRefType newMemRefType, AffineMap map,
1742-
memref::AllocOp *allocOp, OpBuilder b,
1743+
AllocLikeOp *allocOp, OpBuilder b,
17431744
SmallVectorImpl<Value> &newDynamicSizes) {
17441745
// Create new input for AffineApplyOp.
17451746
SmallVector<Value, 4> inAffineApply;
@@ -1786,7 +1787,8 @@ static void createNewDynamicSizes(MemRefType oldMemRefType,
17861787
}
17871788

17881789
// TODO: Currently works for static memrefs with a single layout map.
1789-
LogicalResult mlir::affine::normalizeMemRef(memref::AllocOp *allocOp) {
1790+
template <typename AllocLikeOp>
1791+
LogicalResult mlir::affine::normalizeMemRef(AllocLikeOp *allocOp) {
17901792
MemRefType memrefType = allocOp->getType();
17911793
OpBuilder b(*allocOp);
17921794

@@ -1802,7 +1804,7 @@ LogicalResult mlir::affine::normalizeMemRef(memref::AllocOp *allocOp) {
18021804

18031805
SmallVector<Value, 4> symbolOperands(allocOp->getSymbolOperands());
18041806
AffineMap layoutMap = memrefType.getLayout().getAffineMap();
1805-
memref::AllocOp newAlloc;
1807+
AllocLikeOp newAlloc;
18061808
// Check if `layoutMap` is a tiled layout. Only single layout map is
18071809
// supported for normalizing dynamic memrefs.
18081810
SmallVector<std::tuple<AffineExpr, unsigned, unsigned>> tileSizePos;
@@ -1814,11 +1816,11 @@ LogicalResult mlir::affine::normalizeMemRef(memref::AllocOp *allocOp) {
18141816
newDynamicSizes);
18151817
// Add the new dynamic sizes in new AllocOp.
18161818
newAlloc =
1817-
b.create<memref::AllocOp>(allocOp->getLoc(), newMemRefType,
1818-
newDynamicSizes, allocOp->getAlignmentAttr());
1819+
b.create<AllocLikeOp>(allocOp->getLoc(), newMemRefType, newDynamicSizes,
1820+
allocOp->getAlignmentAttr());
18191821
} else {
1820-
newAlloc = b.create<memref::AllocOp>(allocOp->getLoc(), newMemRefType,
1821-
allocOp->getAlignmentAttr());
1822+
newAlloc = b.create<AllocLikeOp>(allocOp->getLoc(), newMemRefType,
1823+
allocOp->getAlignmentAttr());
18221824
}
18231825
// Replace all uses of the old memref.
18241826
if (failed(replaceAllMemRefUsesWith(oldMemRef, /*newMemRef=*/newAlloc,
@@ -1843,6 +1845,11 @@ LogicalResult mlir::affine::normalizeMemRef(memref::AllocOp *allocOp) {
18431845
return success();
18441846
}
18451847

1848+
template LogicalResult
1849+
mlir::affine::normalizeMemRef<memref::AllocaOp>(memref::AllocaOp *op);
1850+
template LogicalResult
1851+
mlir::affine::normalizeMemRef<memref::AllocOp>(memref::AllocOp *op);
1852+
18461853
MemRefType mlir::affine::normalizeMemRefType(MemRefType memrefType) {
18471854
unsigned rank = memrefType.getRank();
18481855
if (rank == 0)

mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -151,11 +151,11 @@ void NormalizeMemRefs::setCalleesAndCallersNonNormalizable(
151151
});
152152
}
153153

154-
/// Check whether all the uses of AllocOps, CallOps and function arguments of a
155-
/// function are either of dereferencing type or are uses in: DeallocOp, CallOp
156-
/// or ReturnOp. Only if these constraints are satisfied will the function
157-
/// become a candidate for normalization. When the uses of a memref are
158-
/// non-normalizable and the memref map layout is trivial (identity), we can
154+
/// Check whether all the uses of AllocOps, AllocaOps, CallOps and function
155+
/// arguments of a function are either of dereferencing type or are uses in:
156+
/// DeallocOp, CallOp or ReturnOp. Only if these constraints are satisfied will
157+
/// the function become a candidate for normalization. When the uses of a memref
158+
/// are non-normalizable and the memref map layout is trivial (identity), we can
159159
/// still label the entire function as normalizable. We assume external
160160
/// functions to be normalizable.
161161
bool NormalizeMemRefs::areMemRefsNormalizable(func::FuncOp funcOp) {
@@ -174,6 +174,17 @@ bool NormalizeMemRefs::areMemRefsNormalizable(func::FuncOp funcOp) {
174174
.wasInterrupted())
175175
return false;
176176

177+
if (funcOp
178+
.walk([&](memref::AllocaOp allocaOp) -> WalkResult {
179+
Value oldMemRef = allocaOp.getResult();
180+
if (!allocaOp.getType().getLayout().isIdentity() &&
181+
!isMemRefNormalizable(oldMemRef.getUsers()))
182+
return WalkResult::interrupt();
183+
return WalkResult::advance();
184+
})
185+
.wasInterrupted())
186+
return false;
187+
177188
if (funcOp
178189
.walk([&](func::CallOp callOp) -> WalkResult {
179190
for (unsigned resIndex :
@@ -335,18 +346,23 @@ void NormalizeMemRefs::updateFunctionSignature(func::FuncOp funcOp,
335346
}
336347

337348
/// Normalizes the memrefs within a function which includes those arising as a
338-
/// result of AllocOps, CallOps and function's argument. The ModuleOp argument
339-
/// is used to help update function's signature after normalization.
349+
/// result of AllocOps, AllocaOps, CallOps and function's argument. The ModuleOp
350+
/// argument is used to help update function's signature after normalization.
340351
void NormalizeMemRefs::normalizeFuncOpMemRefs(func::FuncOp funcOp,
341352
ModuleOp moduleOp) {
342353
// Turn memrefs' non-identity layouts maps into ones with identity. Collect
343-
// alloc ops first and then process since normalizeMemRef replaces/erases ops
344-
// during memref rewriting.
354+
// alloc/alloca ops first and then process since normalizeMemRef
355+
// replaces/erases ops during memref rewriting.
345356
SmallVector<memref::AllocOp, 4> allocOps;
346357
funcOp.walk([&](memref::AllocOp op) { allocOps.push_back(op); });
347358
for (memref::AllocOp allocOp : allocOps)
348359
(void)normalizeMemRef(&allocOp);
349360

361+
SmallVector<memref::AllocaOp> allocaOps;
362+
funcOp.walk([&](memref::AllocaOp op) { allocaOps.push_back(op); });
363+
for (memref::AllocaOp allocaOp : allocaOps)
364+
(void)normalizeMemRef(&allocaOp);
365+
350366
// We use this OpBuilder to create new memref layout later.
351367
OpBuilder b(funcOp);
352368

mlir/test/Dialect/MemRef/normalize-memrefs.mlir

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,20 @@ func.func @permute() {
3131
// CHECK-NEXT: memref.dealloc [[MEM]]
3232
// CHECK-NEXT: return
3333

34+
// CHECK-LABEL: func @alloca
35+
func.func @alloca(%idx : index) {
36+
// CHECK-NEXT: memref.alloca() : memref<65xf32>
37+
%A = memref.alloca() : memref<64xf32, affine_map<(d0) -> (d0 + 1)>>
38+
// CHECK-NEXT: affine.load %{{.*}}[symbol(%arg0) + 1] : memref<65xf32>
39+
affine.load %A[%idx] : memref<64xf32, affine_map<(d0) -> (d0 + 1)>>
40+
affine.for %i = 0 to 64 {
41+
%1 = affine.load %A[%i] : memref<64xf32, affine_map<(d0) -> (d0 + 1)>>
42+
"prevent.dce"(%1) : (f32) -> ()
43+
// CHECK: %{{.*}} = affine.load %{{.*}}[%arg{{.*}} + 1] : memref<65xf32>
44+
}
45+
return
46+
}
47+
3448
// CHECK-LABEL: func @shift
3549
func.func @shift(%idx : index) {
3650
// CHECK-NEXT: memref.alloc() : memref<65xf32>

0 commit comments

Comments
 (0)