From 1abdf4fff12379f52353ed18c0e02f4a6a21a7c8 Mon Sep 17 00:00:00 2001 From: Andrzej Warzynski Date: Wed, 15 Jan 2025 08:48:29 +0000 Subject: [PATCH 1/3] [mlir][linalg] Add support for masked vectorization of `tensor.insert_slice` (2/N) For context, recall that `tensor.insert_slice` is vectorised using the `vector.transfer_read` + `vector.transfer_write` pair. An unmasked example is shown below: ```mlir // BEFORE VECTORIZATION %res = tensor.insert_slice %slice into %dest[0, %c2] [5, 1] [1, 1] : tensor<5x1xi32> into tensor<5x3xi32> // AFTER VECTORIZATION %read = vector.transfer_read %source[%c0, %c0], %pad : tensor<5x1xi32>, vector<8x1xi32> %res = vector.transfer_write %read, %dest[%c0, %c2] : vector<8x1xi32>, tensor<5x3xi32> ``` This PR extends `vectorizeAsInsertSliceOp` to add masking support for the `vector.transfer_write` operation. This complements the changes in #122927, which introduced masking for the `vector.transfer_read`. --- .../Linalg/Transforms/Vectorization.cpp | 48 +++++----- .../Linalg/vectorization-unsupported.mlir | 23 ----- mlir/test/Dialect/Linalg/vectorization.mlir | 89 +++++++++++++++++-- 3 files changed, 108 insertions(+), 52 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp index a43fa86166e83..1b50c24fbfcfc 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -2716,15 +2716,12 @@ vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp, } auto vecType = VectorType::get(vecShape, sourceType.getElementType()); - // 3. Generate TransferReadOp. - SmallVector readIndices( - vecType.getRank(), - rewriter.create(sliceOp.getLoc(), 0)); - Operation *read = rewriter.create( - sliceOp.getLoc(), vecType, source, readIndices, padValue, - ArrayRef{readInBounds}); + // 3. Generate TransferReadOp + TransferWriteOp + ReifiedRankedShapedTypeDims reifiedSrcSizes; + Value maskOp; - // If vector sizes are user provided, make sure to mask xfer_read. + // If vector sizes are user provided, make sure to mask. First, generate the + // mask. if (!inputVectorSizes.empty()) { auto *srcDefOp = source.getDefiningOp(); if (!srcDefOp) { @@ -2732,40 +2729,47 @@ vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp, return failure(); } - ReifiedRankedShapedTypeDims reifiedSrcSizes; LogicalResult status = cast(srcDefOp).reifyResultShapes( rewriter, reifiedSrcSizes); if (status.failed()) { - LDBG("Unable to reify result shapes of " << sliceOp); + LDBG("Unable to reify result shapes of " << srcDefOp); return failure(); } // Create the mask - SmallVector readMaskShape( - sliceOp.getSource().getType().getShape()); auto readMaskType = VectorType::get(inputVectorSizes, rewriter.getI1Type()); - Value maskOp = rewriter.create( + maskOp = rewriter.create( sliceOp.getLoc(), readMaskType, reifiedSrcSizes[0]); - - // Mask the xfer_read Op - read = mlir::vector::maskOperation(rewriter, read, maskOp); } - // 4. Generate TransferWriteOp. - if (!inputVectorSizes.empty() && - ShapedType::isDynamicShape(resultType.getShape())) { - LDBG("TODO: Masking of xfer_write when vectorising " << sliceOp); - return failure(); + // 3.a. TransferReadOp + SmallVector readIndices( + vecType.getRank(), + rewriter.create(sliceOp.getLoc(), 0)); + Operation *read = rewriter.create( + sliceOp.getLoc(), vecType, source, readIndices, padValue, + ArrayRef{readInBounds}); + + // Mask the xfer_read Op + if (!inputVectorSizes.empty()) { + read = mlir::vector::maskOperation(rewriter, read, maskOp); } + // 3.b. TransferWriteOp auto writeIndices = getValueOrCreateConstantIndexOp( rewriter, sliceOp.getLoc(), sliceOp.getMixedOffsets()); - // 5. Finalize Operation *write = rewriter.create( sliceOp.getLoc(), read->getResult(0), sliceOp.getDest(), writeIndices, ArrayRef{writeInBounds}); + + // Mask the xfer_write Op + if (!inputVectorSizes.empty()) { + write = mlir::vector::maskOperation(rewriter, write, maskOp); + } + + // 4. Finalize newResults.push_back(write->getResult(0)); return success(); diff --git a/mlir/test/Dialect/Linalg/vectorization-unsupported.mlir b/mlir/test/Dialect/Linalg/vectorization-unsupported.mlir index be0180fcf1763..8fbc74ec345c6 100644 --- a/mlir/test/Dialect/Linalg/vectorization-unsupported.mlir +++ b/mlir/test/Dialect/Linalg/vectorization-unsupported.mlir @@ -280,26 +280,3 @@ module attributes {transform.with_named_sequence} { transform.yield } } - -// ----- - -// One of the _destination_ dimensions is dynamic (but _source_ dimensions are static). - -func.func private @insert_slice_dynamic_dest_dim(%source: tensor, %size: index) -> tensor { - %c2 = arith.constant 2 : index - %init = tensor.empty(%size) : tensor - - %source_slice = tensor.extract_slice %source[0, %c2, 0, 0] [1, 1, 5, 1] [1, 1, 1, 1] : tensor to tensor<5x1xi32> - // expected-error @+1 {{Attempted to vectorize, but failed}} - %res = tensor.insert_slice %source_slice into %init[0, %c2] [5, 1] [1, 1] : tensor<5x1xi32> into tensor - - return %res : tensor -} - - module attributes {transform.with_named_sequence} { - transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { - %0 = transform.structured.match ops{["tensor.insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op - transform.structured.vectorize %0 vector_sizes [8, 1] : !transform.any_op - transform.yield - } - } diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir index a660144ab87fb..ea7a03b08d8d6 100644 --- a/mlir/test/Dialect/Linalg/vectorization.mlir +++ b/mlir/test/Dialect/Linalg/vectorization.mlir @@ -1130,14 +1130,14 @@ func.func private @insert_slice_static_sizes(%source: tensor) -> te // CHECK: %[[C_2:.*]] = arith.constant 2 : index // CHECK: %[[INIT:.*]] = tensor.empty() : tensor<5x3xi32> // CHECK: %[[SRC_SLICE:.*]] = tensor.extract_slice %[[SEC]][0, %[[C_2]], 0, 0] [1, 1, 5, 1] [1, 1, 1, 1] : tensor to tensor<5x1xi32> -// CHECK: %[[PAD:.*]] = arith.constant 0 : i32 -// CHECK: %[[C0:.*]] = arith.constant 0 : index -// CHECK: %[[C_5:.*]] = arith.constant 5 : index -// CHECK: %[[C_1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[PAD:.*]] = arith.constant 0 : i32 +// CHECK-DAG: %[[C_5:.*]] = arith.constant 5 : index +// CHECK-DAG: %[[C_1:.*]] = arith.constant 1 : index // CHECK: %[[MASK:.*]] = vector.create_mask %[[C_5]], %[[C_1]] : vector<8x1xi1> +// CHECK: %[[C0:.*]] = arith.constant 0 : index // CHECK: %[[READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[SRC_SLICE]][%[[C0]], %[[C0]]], %[[PAD]] : tensor<5x1xi32>, vector<8x1xi32> } : vector<8x1xi1> -> vector<8x1xi32> // CHECK: %[[C_0:.*]] = arith.constant 0 : index -// CHECK: %[[RES:.*]] = vector.transfer_write %[[READ]], %[[INIT]][%[[C_0]], %[[C_2]]] : vector<8x1xi32>, tensor<5x3xi32> +// CHECK: %[[RES:.*]] = vector.mask %[[MASK]] { vector.transfer_write %[[READ]], %[[INIT]][%[[C_0]], %[[C_2]]] : vector<8x1xi32>, tensor<5x3xi32> } : vector<8x1xi1> -> tensor<5x3xi32> // CHECK: return %[[RES]] : tensor<5x3xi32> module attributes {transform.with_named_sequence} { @@ -1170,11 +1170,11 @@ func.func private @insert_slice_dynamic_src_dim(%source: tensor, %s // CHECK: %[[SRC_SLICE:.*]] = tensor.extract_slice %[[SRC]][0, %[[C_2]], 0, 0] [1, 1, %[[SIZE]], 1] [1, 1, 1, 1] : tensor to tensor // CHECK-DAG: %[[PAD:.*]] = arith.constant 0 : i32 // CHECK-DAG: %[[C_1:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[C_0:.*]] = arith.constant 0 : index // CHECK: %[[MASK:.*]] = vector.create_mask %[[SIZE]], %[[C_1]] : vector<8x1xi1> +// CHECK: %[[C_0:.*]] = arith.constant 0 : index // CHECK: %[[READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[SRC_SLICE]][%[[C_0]], %[[C_0]]], %[[PAD]] : tensor, vector<8x1xi32> } : vector<8x1xi1> -> vector<8x1xi32> // CHECK: %[[C_0_1:.*]] = arith.constant 0 : index -// CHECK: %[[RES:.*]] = vector.transfer_write %[[READ]], %[[INIT]][%[[C_0_1]], %[[C_2]]] : vector<8x1xi32>, tensor<5x3xi32> +// CHECK: %[[RES:.*]] = vector.mask %[[MASK]] { vector.transfer_write %[[READ]], %[[INIT]][%[[C_0_1]], %[[C_2]]] : vector<8x1xi32>, tensor<5x3xi32> } : vector<8x1xi1> -> tensor<5x3xi32> // CHECK: return %[[RES]] : tensor<5x3xi32> module attributes {transform.with_named_sequence} { @@ -1184,3 +1184,78 @@ func.func private @insert_slice_dynamic_src_dim(%source: tensor, %s transform.yield } } + +// ----- + +// One of the _destination_ dimensions is dynamic (but _source_ dimensions are static). + +func.func private @insert_slice_dynamic_dest_dim(%source: tensor, %size: index) -> tensor { + %c2 = arith.constant 2 : index + %init = tensor.empty(%size) : tensor + + %source_slice = tensor.extract_slice %source[0, %c2, 0, 0] [1, 1, 5, 1] [1, 1, 1, 1] : tensor to tensor<5x1xi32> + %res = tensor.insert_slice %source_slice into %init[0, %c2] [5, 1] [1, 1] : tensor<5x1xi32> into tensor + + return %res : tensor +} + +// CHECK-LABEL: func.func private @insert_slice_dynamic_dest_dim( +// CHECK-SAME: %[[SRC:.*]]: tensor, +// CHECK-SAME: %[[size:.*]]: index) -> tensor { +// CHECK: %[[C_2:.*]] = arith.constant 2 : index +// CHECK: %[[INIT:.*]] = tensor.empty(%[[size]]) : tensor +// CHECK: %[[SRC_SLICE:.*]] = tensor.extract_slice %[[SRC]][0, %[[C_2]], 0, 0] [1, 1, 5, 1] [1, 1, 1, 1] : tensor to tensor<5x1xi32> +// CHECK: %[[PAD:.*]] = arith.constant 0 : i32 +// CHECK: %[[C_5:.*]] = arith.constant 5 : index +// CHECK: %[[C_1:.*]] = arith.constant 1 : index +// CHECK: %[[MASK:.*]] = vector.create_mask %[[C_5]], %[[C_1]] : vector<8x1xi1> +// CHECK: %[[C_0:.*]] = arith.constant 0 : index +// CHECK: %[[READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[SRC_SLICE]][%[[C_0]], %[[C_0]]], %[[PAD]] : tensor<5x1xi32>, vector<8x1xi32> } : vector<8x1xi1> -> vector<8x1xi32> +// CHECK: %[[C_0_1:.*]] = arith.constant 0 : index +// CHECK: %[[WRITE:.*]] = vector.mask %[[MASK]] { vector.transfer_write %[[READ]], %[[INIT]][%[[C_0_1]], %[[C_2]]] : vector<8x1xi32>, tensor } : vector<8x1xi1> -> tensor +// CHECK: return %[[WRITE]] : tensor + + module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["tensor.insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op + transform.structured.vectorize %0 vector_sizes [8, 1] : !transform.any_op + transform.yield + } + } + +// ----- + +// At least one _source_ and one _destination_ dimensions are dynamic. + +func.func private @insert_slice_dynamic_source_and_dest_dim(%source: tensor, %size: index) -> tensor { + %c2 = arith.constant 2 : index + %init = tensor.empty(%size) : tensor + + %source_slice = tensor.extract_slice %source[0, %c2, 0, 0] [1, 1, %size, 1] [1, 1, 1, 1] : tensor to tensor + %res = tensor.insert_slice %source_slice into %init[0, %c2] [%size, 1] [1, 1] : tensor into tensor + + return %res : tensor +} + +// CHECK-LABEL: func.func private @insert_slice_dynamic_source_and_dest_dim( +// CHECK-SAME: %[[SRC:.*]]: tensor, +// CHECK-SAME: %[[SIZE:.*]]: index) -> tensor { +// CHECK: %[[C_2:.*]] = arith.constant 2 : index +// CHECK: %[[INIT:.*]] = tensor.empty(%[[SIZE]]) : tensor +// CHECK: %[[SRC_SIZE:.*]] = tensor.extract_slice %[[SRC]][0, %[[C_2]], 0, 0] [1, 1, %[[SIZE]], 1] [1, 1, 1, 1] : tensor to tensor +// CHECK: %[[PAD:.*]] = arith.constant 0 : i32 +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[MASK:.*]] = vector.create_mask %[[SIZE]], %[[C1]] : vector<8x1xi1> +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[SRC_SIZE]]{{\[}}%[[C0]], %[[C0]]], %[[PAD]] : tensor, vector<8x1xi32> } : vector<8x1xi1> -> vector<8x1xi32> +// CHECK: %[[C_0_1:.*]] = arith.constant 0 : index +// CHECK: %[[WRITE:.*]] = vector.mask %[[MASK]] { vector.transfer_write %[[READ]], %[[INIT]]{{\[}}%[[C_0_1]], %[[C_2]]] : vector<8x1xi32>, tensor } : vector<8x1xi1> -> tensor +// CHECK: return %[[WRITE]] : tensor + + module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["tensor.insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op + transform.structured.vectorize %0 vector_sizes [8, 1] : !transform.any_op + transform.yield + } + } From 814379d5dd09951b157d47fe43c476ac76441666 Mon Sep 17 00:00:00 2001 From: Andrzej Warzynski Date: Wed, 5 Feb 2025 17:20:19 +0000 Subject: [PATCH 2/3] fixup! [mlir][linalg] Add support for masked vectorization of `tensor.insert_slice` (2/N) Address comment from Hanhan --- mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp index 1b50c24fbfcfc..f2c23c49a78e8 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -2743,7 +2743,6 @@ vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp, sliceOp.getLoc(), readMaskType, reifiedSrcSizes[0]); } - // 3.a. TransferReadOp SmallVector readIndices( vecType.getRank(), rewriter.create(sliceOp.getLoc(), 0)); @@ -2751,12 +2750,10 @@ vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp, sliceOp.getLoc(), vecType, source, readIndices, padValue, ArrayRef{readInBounds}); - // Mask the xfer_read Op - if (!inputVectorSizes.empty()) { + if (maskOp) { read = mlir::vector::maskOperation(rewriter, read, maskOp); } - // 3.b. TransferWriteOp auto writeIndices = getValueOrCreateConstantIndexOp( rewriter, sliceOp.getLoc(), sliceOp.getMixedOffsets()); @@ -2764,8 +2761,7 @@ vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp, sliceOp.getLoc(), read->getResult(0), sliceOp.getDest(), writeIndices, ArrayRef{writeInBounds}); - // Mask the xfer_write Op - if (!inputVectorSizes.empty()) { + if (maskOp) { write = mlir::vector::maskOperation(rewriter, write, maskOp); } From 48d16575263edf92ccf7533a37fc313538f6d231 Mon Sep 17 00:00:00 2001 From: Andrzej Warzynski Date: Thu, 6 Feb 2025 18:13:31 +0000 Subject: [PATCH 3/3] fixup! fixup! [mlir][linalg] Add support for masked vectorization of `tensor.insert_slice` (2/N) Fix capitalisation --- mlir/test/Dialect/Linalg/vectorization.mlir | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir index ea7a03b08d8d6..6d39262945de5 100644 --- a/mlir/test/Dialect/Linalg/vectorization.mlir +++ b/mlir/test/Dialect/Linalg/vectorization.mlir @@ -1201,9 +1201,9 @@ func.func private @insert_slice_dynamic_dest_dim(%source: tensor, % // CHECK-LABEL: func.func private @insert_slice_dynamic_dest_dim( // CHECK-SAME: %[[SRC:.*]]: tensor, -// CHECK-SAME: %[[size:.*]]: index) -> tensor { +// CHECK-SAME: %[[SIZE:.*]]: index) -> tensor { // CHECK: %[[C_2:.*]] = arith.constant 2 : index -// CHECK: %[[INIT:.*]] = tensor.empty(%[[size]]) : tensor +// CHECK: %[[INIT:.*]] = tensor.empty(%[[SIZE]]) : tensor // CHECK: %[[SRC_SLICE:.*]] = tensor.extract_slice %[[SRC]][0, %[[C_2]], 0, 0] [1, 1, 5, 1] [1, 1, 1, 1] : tensor to tensor<5x1xi32> // CHECK: %[[PAD:.*]] = arith.constant 0 : i32 // CHECK: %[[C_5:.*]] = arith.constant 5 : index