Skip to content

Commit 56984f1

Browse files
simpel01Priyanshu3820
authored andcommitted
[mlir][memref] Generalize dead store detection to all view-like ops (llvm#168507)
The dead alloc elimination pass previously considered only subviews when checking for dead stores. This change generalizes the logic to support all view-like operations, ensuring broader coverage.
1 parent 99103be commit 56984f1

File tree

2 files changed

+69
-2
lines changed

2 files changed

+69
-2
lines changed

mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ getLinearizedMemRefOffsetAndSize(OpBuilder &builder, Location loc, int srcBits,
133133
}
134134

135135
/// Returns true if all the uses of op are not read/load.
136-
/// There can be SubviewOp users as long as all its users are also
136+
/// There can be view-like-op users as long as all its users are also
137137
/// StoreOp/transfer_write. If return true it also fills out the uses, if it
138138
/// returns false uses is unchanged.
139139
static bool resultIsNotRead(Operation *op, std::vector<Operation *> &uses) {
@@ -146,7 +146,7 @@ static bool resultIsNotRead(Operation *op, std::vector<Operation *> &uses) {
146146
if (isa<memref::DeallocOp>(useOp) ||
147147
(useOp->getNumResults() == 0 && useOp->getNumRegions() == 0 &&
148148
!mlir::hasEffect<MemoryEffects::Read>(useOp)) ||
149-
(isa<memref::SubViewOp>(useOp) && resultIsNotRead(useOp, opUses))) {
149+
(isa<ViewLikeOpInterface>(useOp) && resultIsNotRead(useOp, opUses))) {
150150
opUses.push_back(useOp);
151151
continue;
152152
}

mlir/test/Dialect/MemRef/transform-ops.mlir

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -395,6 +395,73 @@ module attributes {transform.with_named_sequence} {
395395

396396
// -----
397397

398+
// CHECK-LABEL: @dead_store_through_subview
399+
// CHECK-SAME: (%[[ARG:.+]]: vector<4xf32>)
400+
// CHECK-NOT: memref.alloc()
401+
// CHECK-NOT: vector.transfer_write
402+
func.func @dead_store_through_subview(%arg: vector<4xf32>) {
403+
%c0 = arith.constant 0 : index
404+
%alloc = memref.alloc() {alignment = 64 : i64} : memref<64xf32>
405+
%subview = memref.subview %alloc[%c0] [4] [1] : memref<64xf32> to memref<4xf32, affine_map<(d0)[s0] -> (d0 + s0)>>
406+
vector.transfer_write %arg, %subview[%c0] {in_bounds = [true]}
407+
: vector<4xf32>, memref<4xf32, affine_map<(d0)[s0] -> (d0 + s0)>>
408+
return
409+
}
410+
411+
module attributes {transform.with_named_sequence} {
412+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
413+
%0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
414+
transform.memref.erase_dead_alloc_and_stores %0 : (!transform.any_op) -> ()
415+
transform.yield
416+
}
417+
}
418+
419+
// -----
420+
421+
// CHECK-LABEL: @dead_store_through_expand
422+
// CHECK-SAME: (%[[ARG:.+]]: vector<4xf32>)
423+
// CHECK-NOT: memref.alloc()
424+
// CHECK-NOT: vector.transfer_write
425+
func.func @dead_store_through_expand(%arg: vector<4xf32>) {
426+
%c0 = arith.constant 0 : index
427+
%alloc = memref.alloc() {alignment = 64 : i64} : memref<64xf32>
428+
%expand = memref.expand_shape %alloc [[0, 1]] output_shape [16, 4] : memref<64xf32> into memref<16x4xf32>
429+
vector.transfer_write %arg, %expand[%c0, %c0] {in_bounds = [true]} : vector<4xf32>, memref<16x4xf32>
430+
return
431+
}
432+
433+
module attributes {transform.with_named_sequence} {
434+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
435+
%0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
436+
transform.memref.erase_dead_alloc_and_stores %0 : (!transform.any_op) -> ()
437+
transform.yield
438+
}
439+
}
440+
441+
// -----
442+
443+
// CHECK-LABEL: @dead_store_through_collapse
444+
// CHECK-SAME: (%[[ARG:.+]]: vector<4xf32>)
445+
// CHECK-NOT: memref.alloc()
446+
// CHECK-NOT: vector.transfer_write
447+
func.func @dead_store_through_collapse(%arg: vector<4xf32>) {
448+
%c0 = arith.constant 0 : index
449+
%alloc = memref.alloc() {alignment = 64 : i64} : memref<16x4xf32>
450+
%collapse = memref.collapse_shape %alloc [[0, 1]] : memref<16x4xf32> into memref<64xf32>
451+
vector.transfer_write %arg, %collapse[%c0] {in_bounds = [true]} : vector<4xf32>, memref<64xf32>
452+
return
453+
}
454+
455+
module attributes {transform.with_named_sequence} {
456+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
457+
%0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
458+
transform.memref.erase_dead_alloc_and_stores %0 : (!transform.any_op) -> ()
459+
transform.yield
460+
}
461+
}
462+
463+
// -----
464+
398465
// CHECK-LABEL: func @lower_to_llvm
399466
// CHECK-NOT: memref.alloc
400467
// CHECK: llvm.call @malloc

0 commit comments

Comments
 (0)