diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp index a43fa86166e83..f2c23c49a78e8 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -2716,15 +2716,12 @@ vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp, } auto vecType = VectorType::get(vecShape, sourceType.getElementType()); - // 3. Generate TransferReadOp. - SmallVector readIndices( - vecType.getRank(), - rewriter.create(sliceOp.getLoc(), 0)); - Operation *read = rewriter.create( - sliceOp.getLoc(), vecType, source, readIndices, padValue, - ArrayRef{readInBounds}); + // 3. Generate TransferReadOp + TransferWriteOp + ReifiedRankedShapedTypeDims reifiedSrcSizes; + Value maskOp; - // If vector sizes are user provided, make sure to mask xfer_read. + // If vector sizes are user provided, make sure to mask. First, generate the + // mask. if (!inputVectorSizes.empty()) { auto *srcDefOp = source.getDefiningOp(); if (!srcDefOp) { @@ -2732,40 +2729,43 @@ vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp, return failure(); } - ReifiedRankedShapedTypeDims reifiedSrcSizes; LogicalResult status = cast(srcDefOp).reifyResultShapes( rewriter, reifiedSrcSizes); if (status.failed()) { - LDBG("Unable to reify result shapes of " << sliceOp); + LDBG("Unable to reify result shapes of " << srcDefOp); return failure(); } // Create the mask - SmallVector readMaskShape( - sliceOp.getSource().getType().getShape()); auto readMaskType = VectorType::get(inputVectorSizes, rewriter.getI1Type()); - Value maskOp = rewriter.create( + maskOp = rewriter.create( sliceOp.getLoc(), readMaskType, reifiedSrcSizes[0]); - - // Mask the xfer_read Op - read = mlir::vector::maskOperation(rewriter, read, maskOp); } - // 4. Generate TransferWriteOp. - if (!inputVectorSizes.empty() && - ShapedType::isDynamicShape(resultType.getShape())) { - LDBG("TODO: Masking of xfer_write when vectorising " << sliceOp); - return failure(); + SmallVector readIndices( + vecType.getRank(), + rewriter.create(sliceOp.getLoc(), 0)); + Operation *read = rewriter.create( + sliceOp.getLoc(), vecType, source, readIndices, padValue, + ArrayRef{readInBounds}); + + if (maskOp) { + read = mlir::vector::maskOperation(rewriter, read, maskOp); } auto writeIndices = getValueOrCreateConstantIndexOp( rewriter, sliceOp.getLoc(), sliceOp.getMixedOffsets()); - // 5. Finalize Operation *write = rewriter.create( sliceOp.getLoc(), read->getResult(0), sliceOp.getDest(), writeIndices, ArrayRef{writeInBounds}); + + if (maskOp) { + write = mlir::vector::maskOperation(rewriter, write, maskOp); + } + + // 4. Finalize newResults.push_back(write->getResult(0)); return success(); diff --git a/mlir/test/Dialect/Linalg/vectorization-unsupported.mlir b/mlir/test/Dialect/Linalg/vectorization-unsupported.mlir index be0180fcf1763..8fbc74ec345c6 100644 --- a/mlir/test/Dialect/Linalg/vectorization-unsupported.mlir +++ b/mlir/test/Dialect/Linalg/vectorization-unsupported.mlir @@ -280,26 +280,3 @@ module attributes {transform.with_named_sequence} { transform.yield } } - -// ----- - -// One of the _destination_ dimensions is dynamic (but _source_ dimensions are static). - -func.func private @insert_slice_dynamic_dest_dim(%source: tensor, %size: index) -> tensor { - %c2 = arith.constant 2 : index - %init = tensor.empty(%size) : tensor - - %source_slice = tensor.extract_slice %source[0, %c2, 0, 0] [1, 1, 5, 1] [1, 1, 1, 1] : tensor to tensor<5x1xi32> - // expected-error @+1 {{Attempted to vectorize, but failed}} - %res = tensor.insert_slice %source_slice into %init[0, %c2] [5, 1] [1, 1] : tensor<5x1xi32> into tensor - - return %res : tensor -} - - module attributes {transform.with_named_sequence} { - transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { - %0 = transform.structured.match ops{["tensor.insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op - transform.structured.vectorize %0 vector_sizes [8, 1] : !transform.any_op - transform.yield - } - } diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir index a660144ab87fb..6d39262945de5 100644 --- a/mlir/test/Dialect/Linalg/vectorization.mlir +++ b/mlir/test/Dialect/Linalg/vectorization.mlir @@ -1130,14 +1130,14 @@ func.func private @insert_slice_static_sizes(%source: tensor) -> te // CHECK: %[[C_2:.*]] = arith.constant 2 : index // CHECK: %[[INIT:.*]] = tensor.empty() : tensor<5x3xi32> // CHECK: %[[SRC_SLICE:.*]] = tensor.extract_slice %[[SEC]][0, %[[C_2]], 0, 0] [1, 1, 5, 1] [1, 1, 1, 1] : tensor to tensor<5x1xi32> -// CHECK: %[[PAD:.*]] = arith.constant 0 : i32 -// CHECK: %[[C0:.*]] = arith.constant 0 : index -// CHECK: %[[C_5:.*]] = arith.constant 5 : index -// CHECK: %[[C_1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[PAD:.*]] = arith.constant 0 : i32 +// CHECK-DAG: %[[C_5:.*]] = arith.constant 5 : index +// CHECK-DAG: %[[C_1:.*]] = arith.constant 1 : index // CHECK: %[[MASK:.*]] = vector.create_mask %[[C_5]], %[[C_1]] : vector<8x1xi1> +// CHECK: %[[C0:.*]] = arith.constant 0 : index // CHECK: %[[READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[SRC_SLICE]][%[[C0]], %[[C0]]], %[[PAD]] : tensor<5x1xi32>, vector<8x1xi32> } : vector<8x1xi1> -> vector<8x1xi32> // CHECK: %[[C_0:.*]] = arith.constant 0 : index -// CHECK: %[[RES:.*]] = vector.transfer_write %[[READ]], %[[INIT]][%[[C_0]], %[[C_2]]] : vector<8x1xi32>, tensor<5x3xi32> +// CHECK: %[[RES:.*]] = vector.mask %[[MASK]] { vector.transfer_write %[[READ]], %[[INIT]][%[[C_0]], %[[C_2]]] : vector<8x1xi32>, tensor<5x3xi32> } : vector<8x1xi1> -> tensor<5x3xi32> // CHECK: return %[[RES]] : tensor<5x3xi32> module attributes {transform.with_named_sequence} { @@ -1170,11 +1170,11 @@ func.func private @insert_slice_dynamic_src_dim(%source: tensor, %s // CHECK: %[[SRC_SLICE:.*]] = tensor.extract_slice %[[SRC]][0, %[[C_2]], 0, 0] [1, 1, %[[SIZE]], 1] [1, 1, 1, 1] : tensor to tensor // CHECK-DAG: %[[PAD:.*]] = arith.constant 0 : i32 // CHECK-DAG: %[[C_1:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[C_0:.*]] = arith.constant 0 : index // CHECK: %[[MASK:.*]] = vector.create_mask %[[SIZE]], %[[C_1]] : vector<8x1xi1> +// CHECK: %[[C_0:.*]] = arith.constant 0 : index // CHECK: %[[READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[SRC_SLICE]][%[[C_0]], %[[C_0]]], %[[PAD]] : tensor, vector<8x1xi32> } : vector<8x1xi1> -> vector<8x1xi32> // CHECK: %[[C_0_1:.*]] = arith.constant 0 : index -// CHECK: %[[RES:.*]] = vector.transfer_write %[[READ]], %[[INIT]][%[[C_0_1]], %[[C_2]]] : vector<8x1xi32>, tensor<5x3xi32> +// CHECK: %[[RES:.*]] = vector.mask %[[MASK]] { vector.transfer_write %[[READ]], %[[INIT]][%[[C_0_1]], %[[C_2]]] : vector<8x1xi32>, tensor<5x3xi32> } : vector<8x1xi1> -> tensor<5x3xi32> // CHECK: return %[[RES]] : tensor<5x3xi32> module attributes {transform.with_named_sequence} { @@ -1184,3 +1184,78 @@ func.func private @insert_slice_dynamic_src_dim(%source: tensor, %s transform.yield } } + +// ----- + +// One of the _destination_ dimensions is dynamic (but _source_ dimensions are static). + +func.func private @insert_slice_dynamic_dest_dim(%source: tensor, %size: index) -> tensor { + %c2 = arith.constant 2 : index + %init = tensor.empty(%size) : tensor + + %source_slice = tensor.extract_slice %source[0, %c2, 0, 0] [1, 1, 5, 1] [1, 1, 1, 1] : tensor to tensor<5x1xi32> + %res = tensor.insert_slice %source_slice into %init[0, %c2] [5, 1] [1, 1] : tensor<5x1xi32> into tensor + + return %res : tensor +} + +// CHECK-LABEL: func.func private @insert_slice_dynamic_dest_dim( +// CHECK-SAME: %[[SRC:.*]]: tensor, +// CHECK-SAME: %[[SIZE:.*]]: index) -> tensor { +// CHECK: %[[C_2:.*]] = arith.constant 2 : index +// CHECK: %[[INIT:.*]] = tensor.empty(%[[SIZE]]) : tensor +// CHECK: %[[SRC_SLICE:.*]] = tensor.extract_slice %[[SRC]][0, %[[C_2]], 0, 0] [1, 1, 5, 1] [1, 1, 1, 1] : tensor to tensor<5x1xi32> +// CHECK: %[[PAD:.*]] = arith.constant 0 : i32 +// CHECK: %[[C_5:.*]] = arith.constant 5 : index +// CHECK: %[[C_1:.*]] = arith.constant 1 : index +// CHECK: %[[MASK:.*]] = vector.create_mask %[[C_5]], %[[C_1]] : vector<8x1xi1> +// CHECK: %[[C_0:.*]] = arith.constant 0 : index +// CHECK: %[[READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[SRC_SLICE]][%[[C_0]], %[[C_0]]], %[[PAD]] : tensor<5x1xi32>, vector<8x1xi32> } : vector<8x1xi1> -> vector<8x1xi32> +// CHECK: %[[C_0_1:.*]] = arith.constant 0 : index +// CHECK: %[[WRITE:.*]] = vector.mask %[[MASK]] { vector.transfer_write %[[READ]], %[[INIT]][%[[C_0_1]], %[[C_2]]] : vector<8x1xi32>, tensor } : vector<8x1xi1> -> tensor +// CHECK: return %[[WRITE]] : tensor + + module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["tensor.insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op + transform.structured.vectorize %0 vector_sizes [8, 1] : !transform.any_op + transform.yield + } + } + +// ----- + +// At least one _source_ and one _destination_ dimensions are dynamic. + +func.func private @insert_slice_dynamic_source_and_dest_dim(%source: tensor, %size: index) -> tensor { + %c2 = arith.constant 2 : index + %init = tensor.empty(%size) : tensor + + %source_slice = tensor.extract_slice %source[0, %c2, 0, 0] [1, 1, %size, 1] [1, 1, 1, 1] : tensor to tensor + %res = tensor.insert_slice %source_slice into %init[0, %c2] [%size, 1] [1, 1] : tensor into tensor + + return %res : tensor +} + +// CHECK-LABEL: func.func private @insert_slice_dynamic_source_and_dest_dim( +// CHECK-SAME: %[[SRC:.*]]: tensor, +// CHECK-SAME: %[[SIZE:.*]]: index) -> tensor { +// CHECK: %[[C_2:.*]] = arith.constant 2 : index +// CHECK: %[[INIT:.*]] = tensor.empty(%[[SIZE]]) : tensor +// CHECK: %[[SRC_SIZE:.*]] = tensor.extract_slice %[[SRC]][0, %[[C_2]], 0, 0] [1, 1, %[[SIZE]], 1] [1, 1, 1, 1] : tensor to tensor +// CHECK: %[[PAD:.*]] = arith.constant 0 : i32 +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[MASK:.*]] = vector.create_mask %[[SIZE]], %[[C1]] : vector<8x1xi1> +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[SRC_SIZE]]{{\[}}%[[C0]], %[[C0]]], %[[PAD]] : tensor, vector<8x1xi32> } : vector<8x1xi1> -> vector<8x1xi32> +// CHECK: %[[C_0_1:.*]] = arith.constant 0 : index +// CHECK: %[[WRITE:.*]] = vector.mask %[[MASK]] { vector.transfer_write %[[READ]], %[[INIT]]{{\[}}%[[C_0_1]], %[[C_2]]] : vector<8x1xi32>, tensor } : vector<8x1xi1> -> tensor +// CHECK: return %[[WRITE]] : tensor + + module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["tensor.insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op + transform.structured.vectorize %0 vector_sizes [8, 1] : !transform.any_op + transform.yield + } + }