-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[MLIR] Bubble up tensor.extract_slice through tensor.collapse_shape #131982
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
72fbf71
c0291d0
72b0be3
1aaf3c9
5845db6
3c69390
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -12,8 +12,10 @@ | |
| #include "mlir/Dialect/Tensor/Transforms/Transforms.h" | ||
| #include "mlir/IR/PatternMatch.h" | ||
| #include "mlir/Interfaces/ValueBoundsOpInterface.h" | ||
| #include "llvm/ADT/STLExtras.h" | ||
| #include "llvm/Support/Debug.h" | ||
| #include "llvm/Support/LogicalResult.h" | ||
| #include <algorithm> | ||
|
|
||
| using namespace mlir; | ||
| using namespace mlir::tensor; | ||
|
|
@@ -428,6 +430,190 @@ struct BubbleUpExpandShapeThroughExtractSlice | |
| } | ||
| }; | ||
|
|
||
| /// Converts `tensor.collapse_shape(tensor.extract_slice)` to | ||
| /// `tensor.extract_slice(tensor.collapse_shape)`. | ||
banach-space marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| /// | ||
| /// For this transformation to be possible, the slice must be representable as a | ||
| /// contiguous slice within each reassociation group of the src. | ||
banach-space marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| /// | ||
| /// In case the size and offset extracted are static then this is possible if | ||
| /// the following conditions are met: | ||
| /// Let T be a tensor of shape [A0, A1, ..., An], and let S = [S0, S1, ..., Sn] | ||
| /// be the shape of a desired slice. A slice of shape S can be extracted as a | ||
| /// contiguous block of memory if and only if there exists an index k in {0, 1, | ||
banach-space marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| /// ..., n} such that: | ||
| /// S_i = 1 for all i < k (that is, all leading dimensions are singleton), | ||
| /// 1 <= S_k <= A_k (that is, non trivial slicing occurs along exactly | ||
| /// one dimension), | ||
| /// S_i = A_i for all i > k (that is, all trailing dimensions are preserved | ||
| /// in full). | ||
| /// In other words, the slice shape S must be of the form: | ||
| /// [ 1, 1, ..., 1, Sk, Ak + 1, Ak + 2, ...,An ] | ||
| /// | ||
| /// In case the size and/or offset extracted are dynamic then this is possible | ||
| /// only if there is single dimension in the reassociation group that has a size | ||
| /// not equal to 1. | ||
| /// In other words, the tensor shape must be of the form: | ||
| /// [ 1, 1, ..., 1, A, 1, ...,1 ] | ||
| /// Note - it might be possible to enable this pattern for more cases when the | ||
| /// size/offset are dynamic via performing an analysis of the possible values | ||
| /// that could be given to the size/offset. | ||
banach-space marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| /// | ||
| /// Example: | ||
| /// The transformation is possible because each reassociation group can be | ||
| /// represented as a contiguous slice (i.e., [8x16->2x16], [1x7->1x?], | ||
| /// [20->10]). | ||
| /// ``` | ||
| /// BEFORE: | ||
| /// %collapse = tensor.collapse_shape %src [[0, 1], [2, 3], [4]] ... | ||
| /// tensor<8x16x1x7x20f32> to tensor<128x7x20xf32> | ||
| /// %slice = tensor.extract_slice %slice [0, 0, 0][32, %size, 10][1, 1, 1] | ||
| /// tensor<128x7x20xf32> to tensor<32x?x10xf32> | ||
| /// | ||
| /// AFTER: | ||
| /// %slice = tensor.extract_slice %src [0, 0, 0, 0, 0][2, 16, 1, %size, 10] | ||
| // [1, 1, 1, 1, 1] : tensor<8x16x1x7x20f32> to tensor<2x16x1x?x10xf32> | ||
| /// %collapse = tensor.collapse_shape %slice [[0, 1], [2, 3], [4]] ... | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Whew! This is actually correct. Took me a while to work out that it, but this works because for the reassociation the other dimensions are all 1. |
||
| /// tensor<2x16x1x?x10xf32> to tensor<32x?x10xf32> | ||
| /// ``` | ||
| struct BubbleUpCollapseShapeThroughExtractSlice | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm coming from #153675, which performs the refactoring. Similar to https://github.com/llvm/llvm-project/pull/126898/files#r2286013250, should it be
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, thanks for noticing. |
||
| : public OpRewritePattern<tensor::ExtractSliceOp> { | ||
| using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern; | ||
|
|
||
| LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp, | ||
| PatternRewriter &rewriter) const override { | ||
| auto collapseShapeOp = | ||
| sliceOp.getSource().getDefiningOp<tensor::CollapseShapeOp>(); | ||
| if (!collapseShapeOp) | ||
| return rewriter.notifyMatchFailure( | ||
| sliceOp, | ||
| "tensor.extract_slice source not produced by tensor.collapse_shape"); | ||
|
|
||
| if (!sliceOp.hasUnitStride()) { | ||
| return rewriter.notifyMatchFailure( | ||
| sliceOp, "unsupported: non-unit stride. Only contiguous slices can " | ||
| "be supported in this transformation."); | ||
| } | ||
|
|
||
| // The tensor.extract_slice before applying the pattern works on the result | ||
| // of the tensor.collapse_shape, so variables (i.e. inputs for | ||
| // ExtractSliceOp) referring to the state before applying the pattern are | ||
| // named with the prefix "collapsed", and ones referring to the state after | ||
| // applying the pattern are named with the prefix "expanded". | ||
| SmallVector<OpFoldResult> collapsedOffsets = sliceOp.getMixedOffsets(); | ||
| SmallVector<OpFoldResult> collapsedSizes = sliceOp.getMixedSizes(); | ||
|
|
||
| if (static_cast<size_t>(sliceOp.getResultType().getRank()) != | ||
| collapsedSizes.size()) | ||
| return rewriter.notifyMatchFailure(sliceOp, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: Please add
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added in all the relevant places in the pattern |
||
| "unimplemented: rank reducing slice"); | ||
|
|
||
| ArrayRef<int64_t> srcShape = collapseShapeOp.getSrcType().getShape(); | ||
| SmallVector<ReassociationIndices, 4> reassociationIndices = | ||
| collapseShapeOp.getReassociationIndices(); | ||
|
|
||
| // Compute new offsets, sizes, and strides for tensor.extract_slice. | ||
| // The new tensor.extract_slice will work on a tensor that has has a rank | ||
| // equal to the rank of the src of the collapse_shape. In each iteration of | ||
| // the loop, the offsets and sizes will be computed per reassociation group. | ||
| SmallVector<OpFoldResult> expandedOffsets, expandedSizes; | ||
| SmallVector<OpFoldResult> expandedStrides(srcShape.size(), | ||
| rewriter.getIndexAttr(1)); | ||
|
|
||
| for (auto [groupIdx, reassocIndices] : | ||
|
||
| enumerate(collapseShapeOp.getReassociationIndices())) { | ||
| OpFoldResult collapsedSize = collapsedSizes[groupIdx]; | ||
| OpFoldResult collapsedOffset = collapsedOffsets[groupIdx]; | ||
| // Case #1 - size and/or offset are dynamic. | ||
banach-space marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| // In this case, the slice can be represented as a contiguous slice only | ||
| // if there is a single dimension in the reassociation group that has a | ||
| // size not equal to 1. | ||
| if (isa<Value>(collapsedSize) || isa<Value>(collapsedOffset)) { | ||
| int nonUnitSizeCount = 0; | ||
| for (int64_t expandedShapeIdx : reassocIndices) { | ||
| if (srcShape[expandedShapeIdx] != 1) { | ||
| nonUnitSizeCount++; | ||
| expandedSizes.emplace_back(collapsedSize); | ||
| expandedOffsets.emplace_back(collapsedOffset); | ||
banach-space marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| continue; | ||
| } | ||
|
|
||
| expandedSizes.emplace_back(rewriter.getIndexAttr(1)); | ||
| expandedOffsets.emplace_back(rewriter.getIndexAttr(0)); | ||
| } | ||
|
|
||
| if (nonUnitSizeCount != 1) { | ||
| return rewriter.notifyMatchFailure( | ||
| sliceOp, | ||
| "unsupported: slice cannot be verified to be contiguous"); | ||
| } | ||
| continue; | ||
| } | ||
|
|
||
| // Case #2 = size and offset are static. | ||
| // Verify that the slice can be represented as a contiguous slice of the | ||
| // src of the collapse_shape. | ||
| // Checking this must be done on order of most | ||
| // internal dimensions first, so traversal is done in reverse order of the | ||
| // reassociation group. | ||
| int64_t collapsedSizeValue = getConstantIntValue(collapsedSize).value(); | ||
| int64_t collapsedOffsetValue = | ||
| getConstantIntValue(collapsedOffset).value(); | ||
|
|
||
| SmallVector<OpFoldResult> groupExpandedSizes, groupExpandedOffsets; | ||
|
|
||
| for (int64_t expandedShapeIdx : llvm::reverse(reassocIndices)) { | ||
| int64_t expandedShapeSize = srcShape[expandedShapeIdx]; | ||
|
|
||
| // This is a dimension that slicing will occur on, so need to make sure | ||
| // that the slice size can be set to the shape size and the offset to 0. | ||
| if (collapsedSizeValue >= expandedShapeSize && | ||
banach-space marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| (collapsedSizeValue % expandedShapeSize != 0 || | ||
| collapsedOffsetValue % expandedShapeSize != 0)) { | ||
| return rewriter.notifyMatchFailure( | ||
| sliceOp, "unsupported: cannot be extracted as a contiguous slice " | ||
| "of the src of the collapse_shape"); | ||
| } | ||
|
|
||
| int64_t offsetInDim = collapsedOffsetValue % expandedShapeSize; | ||
|
|
||
| // This is the dimension that slicing will occur along, so need to make | ||
| // sure that the slice size + offset will not exceed the shape size. | ||
| if (collapsedSizeValue < expandedShapeSize && | ||
| (collapsedSizeValue + offsetInDim) >= expandedShapeSize) { | ||
| return rewriter.notifyMatchFailure( | ||
| sliceOp, "unsupported: slice cannot be extracted as a contiguous " | ||
| "slice of the src of the collapse_shape"); | ||
| } | ||
|
|
||
| groupExpandedSizes.emplace_back(rewriter.getIndexAttr( | ||
| std::min(collapsedSizeValue, expandedShapeSize))); | ||
| groupExpandedOffsets.emplace_back(rewriter.getIndexAttr(offsetInDim)); | ||
|
|
||
| // Remove the size and offset of trailing dimensions from the size and | ||
| // offset of the slice. | ||
| collapsedSizeValue /= expandedShapeSize; | ||
| collapsedSizeValue = std::max<int64_t>(collapsedSizeValue, 1); | ||
| collapsedOffsetValue /= expandedShapeSize; | ||
| } | ||
|
|
||
| expandedSizes.append(groupExpandedSizes.rbegin(), | ||
| groupExpandedSizes.rend()); | ||
| expandedOffsets.append(groupExpandedOffsets.rbegin(), | ||
| groupExpandedOffsets.rend()); | ||
| } | ||
|
|
||
| Value newSliceOp = rewriter.create<tensor::ExtractSliceOp>( | ||
| collapseShapeOp->getLoc(), collapseShapeOp.getSrc(), expandedOffsets, | ||
| expandedSizes, expandedStrides); | ||
| rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>( | ||
| sliceOp, sliceOp.getResultType(), newSliceOp, | ||
| collapseShapeOp.getReassociationIndices()); | ||
|
|
||
| return success(); | ||
| } | ||
| }; | ||
|
|
||
| } // namespace | ||
|
|
||
| void mlir::tensor::populateReassociativeReshapeFoldingPatterns( | ||
|
|
@@ -448,5 +634,6 @@ void mlir::tensor::populateBubbleUpExpandShapePatterns( | |
|
|
||
| void mlir::tensor::populateBubbleUpExtractSliceOpPatterns( | ||
| RewritePatternSet &patterns) { | ||
| patterns.add<BubbleUpExpandShapeThroughExtractSlice>(patterns.getContext()); | ||
| patterns.add<BubbleUpExpandShapeThroughExtractSlice, | ||
| BubbleUpCollapseShapeThroughExtractSlice>(patterns.getContext()); | ||
| } | ||
Uh oh!
There was an error while loading. Please reload this page.