Skip to content

Commit bada367

Browse files
authored
[FXML-5417] TileUsingInterface: drop unused extract_slice (#432)
1 parent 752540d commit bada367

File tree

2 files changed

+46
-0
lines changed

2 files changed

+46
-0
lines changed

mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1536,6 +1536,12 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
15361536
tiledAndFusedOps.insert(tiledAndFusedOp);
15371537
}
15381538

1539+
// Drop the extract_slice if it has been replaced by the tiled producer, and
1540+
// is no longer used.
1541+
if (worklistItem.candidateSlice->use_empty()) {
1542+
rewriter.eraseOp(worklistItem.candidateSlice);
1543+
}
1544+
15391545
if (failed(sliceTracker.insertAndApplyPatterns(worklistCandidates))) {
15401546
return rewriter.notifyMatchFailure(consumer, "cleanup patterns failed");
15411547
}

mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -651,6 +651,46 @@ module {
651651
}
652652
}
653653

654+
// -----
655+
656+
// This test checks that upon tiling and fusion, Linalg ops that have been tiled
657+
// through fusion and are not used elsewhere are indeed dead code and get
658+
// dropped.
659+
660+
// CHECK-LABEL: func @tile_fuse_drop_dead_producer(
661+
// CHECK-SAME: %[[TA:[0-9a-z]+]]: tensor<10x10xf32>) -> tensor<10x10xf32> {
662+
func.func @tile_fuse_drop_dead_producer(%arg0: tensor<10x10xf32>) -> tensor<10x10xf32> {
663+
%c2f = arith.constant 2.0 : f32
664+
665+
// CHECK-NOT: linalg.generic {{{[^\}]*}}} ins(%[[TA]] : tensor<10x10xf32>) outs(%{{.*}} : tensor<10x10xf32>) {
666+
%empty = tensor.empty() : tensor<10x10xf32>
667+
%0 = linalg.generic {indexing_maps = [affine_map<(i, j) -> (i, j)>, affine_map<(i, j) -> (i, j)>], iterator_types = ["parallel", "parallel"]}
668+
ins(%arg0: tensor<10x10xf32>) outs(%empty: tensor<10x10xf32>) {
669+
^bb0(%a: f32, %b: f32):
670+
%res = arith.addf %a, %c2f : f32
671+
linalg.yield %res : f32
672+
} -> tensor<10x10xf32>
673+
674+
%empty2 = tensor.empty() : tensor<10x10xf32>
675+
// CHECK: scf.for {{.*}} {
676+
// CHECK: scf.for {{.*}} {
677+
// CHECK: linalg.generic
678+
// CHECK: linalg.negf
679+
// CHECK: }
680+
// CHECK: }
681+
%1 = linalg.negf ins(%0 : tensor<10x10xf32>) outs(%empty2 : tensor<10x10xf32>) -> tensor<10x10xf32>
682+
683+
return %1 : tensor<10x10xf32>
684+
}
685+
686+
module attributes {transform.with_named_sequence} {
687+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
688+
%0 = transform.structured.match ops{["linalg.negf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
689+
%tiled_low, %loop1, %loop2 = transform.structured.fuse %0 [5, 2] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
690+
transform.yield
691+
}
692+
}
693+
654694

655695
////////////////////////////////////////////////////////////////////////////////
656696
// Tests below are expected to fail.

0 commit comments

Comments
 (0)