diff --git a/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp b/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp index 5396531922aab..0f5fa61879b71 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp @@ -21,6 +21,7 @@ #include "mlir/Dialect/Vector/Utils/VectorUtils.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/BuiltinAttributes.h" +#include "mlir/Interfaces/ValueBoundsOpInterface.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/TypeSwitch.h" #include @@ -67,6 +68,10 @@ class InsertSliceOfTransferWriteOpFolder final LogicalResult matchAndRewrite(tensor::InsertSliceOp insertSliceOp, PatternRewriter &rewriter) const override; + +private: + static bool + doesTransferWriteCoverInsertSlice(vector::TransferWriteOp writeOp); }; } // namespace @@ -136,6 +141,10 @@ LogicalResult InsertSliceOfTransferWriteOpFolder::matchAndRewrite( if (failed(preconditionResult)) return preconditionResult; + if (!doesTransferWriteCoverInsertSlice(writeOp)) + return rewriter.notifyMatchFailure( + insertSliceOp, "transfer_write does not cover insert_slice"); + SmallVector indices(writeOp.getIndices().begin(), writeOp.getIndices().end()); SmallVector sourceIndices; @@ -154,6 +163,17 @@ LogicalResult InsertSliceOfTransferWriteOpFolder::matchAndRewrite( return success(); } +bool InsertSliceOfTransferWriteOpFolder::doesTransferWriteCoverInsertSlice( + vector::TransferWriteOp writeOp) { + if (writeOp.getShapedType().hasStaticShape()) + return llvm::equal(writeOp.getVectorType().getShape(), + writeOp.getShapedType().getShape()); + + // TODO: Use ValueBoundsConstraintSet for dynamic shapes. + + return false; +} + template struct InsertSliceOfInsertSliceFolder : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; diff --git a/mlir/test/Dialect/Tensor/fold-tensor-subset-ops.mlir b/mlir/test/Dialect/Tensor/fold-tensor-subset-ops.mlir index 1a84e14104932..988b5d835c16e 100644 --- a/mlir/test/Dialect/Tensor/fold-tensor-subset-ops.mlir +++ b/mlir/test/Dialect/Tensor/fold-tensor-subset-ops.mlir @@ -144,8 +144,6 @@ func.func @transfer_read_of_extract_slice_swappy_rank_reducing(%t : tensor (s0 + s1)> - // CHECK: func @fold_vector_transfer_write_with_rank_reduced_insert_slice // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: vector<4xf32> @@ -155,6 +153,7 @@ func.func @transfer_read_of_extract_slice_swappy_rank_reducing(%t : tensor func.func @fold_vector_transfer_write_with_rank_reduced_insert_slice( %arg0 : tensor, %arg1 : vector<4xf32>, %arg2: index, %arg3 : index, %arg4 : index, @@ -162,11 +161,8 @@ func.func @fold_vector_transfer_write_with_rank_reduced_insert_slice( %st : tensor) -> tensor { %cst = arith.constant 0.0 : f32 -// CHECK-NOT: insert_slice -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[IDX0:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG6]]] -// CHECK-DAG: %[[IDX1:.+]] = affine.apply #[[MAP1]]()[%[[ARG3]], %[[ARG7]]] -// CHECK-DAG: vector.transfer_write %[[ARG1]], %[[ARG0]][%[[C0]], %[[IDX0]], %[[IDX1]]] {in_bounds = [true]} : vector<4xf32>, tensor, tensor + // CHECK-DAG: %[[r2:.*]] = tensor.insert_slice %[[r1]] into %[[ARG0]][0, %[[ARG2]], %[[ARG3]]] [1, %[[ARG4]], %[[ARG5]]] [1, 1, 1] : tensor into tensor %0 = vector.transfer_write %arg1, %st[%arg6, %arg7] {in_bounds = [true]} : vector<4xf32>, tensor %1 = tensor.insert_slice %0 into %arg0[0, %arg2, %arg3] [1, %arg4, %arg5] [1, 1, 1] @@ -176,9 +172,6 @@ func.func @fold_vector_transfer_write_with_rank_reduced_insert_slice( // ----- -// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s0 + s1)> -// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2) -> (d1)> - // CHECK: func @fold_vector_transfer_write_with_inner_rank_reduced_insert_slice // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: vector<4xf32> @@ -188,6 +181,7 @@ func.func @fold_vector_transfer_write_with_rank_reduced_insert_slice( // CHECK-SAME: %[[ARG5:[a-zA-Z0-9]+]]: index // CHECK-SAME: %[[ARG6:[a-zA-Z0-9]+]]: index // CHECK-SAME: %[[ARG7:[a-zA-Z0-9]+]]: index +// CHECK-SAME: %[[ARG8:[a-zA-Z0-9]+]]: tensor func.func @fold_vector_transfer_write_with_inner_rank_reduced_insert_slice( %arg0 : tensor, %arg1 : vector<4xf32>, %arg2: index, %arg3 : index, %arg4 : index, @@ -195,12 +189,8 @@ func.func @fold_vector_transfer_write_with_inner_rank_reduced_insert_slice( %st : tensor) -> tensor { %cst = arith.constant 0.0 : f32 - // CHECK-NOT: insert_slice - // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index - // CHECK-DAG: %[[IDX0:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG6]]] - // CHECK-DAG: %[[IDX1:.+]] = affine.apply #[[MAP1]]()[%[[ARG3]], %[[ARG7]]] - // CHECK-DAG: vector.transfer_write %[[ARG1]], %[[ARG0]][%[[IDX0]], %[[IDX1]], %[[C0]]] - // CHECK-SAME: {in_bounds = [true], permutation_map = #[[MAP2]]} : vector<4xf32>, tensor, tensor + // CHECK-DAG: %[[r2:.*]] = tensor.insert_slice %[[r1]] into %[[ARG0]][%[[ARG2]], %[[ARG3]], 0] [%[[ARG4]], %[[ARG5]], 1] [1, 1, 1] : tensor into tensor %0 = vector.transfer_write %arg1, %st[%arg6, %arg7] {in_bounds = [true]} : vector<4xf32>, tensor %1 = tensor.insert_slice %0 into %arg0[%arg2, %arg3, 0] [%arg4, %arg5, 1] [1, 1, 1] @@ -226,6 +216,24 @@ func.func @insert_slice_of_transfer_write(%t1 : tensor, %v : vector<5x // ----- +// This test is negative since `transfer_write` only +// writes to `5x6` of the `100x100` elements of `%arg3` +// CHECK-LABEL: func @insert_slice_of_transfer_write_overwrite_all( +// CHECK-SAME: %[[arg0:.*]]: tensor<1000x1000xf32>, %[[arg1:.*]]: vector<5x6xf32>, %[[arg2:.*]]: index, %[[arg3:.*]]: tensor<100x100xf32> +func.func @insert_slice_of_transfer_write_overwrite_all(%arg0: tensor<1000x1000xf32>, %arg1: vector<5x6xf32>, %arg2: index, %arg3: tensor<100x100xf32>) -> tensor<1000x1000xf32> { + %c0 = arith.constant 0 : index + +// CHECK: %[[c0:.*]] = arith.constant 0 : index +// CHECK: %[[r1:.*]] = vector.transfer_write %[[arg1]], %[[arg3]][%[[c0]], %[[c0]]] {in_bounds = [true, true]} : vector<5x6xf32>, tensor<100x100xf32> +// CHECK: %[[r2:.*]] = tensor.insert_slice %[[r1]] into %[[arg0]][3, %[[arg2]]] [100, 100] [1, 1] : tensor<100x100xf32> into tensor<1000x1000xf32> +// CHECK: return %[[r2]] : tensor<1000x1000xf32> + %0 = vector.transfer_write %arg1, %arg3[%c0, %c0] {in_bounds = [true, true]} : vector<5x6xf32>, tensor<100x100xf32> + %inserted_slice = tensor.insert_slice %0 into %arg0[3, %arg2] [100, 100] [1, 1] : tensor<100x100xf32> into tensor<1000x1000xf32> + return %inserted_slice : tensor<1000x1000xf32> +} + +// ----- + // CHECK-DAG: #[[$d0d2:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)> // CHECK-LABEL: func @insert_slice_of_transfer_write_swappy_rank_extending(