-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[mlir][linalg] Add support for masked vectorization of tensor.insert_slice (2/N)
#123031
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
1abdf4f
814379d
48d1657
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2716,56 +2716,56 @@ vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp, | |
| } | ||
| auto vecType = VectorType::get(vecShape, sourceType.getElementType()); | ||
|
|
||
| // 3. Generate TransferReadOp. | ||
| SmallVector<Value> readIndices( | ||
| vecType.getRank(), | ||
| rewriter.create<arith::ConstantIndexOp>(sliceOp.getLoc(), 0)); | ||
| Operation *read = rewriter.create<vector::TransferReadOp>( | ||
| sliceOp.getLoc(), vecType, source, readIndices, padValue, | ||
| ArrayRef<bool>{readInBounds}); | ||
| // 3. Generate TransferReadOp + TransferWriteOp | ||
| ReifiedRankedShapedTypeDims reifiedSrcSizes; | ||
| Value maskOp; | ||
|
|
||
| // If vector sizes are user provided, make sure to mask xfer_read. | ||
| // If vector sizes are user provided, make sure to mask. First, generate the | ||
| // mask. | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Couldn't user-provided vector sizes lead to an unmasked scenario? We have a method that checks if mask is needed here (can't remember the name right now). Couldn't use it for this case?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yup, that's |
||
| if (!inputVectorSizes.empty()) { | ||
| auto *srcDefOp = source.getDefiningOp(); | ||
| if (!srcDefOp) { | ||
| LDBG("Unable to get the defining Op of " << sliceOp); | ||
| return failure(); | ||
| } | ||
|
|
||
| ReifiedRankedShapedTypeDims reifiedSrcSizes; | ||
| LogicalResult status = | ||
| cast<ReifyRankedShapedTypeOpInterface>(srcDefOp).reifyResultShapes( | ||
| rewriter, reifiedSrcSizes); | ||
| if (status.failed()) { | ||
| LDBG("Unable to reify result shapes of " << sliceOp); | ||
| LDBG("Unable to reify result shapes of " << srcDefOp); | ||
| return failure(); | ||
| } | ||
|
|
||
| // Create the mask | ||
| SmallVector<int64_t> readMaskShape( | ||
| sliceOp.getSource().getType().getShape()); | ||
| auto readMaskType = VectorType::get(inputVectorSizes, rewriter.getI1Type()); | ||
| Value maskOp = rewriter.create<vector::CreateMaskOp>( | ||
| maskOp = rewriter.create<vector::CreateMaskOp>( | ||
| sliceOp.getLoc(), readMaskType, reifiedSrcSizes[0]); | ||
|
|
||
| // Mask the xfer_read Op | ||
| read = mlir::vector::maskOperation(rewriter, read, maskOp); | ||
| } | ||
|
|
||
| // 4. Generate TransferWriteOp. | ||
| if (!inputVectorSizes.empty() && | ||
| ShapedType::isDynamicShape(resultType.getShape())) { | ||
| LDBG("TODO: Masking of xfer_write when vectorising " << sliceOp); | ||
| return failure(); | ||
| SmallVector<Value> readIndices( | ||
| vecType.getRank(), | ||
| rewriter.create<arith::ConstantIndexOp>(sliceOp.getLoc(), 0)); | ||
| Operation *read = rewriter.create<vector::TransferReadOp>( | ||
| sliceOp.getLoc(), vecType, source, readIndices, padValue, | ||
| ArrayRef<bool>{readInBounds}); | ||
|
|
||
| if (maskOp) { | ||
| read = mlir::vector::maskOperation(rewriter, read, maskOp); | ||
| } | ||
|
|
||
| auto writeIndices = getValueOrCreateConstantIndexOp( | ||
| rewriter, sliceOp.getLoc(), sliceOp.getMixedOffsets()); | ||
|
|
||
| // 5. Finalize | ||
| Operation *write = rewriter.create<vector::TransferWriteOp>( | ||
| sliceOp.getLoc(), read->getResult(0), sliceOp.getDest(), writeIndices, | ||
| ArrayRef<bool>{writeInBounds}); | ||
|
|
||
| if (maskOp) { | ||
| write = mlir::vector::maskOperation(rewriter, write, maskOp); | ||
| } | ||
|
|
||
| // 4. Finalize | ||
| newResults.push_back(write->getResult(0)); | ||
|
|
||
| return success(); | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1130,14 +1130,14 @@ func.func private @insert_slice_static_sizes(%source: tensor<?x3x?x1xi32>) -> te | |
| // CHECK: %[[C_2:.*]] = arith.constant 2 : index | ||
| // CHECK: %[[INIT:.*]] = tensor.empty() : tensor<5x3xi32> | ||
| // CHECK: %[[SRC_SLICE:.*]] = tensor.extract_slice %[[SEC]][0, %[[C_2]], 0, 0] [1, 1, 5, 1] [1, 1, 1, 1] : tensor<?x3x?x1xi32> to tensor<5x1xi32> | ||
| // CHECK: %[[PAD:.*]] = arith.constant 0 : i32 | ||
| // CHECK: %[[C0:.*]] = arith.constant 0 : index | ||
| // CHECK: %[[C_5:.*]] = arith.constant 5 : index | ||
| // CHECK: %[[C_1:.*]] = arith.constant 1 : index | ||
| // CHECK-DAG: %[[PAD:.*]] = arith.constant 0 : i32 | ||
| // CHECK-DAG: %[[C_5:.*]] = arith.constant 5 : index | ||
| // CHECK-DAG: %[[C_1:.*]] = arith.constant 1 : index | ||
| // CHECK: %[[MASK:.*]] = vector.create_mask %[[C_5]], %[[C_1]] : vector<8x1xi1> | ||
| // CHECK: %[[C0:.*]] = arith.constant 0 : index | ||
| // CHECK: %[[READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[SRC_SLICE]][%[[C0]], %[[C0]]], %[[PAD]] : tensor<5x1xi32>, vector<8x1xi32> } : vector<8x1xi1> -> vector<8x1xi32> | ||
| // CHECK: %[[C_0:.*]] = arith.constant 0 : index | ||
| // CHECK: %[[RES:.*]] = vector.transfer_write %[[READ]], %[[INIT]][%[[C_0]], %[[C_2]]] : vector<8x1xi32>, tensor<5x3xi32> | ||
| // CHECK: %[[RES:.*]] = vector.mask %[[MASK]] { vector.transfer_write %[[READ]], %[[INIT]][%[[C_0]], %[[C_2]]] : vector<8x1xi32>, tensor<5x3xi32> } : vector<8x1xi1> -> tensor<5x3xi32> | ||
| // CHECK: return %[[RES]] : tensor<5x3xi32> | ||
|
|
||
| module attributes {transform.with_named_sequence} { | ||
|
|
@@ -1170,11 +1170,11 @@ func.func private @insert_slice_dynamic_src_dim(%source: tensor<?x3x?x1xi32>, %s | |
| // CHECK: %[[SRC_SLICE:.*]] = tensor.extract_slice %[[SRC]][0, %[[C_2]], 0, 0] [1, 1, %[[SIZE]], 1] [1, 1, 1, 1] : tensor<?x3x?x1xi32> to tensor<?x1xi32> | ||
| // CHECK-DAG: %[[PAD:.*]] = arith.constant 0 : i32 | ||
| // CHECK-DAG: %[[C_1:.*]] = arith.constant 1 : index | ||
| // CHECK-DAG: %[[C_0:.*]] = arith.constant 0 : index | ||
| // CHECK: %[[MASK:.*]] = vector.create_mask %[[SIZE]], %[[C_1]] : vector<8x1xi1> | ||
| // CHECK: %[[C_0:.*]] = arith.constant 0 : index | ||
| // CHECK: %[[READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[SRC_SLICE]][%[[C_0]], %[[C_0]]], %[[PAD]] : tensor<?x1xi32>, vector<8x1xi32> } : vector<8x1xi1> -> vector<8x1xi32> | ||
| // CHECK: %[[C_0_1:.*]] = arith.constant 0 : index | ||
| // CHECK: %[[RES:.*]] = vector.transfer_write %[[READ]], %[[INIT]][%[[C_0_1]], %[[C_2]]] : vector<8x1xi32>, tensor<5x3xi32> | ||
| // CHECK: %[[RES:.*]] = vector.mask %[[MASK]] { vector.transfer_write %[[READ]], %[[INIT]][%[[C_0_1]], %[[C_2]]] : vector<8x1xi32>, tensor<5x3xi32> } : vector<8x1xi1> -> tensor<5x3xi32> | ||
| // CHECK: return %[[RES]] : tensor<5x3xi32> | ||
|
|
||
| module attributes {transform.with_named_sequence} { | ||
|
|
@@ -1184,3 +1184,78 @@ func.func private @insert_slice_dynamic_src_dim(%source: tensor<?x3x?x1xi32>, %s | |
| transform.yield | ||
| } | ||
| } | ||
|
|
||
| // ----- | ||
|
|
||
| // One of the _destination_ dimensions is dynamic (but _source_ dimensions are static). | ||
|
|
||
| func.func private @insert_slice_dynamic_dest_dim(%source: tensor<?x3x?x1xi32>, %size: index) -> tensor<?x3xi32> { | ||
| %c2 = arith.constant 2 : index | ||
| %init = tensor.empty(%size) : tensor<?x3xi32> | ||
|
|
||
| %source_slice = tensor.extract_slice %source[0, %c2, 0, 0] [1, 1, 5, 1] [1, 1, 1, 1] : tensor<?x3x?x1xi32> to tensor<5x1xi32> | ||
| %res = tensor.insert_slice %source_slice into %init[0, %c2] [5, 1] [1, 1] : tensor<5x1xi32> into tensor<?x3xi32> | ||
|
|
||
| return %res : tensor<?x3xi32> | ||
| } | ||
|
|
||
| // CHECK-LABEL: func.func private @insert_slice_dynamic_dest_dim( | ||
| // CHECK-SAME: %[[SRC:.*]]: tensor<?x3x?x1xi32>, | ||
| // CHECK-SAME: %[[size:.*]]: index) -> tensor<?x3xi32> { | ||
|
||
| // CHECK: %[[C_2:.*]] = arith.constant 2 : index | ||
| // CHECK: %[[INIT:.*]] = tensor.empty(%[[size]]) : tensor<?x3xi32> | ||
| // CHECK: %[[SRC_SLICE:.*]] = tensor.extract_slice %[[SRC]][0, %[[C_2]], 0, 0] [1, 1, 5, 1] [1, 1, 1, 1] : tensor<?x3x?x1xi32> to tensor<5x1xi32> | ||
| // CHECK: %[[PAD:.*]] = arith.constant 0 : i32 | ||
| // CHECK: %[[C_5:.*]] = arith.constant 5 : index | ||
| // CHECK: %[[C_1:.*]] = arith.constant 1 : index | ||
| // CHECK: %[[MASK:.*]] = vector.create_mask %[[C_5]], %[[C_1]] : vector<8x1xi1> | ||
| // CHECK: %[[C_0:.*]] = arith.constant 0 : index | ||
| // CHECK: %[[READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[SRC_SLICE]][%[[C_0]], %[[C_0]]], %[[PAD]] : tensor<5x1xi32>, vector<8x1xi32> } : vector<8x1xi1> -> vector<8x1xi32> | ||
| // CHECK: %[[C_0_1:.*]] = arith.constant 0 : index | ||
| // CHECK: %[[WRITE:.*]] = vector.mask %[[MASK]] { vector.transfer_write %[[READ]], %[[INIT]][%[[C_0_1]], %[[C_2]]] : vector<8x1xi32>, tensor<?x3xi32> } : vector<8x1xi1> -> tensor<?x3xi32> | ||
| // CHECK: return %[[WRITE]] : tensor<?x3xi32> | ||
|
|
||
| module attributes {transform.with_named_sequence} { | ||
| transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { | ||
| %0 = transform.structured.match ops{["tensor.insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op | ||
| transform.structured.vectorize %0 vector_sizes [8, 1] : !transform.any_op | ||
| transform.yield | ||
| } | ||
| } | ||
|
|
||
| // ----- | ||
|
|
||
| // At least one _source_ and one _destination_ dimensions are dynamic. | ||
|
|
||
| func.func private @insert_slice_dynamic_source_and_dest_dim(%source: tensor<?x3x?x1xi32>, %size: index) -> tensor<?x3xi32> { | ||
| %c2 = arith.constant 2 : index | ||
| %init = tensor.empty(%size) : tensor<?x3xi32> | ||
|
|
||
| %source_slice = tensor.extract_slice %source[0, %c2, 0, 0] [1, 1, %size, 1] [1, 1, 1, 1] : tensor<?x3x?x1xi32> to tensor<?x1xi32> | ||
| %res = tensor.insert_slice %source_slice into %init[0, %c2] [%size, 1] [1, 1] : tensor<?x1xi32> into tensor<?x3xi32> | ||
|
|
||
| return %res : tensor<?x3xi32> | ||
| } | ||
|
|
||
| // CHECK-LABEL: func.func private @insert_slice_dynamic_source_and_dest_dim( | ||
| // CHECK-SAME: %[[SRC:.*]]: tensor<?x3x?x1xi32>, | ||
| // CHECK-SAME: %[[SIZE:.*]]: index) -> tensor<?x3xi32> { | ||
| // CHECK: %[[C_2:.*]] = arith.constant 2 : index | ||
| // CHECK: %[[INIT:.*]] = tensor.empty(%[[SIZE]]) : tensor<?x3xi32> | ||
| // CHECK: %[[SRC_SIZE:.*]] = tensor.extract_slice %[[SRC]][0, %[[C_2]], 0, 0] [1, 1, %[[SIZE]], 1] [1, 1, 1, 1] : tensor<?x3x?x1xi32> to tensor<?x1xi32> | ||
| // CHECK: %[[PAD:.*]] = arith.constant 0 : i32 | ||
| // CHECK: %[[C1:.*]] = arith.constant 1 : index | ||
| // CHECK: %[[MASK:.*]] = vector.create_mask %[[SIZE]], %[[C1]] : vector<8x1xi1> | ||
| // CHECK: %[[C0:.*]] = arith.constant 0 : index | ||
| // CHECK: %[[READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[SRC_SIZE]]{{\[}}%[[C0]], %[[C0]]], %[[PAD]] : tensor<?x1xi32>, vector<8x1xi32> } : vector<8x1xi1> -> vector<8x1xi32> | ||
| // CHECK: %[[C_0_1:.*]] = arith.constant 0 : index | ||
| // CHECK: %[[WRITE:.*]] = vector.mask %[[MASK]] { vector.transfer_write %[[READ]], %[[INIT]]{{\[}}%[[C_0_1]], %[[C_2]]] : vector<8x1xi32>, tensor<?x3xi32> } : vector<8x1xi1> -> tensor<?x3xi32> | ||
| // CHECK: return %[[WRITE]] : tensor<?x3xi32> | ||
|
|
||
| module attributes {transform.with_named_sequence} { | ||
| transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { | ||
| %0 = transform.structured.match ops{["tensor.insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op | ||
| transform.structured.vectorize %0 vector_sizes [8, 1] : !transform.any_op | ||
| transform.yield | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Move the declaration to where it is initialized, i.e., l.2742?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that l.2742 sits within an
ifblock and the generated mask is also used outside, e.g. l.2756.This is roughly the structure:
I have ideas how to improve this, but no spare cycles 😢 (there's
createWriteOrMaskedWriteandcreateReadOrMaskedReadthat we should re-use here, but that won't work as-is).If that's OK, will add this to my TODO list?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry that I missed it. I see, thanks!