Skip to content

Commit bb17dfa

Browse files
authored
[mlir][bufferization] Enable moving dependent values in eliminate-empty-tensors (#169718)
Currently empty tensor elimination by constructing a SubsetExtractionOp to match a SubsetInsertionOp at the end of a DPS chain will fail if any operands required by the insertion op don't dominate the insertion point for the extraction op. This change improves the transformation by attempting to move all pure producers of required operands to the insertion point of the extraction op. In the process this improves a number of tests for empty tensor elimination.
1 parent 29fa151 commit bb17dfa

File tree

6 files changed

+197
-118
lines changed

6 files changed

+197
-118
lines changed

mlir/include/mlir/Transforms/RegionUtils.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,8 @@ LogicalResult moveOperationDependencies(RewriterBase &rewriter, Operation *op,
8484
/// Move definitions of `values` before an insertion point. Current support is
8585
/// only for movement of definitions within the same basic block. Note that this
8686
/// is an all-or-nothing approach. Either definitions of all values are moved
87-
/// before insertion point, or none of them are.
87+
/// before insertion point, or none of them are. Any side-effecting operations
88+
/// in the producer chain pessimistically blocks movement.
8889
LogicalResult moveValueDefinitions(RewriterBase &rewriter, ValueRange values,
8990
Operation *insertionPoint,
9091
DominanceInfo &dominance);

mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "mlir/Dialect/Tensor/IR/Tensor.h"
1717
#include "mlir/IR/Dominance.h"
1818
#include "mlir/Interfaces/SubsetOpInterface.h"
19+
#include "mlir/Transforms/RegionUtils.h"
1920

2021
namespace mlir {
2122
namespace bufferization {
@@ -105,8 +106,13 @@ Value mlir::bufferization::buildSubsetExtraction(RewriterBase &rewriter,
105106
// this replacement.
106107
Operation *insertionPoint =
107108
findValidInsertionPoint(emptyTensorOp, user, neededValues);
108-
if (!insertionPoint)
109-
return {};
109+
if (!insertionPoint) {
110+
// If no already suitable insertion point was found, attempt to move all
111+
// needed values before the user.
112+
if (failed(moveValueDefinitions(rewriter, neededValues, user)))
113+
return {};
114+
insertionPoint = user;
115+
}
110116

111117
rewriter.setInsertionPoint(insertionPoint);
112118
Value replacement =

mlir/lib/Transforms/Utils/RegionUtils.cpp

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1149,9 +1149,8 @@ LogicalResult mlir::moveValueDefinitions(RewriterBase &rewriter,
11491149
// Remove the values that already dominate the insertion point.
11501150
SmallVector<Value> prunedValues;
11511151
for (auto value : values) {
1152-
if (dominance.properlyDominates(value, insertionPoint)) {
1152+
if (dominance.properlyDominates(value, insertionPoint))
11531153
continue;
1154-
}
11551154
// Block arguments are not supported.
11561155
if (isa<BlockArgument>(value)) {
11571156
return rewriter.notifyMatchFailure(
@@ -1178,8 +1177,13 @@ LogicalResult mlir::moveValueDefinitions(RewriterBase &rewriter,
11781177
// Since current support is to only move within a same basic block,
11791178
// the slices dont need to look past block arguments.
11801179
options.omitBlockArguments = true;
1180+
bool dependsOnSideEffectingOp = false;
11811181
options.filter = [&](Operation *sliceBoundaryOp) {
1182-
return !dominance.properlyDominates(sliceBoundaryOp, insertionPoint);
1182+
bool mustMove =
1183+
!dominance.properlyDominates(sliceBoundaryOp, insertionPoint);
1184+
if (mustMove && !isPure(sliceBoundaryOp))
1185+
dependsOnSideEffectingOp = true;
1186+
return mustMove;
11831187
};
11841188
llvm::SetVector<Operation *> slice;
11851189
for (auto value : prunedValues) {
@@ -1188,6 +1192,10 @@ LogicalResult mlir::moveValueDefinitions(RewriterBase &rewriter,
11881192
(void)result;
11891193
}
11901194

1195+
// Check if any operation in the slice is side-effecting.
1196+
if (dependsOnSideEffectingOp)
1197+
return failure();
1198+
11911199
// If the slice contains `insertionPoint` cannot move the dependencies.
11921200
if (slice.contains(insertionPoint)) {
11931201
return rewriter.notifyMatchFailure(
@@ -1198,9 +1206,8 @@ LogicalResult mlir::moveValueDefinitions(RewriterBase &rewriter,
11981206
// Sort operations topologically before moving.
11991207
mlir::topologicalSort(slice);
12001208

1201-
for (Operation *op : slice) {
1209+
for (Operation *op : slice)
12021210
rewriter.moveOpBefore(op, insertionPoint);
1203-
}
12041211
return success();
12051212
}
12061213

mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir

Lines changed: 80 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -368,21 +368,18 @@ func.func @multiple_materialize_in_destination_buffer(%m: memref<5xf32>, %f: f32
368368

369369
// -----
370370

371-
// `EmptyTensorElimination` fails to find a valid insertion
372-
// point for the new injected `SubsetExtraction`.
373-
// CHECK-LABEL: func.func @fail_to_eliminate_any_empty_tensors
374-
func.func @fail_to_eliminate_any_empty_tensors() -> tensor<5x6x128xf32> {
371+
// CHECK-LABEL: func.func @eliminate_all_empty_tensors
372+
func.func @eliminate_all_empty_tensors() -> tensor<5x6x128xf32> {
375373
%cst_1 = arith.constant 1.0 : f32
376374
%cst_2 = arith.constant 2.0 : f32
377-
// CHECK: memref.alloc
378-
// CHECK: memref.alloc
379-
// CHECK: memref.alloc
375+
// CHECK: memref.alloc() {alignment = 64 : i64} : memref<5x6x128xf32>
376+
// CHECK-NOT: memref.alloc
380377
%empty_1 = tensor.empty() : tensor<5x6x64xf32>
381378
%res_1 = linalg.fill ins(%cst_1 : f32) outs(%empty_1 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32>
382379
%empty_2 = tensor.empty() : tensor<5x6x64xf32>
383380
%res_2 = linalg.fill ins(%cst_2 : f32) outs(%empty_2 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32>
384381
%cancatenated_empty = tensor.empty() : tensor<5x6x128xf32>
385-
// CHECK: memref.copy
382+
// CHECK-NOT: memref.copy
386383
%inserted_slice_1 = tensor.insert_slice %res_1 into %cancatenated_empty[0, 0, 0][5, 6, 64][1, 1, 1]
387384
: tensor<5x6x64xf32> into tensor<5x6x128xf32>
388385
%inserted_slice_2 = tensor.insert_slice %res_2 into %inserted_slice_1[0, 0, 64][5, 6, 64][1, 1, 1]
@@ -392,20 +389,19 @@ func.func @fail_to_eliminate_any_empty_tensors() -> tensor<5x6x128xf32> {
392389

393390
// -----
394391

395-
// CHECK-LABEL: func.func @succeed_to_eliminate_one_empty_tensor
396-
func.func @succeed_to_eliminate_one_empty_tensor() -> tensor<5x6x128xf32> {
392+
// CHECK-LABEL: func.func @eliminate_concatenated_empty_tensors
393+
func.func @eliminate_concatenated_empty_tensors() -> tensor<5x6x128xf32> {
397394
%cst_1 = arith.constant 1.0 : f32
398395
%cst_2 = arith.constant 2.0 : f32
399396
// CHECK: memref.alloc() {alignment = 64 : i64} : memref<5x6x128xf32>
400-
// CHECK: memref.alloc
401397
// CHECK-NOT: memref.alloc
402-
%cancatenated_empty = tensor.empty() : tensor<5x6x128xf32>
398+
%concatenated_empty = tensor.empty() : tensor<5x6x128xf32>
403399
%empty_1 = tensor.empty() : tensor<5x6x64xf32>
404400
%res_1 = linalg.fill ins(%cst_1 : f32) outs(%empty_1 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32>
405401
%empty_2 = tensor.empty() : tensor<5x6x64xf32>
406402
%res_2 = linalg.fill ins(%cst_2 : f32) outs(%empty_2 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32>
407-
// CHECK: memref.copy
408-
%inserted_slice_1 = tensor.insert_slice %res_1 into %cancatenated_empty[0, 0, 0][5, 6, 64][1, 1, 1]
403+
// CHECK-NOT: memref.copy
404+
%inserted_slice_1 = tensor.insert_slice %res_1 into %concatenated_empty[0, 0, 0][5, 6, 64][1, 1, 1]
409405
: tensor<5x6x64xf32> into tensor<5x6x128xf32>
410406
%inserted_slice_2 = tensor.insert_slice %res_2 into %inserted_slice_1[0, 0, 64][5, 6, 64][1, 1, 1]
411407
: tensor<5x6x64xf32> into tensor<5x6x128xf32>
@@ -420,20 +416,22 @@ func.func @succeed_to_eliminate_one_empty_tensor() -> tensor<5x6x128xf32> {
420416

421417
// CHECK-ELIM-LABEL: func.func @multi_use_of_the_same_tensor_empty
422418
// CHECK-LABEL: func.func @multi_use_of_the_same_tensor_empty
419+
// CHECK: memref.alloc() {alignment = 64 : i64} : memref<5x6x128xf32>
420+
// CHECK-NOT: memref.alloc
421+
// CHECK-NOT: memref.copy
422+
// CHECK-ELIM: tensor.extract_slice {{.*}}[0, 0, 0]
423+
// CHECK-ELIM: linalg.fill
424+
// CHECK-ELIM: tensor.extract_slice {{.*}}[0, 0, 64]
425+
// CHECK-ELIM: linalg.fill
423426
func.func @multi_use_of_the_same_tensor_empty() -> tensor<5x6x128xf32> {
424427
%cst_1 = arith.constant 1.0 : f32
425428
%cst_2 = arith.constant 2.0 : f32
426429
%cancatenated_empty = tensor.empty() : tensor<5x6x128xf32>
427430
%empty_1 = tensor.empty() : tensor<5x6x64xf32>
428-
// CHECK-ELIM: %[[VAL_3:.*]] = tensor.extract_slice
429-
// CHECK-ELIM: linalg.fill ins(%[[VAL_0:.*]] : f32) outs(%[[VAL_3]]
430-
// CHECK-ELIM-NOT: linalg.fill ins(%[[VAL_1:.*]] : f32) outs(%[[VAL_3]]
431431
%res_1 = linalg.fill ins(%cst_1 : f32) outs(%empty_1 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32>
432432
%res_2 = linalg.fill ins(%cst_2 : f32) outs(%empty_1 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32>
433-
// CHECK: memref.copy
434433
%inserted_slice_1 = tensor.insert_slice %res_1 into %cancatenated_empty[0, 0, 0][5, 6, 64][1, 1, 1]
435434
: tensor<5x6x64xf32> into tensor<5x6x128xf32>
436-
// CHECK-NOT: memref.copy
437435
%inserted_slice_2 = tensor.insert_slice %res_2 into %inserted_slice_1[0, 0, 64][5, 6, 64][1, 1, 1]
438436
: tensor<5x6x64xf32> into tensor<5x6x128xf32>
439437
return %inserted_slice_2 : tensor<5x6x128xf32>
@@ -476,3 +474,66 @@ func.func @direct_use_of_tensor_empty(%arg0: tensor<5x6x128xf32>) -> tensor<5x6x
476474
: tensor<5x6x64xf32> into tensor<5x6x128xf32>
477475
return %inserted_slice_1 : tensor<5x6x128xf32>
478476
}
477+
478+
// -----
479+
480+
// Test that dependent pure operations are moved before the
481+
// insertion point to enable empty tensor elimination.
482+
483+
// CHECK-LABEL: func.func @move_dependent_arith_op(
484+
// CHECK-SAME: %[[ARG0:.*]]: memref<10xf32>
485+
// CHECK-SAME: %[[ARG1:.*]]: index
486+
// CHECK-NOT: memref.alloc
487+
// CHECK: %[[C5:.*]] = arith.constant 5 : index
488+
// CHECK: %[[OFFSET:.*]] = arith.addi %[[ARG1]], %[[C5]]
489+
// CHECK: %[[SV:.*]] = memref.subview %[[ARG0]][%[[OFFSET]]] [5] [1]
490+
// CHECK: linalg.fill {{.*}} outs(%[[SV]]
491+
// CHECK: return %[[ARG0]]
492+
// CHECK-ELIM-LABEL: func.func @move_dependent_arith_op(
493+
// CHECK-ELIM-SAME: %[[ARG0:.*]]: tensor<10xf32>
494+
// CHECK-ELIM-SAME: %[[ARG1:.*]]: index
495+
// CHECK-ELIM: %[[C5:.*]] = arith.constant 5 : index
496+
// CHECK-ELIM: %[[OFFSET:.*]] = arith.addi %[[ARG1]], %[[C5]]
497+
// CHECK-ELIM: %[[SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[OFFSET]]] [5] [1]
498+
// CHECK-ELIM: %[[FILL:.*]] = linalg.fill {{.*}} outs(%[[SLICE]]
499+
// CHECK-ELIM: tensor.insert_slice %[[FILL]] into %[[ARG0]][%[[OFFSET]]]
500+
func.func @move_dependent_arith_op(
501+
%arg0: tensor<10xf32> {bufferization.buffer_layout = affine_map<(d0) -> (d0)>, bufferization.writable = true},
502+
%arg1: index, %f: f32) -> tensor<10xf32>
503+
{
504+
%0 = tensor.empty() : tensor<5xf32>
505+
%1 = linalg.fill ins(%f : f32) outs(%0 : tensor<5xf32>) -> tensor<5xf32>
506+
%c5 = arith.constant 5 : index
507+
%offset = arith.addi %arg1, %c5 : index
508+
%2 = tensor.insert_slice %1 into %arg0[%offset][5][1]
509+
: tensor<5xf32> into tensor<10xf32>
510+
return %2 : tensor<10xf32>
511+
}
512+
513+
// -----
514+
515+
// Test that side-effecting operations are not moved, preventing empty
516+
// tensor elimination.
517+
518+
// CHECK-LABEL: func.func @side_effecting_op_blocks_movement(
519+
// CHECK: memref.alloc
520+
// CHECK: linalg.fill
521+
// CHECK: memref.load
522+
// CHECK: memref.subview
523+
// CHECK: memref.copy
524+
// CHECK-ELIM-LABEL: func.func @side_effecting_op_blocks_movement(
525+
// CHECK-ELIM: tensor.empty
526+
// CHECK-ELIM: linalg.fill
527+
// CHECK-ELIM: memref.load
528+
// CHECK-ELIM: tensor.insert_slice
529+
func.func @side_effecting_op_blocks_movement(
530+
%arg0: tensor<10xf32> {bufferization.buffer_layout = affine_map<(d0) -> (d0)>, bufferization.writable = true},
531+
%mem: memref<index>, %f: f32) -> tensor<10xf32>
532+
{
533+
%0 = tensor.empty() : tensor<5xf32>
534+
%1 = linalg.fill ins(%f : f32) outs(%0 : tensor<5xf32>) -> tensor<5xf32>
535+
%offset = memref.load %mem[] : memref<index>
536+
%2 = tensor.insert_slice %1 into %arg0[%offset][5][1]
537+
: tensor<5xf32> into tensor<10xf32>
538+
return %2 : tensor<10xf32>
539+
}

0 commit comments

Comments
 (0)