Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
188 changes: 187 additions & 1 deletion mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/LogicalResult.h"

Expand Down Expand Up @@ -428,6 +429,190 @@ struct BubbleUpExpandShapeThroughExtractSlice
}
};

/// Converts `tensor.collapse_shape(tensor.extract_slice)` to
/// `tensor.extract_slice(tensor.collapse_shape)`.
///
/// For this transformation to be possible, the slice must be representable as a
/// contiguous slice within each reassociation group of the src.
///
/// In case the size and offset extracted are static then this is possible if
/// the following conditions are met:
/// Let T be a tensor of shape [A0, A1, ..., An], and let S = [S0, S1, ..., Sn]
/// be the shape of a desired slice. A slice of shape S can be extracted as a
/// contiguous block of memory if and only if there exists an index k in {0, 1,
/// ..., n} such that:
/// S_i = 1 for all i < k (that is, all leading dimensions are singleton),
/// 1 <= S_k <= A_k (that is, non trivial slicing occurs along exactly
/// one dimension),
/// S_i = A_i for all i > k (that is, all trailing dimensions are preserved
/// in full).
/// In other words, the slice shape S must be of the form:
/// [ 1, 1, ..., 1, Sk, Ak + 1, Ak + 2, ...,An ]
///
/// In case the size and/or offset extracted are dynamic then this is possible
/// only if there is single dimension in the reassociation group that has a size
/// not equal to 1.
/// In other words, the tensor shape must be of the form:
/// [ 1, 1, ..., 1, A, 1, ...,1 ]
/// Note - it might be possible to enable this pattern for more cases when the
/// size/offset are dynamic via performing an analysis of the possible values
/// that could be given to the size/offset.
///
/// Example:
/// The transformation is possible because each reassociation group can be
/// represented as a contiguous slice (i.e., [8x16->2x16], [1x7->1x?],
/// [20->10]).
/// ```
/// BEFORE:
/// %collapse = tensor.collapse_shape %src [[0, 1], [2, 3], [4]] ...
/// tensor<8x16x1x7x20f32> to tensor<128x7x20xf32>
/// %slice = tensor.extract_slice %slice [0, 0, 0][32, %size, 10][1, 1, 1]
/// tensor<128x7x20xf32> to tensor<32x?x10xf32>
///
/// AFTER:
/// %slice = tensor.extract_slice %src [0, 0, 0, 0, 0][2, 16, 1, %size, 10]
// [1, 1, 1, 1, 1] : tensor<8x16x1x7x20f32> to tensor<2x16x1x?x10xf32>
/// %collapse = tensor.collapse_shape %slice [[0, 1], [2, 3], [4]] ...
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whew! This is actually correct. Took me a while to work out that it, but this works because for the reassociation the other dimensions are all 1.

/// tensor<2x16x1x?x10xf32> to tensor<32x?x10xf32>
/// ```
struct BubbleUpCollapseShapeThroughExtractSlice
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm coming from #153675, which performs the refactoring.

Similar to https://github.com/llvm/llvm-project/pull/126898/files#r2286013250, should it be BubbleUpExtractSliceThroughCollapseShape?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, thanks for noticing.

: public OpRewritePattern<tensor::ExtractSliceOp> {
using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;

LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
PatternRewriter &rewriter) const override {
auto collapseShapeOp =
sliceOp.getSource().getDefiningOp<tensor::CollapseShapeOp>();
if (!collapseShapeOp)
return rewriter.notifyMatchFailure(
sliceOp,
"tensor.extract_slice source not produced by tensor.collapse_shape");

if (!sliceOp.hasUnitStride()) {
return rewriter.notifyMatchFailure(
sliceOp, "unsupported: non-unit stride. Only contiguous slices can "
"be supported in this transformation.");
}

// The tensor.extract_slice before applying the pattern works on the result
// of the tensor.collapse_shape, so variables (i.e. inputs for
// ExtractSliceOp) referring to the state before applying the pattern are
// named with the prefix "collapsed", and ones referring to the state after
// applying the pattern are named with the prefix "expanded".
SmallVector<OpFoldResult> collapsedOffsets = sliceOp.getMixedOffsets();
SmallVector<OpFoldResult> collapsedSizes = sliceOp.getMixedSizes();

if (static_cast<size_t>(sliceOp.getResultType().getRank()) !=
collapsedSizes.size())
return rewriter.notifyMatchFailure(sliceOp,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Please add { } around multi-line statements.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added in all the relevant places in the pattern

"unimplemented: rank reducing slice");

ArrayRef<int64_t> srcShape = collapseShapeOp.getSrcType().getShape();
SmallVector<ReassociationIndices, 4> reassociationIndices =
collapseShapeOp.getReassociationIndices();

// Compute new offsets, sizes, and strides for tensor.extract_slice.
// The new tensor.extract_slice will work on a tensor that has has a rank
// equal to the rank of the src of the collapse_shape. In each iteration of
// the loop, the offsets and sizes will be computed per reassociation group.
SmallVector<OpFoldResult> expandedOffsets, expandedSizes;
SmallVector<OpFoldResult> expandedStrides(srcShape.size(),
rewriter.getIndexAttr(1));

for (auto [groupIdx, reassocIndices] :
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: You could do

for (auto [collapsedSize, collapsedOffset, reassocIndicess] : llvm::zip_equal(collapsedSizes, collapsedOffsets, collapsedShapeOp.getReassociationIndices())

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! updated

enumerate(collapseShapeOp.getReassociationIndices())) {
OpFoldResult collapsedSize = collapsedSizes[groupIdx];
OpFoldResult collapsedOffset = collapsedOffsets[groupIdx];
// Case #1 - size and/or offset are dynamic.
// In this case, the slice can be represented as a contiguous slice only
// if there is a single dimension in the reassociation group that has a
// size not equal to 1.
if (isa<Value>(collapsedSize) || isa<Value>(collapsedOffset)) {
int nonUnitSizeCount = 0;
for (int64_t expandedShapeIdx : reassocIndices) {
if (srcShape[expandedShapeIdx] != 1) {
nonUnitSizeCount++;
expandedSizes.push_back(collapsedSize);
expandedOffsets.push_back(collapsedOffset);
continue;
}

expandedSizes.push_back(rewriter.getIndexAttr(1));
expandedOffsets.push_back(rewriter.getIndexAttr(0));
}

if (nonUnitSizeCount != 1) {
return rewriter.notifyMatchFailure(
sliceOp,
"unsupported: slice cannot be verified to be contiguous");
}
continue;
}

// Case #2 = size and offset are static.
// Verify that the slice can be represented as a contiguous slice of the
// src of the collapse_shape.
// Checking this must be done on order of most
// internal dimensions first, so traversal is done in reverse order of the
// reassociation group.
int64_t collapsedSizeValue = getConstantIntValue(collapsedSize).value();
int64_t collapsedOffsetValue =
getConstantIntValue(collapsedOffset).value();

SmallVector<OpFoldResult> groupExpandedSizes, groupExpandedOffsets;

for (int64_t expandedShapeIdx : llvm::reverse(reassocIndices)) {
int64_t expandedShapeSize = srcShape[expandedShapeIdx];

// This is a dimension that slicing will occur on, so need to make sure
// that the slice size can be set to the shape size and the offset to 0.
if (collapsedSizeValue >= expandedShapeSize &&
(collapsedSizeValue % expandedShapeSize != 0 ||
collapsedOffsetValue % expandedShapeSize != 0)) {
return rewriter.notifyMatchFailure(
sliceOp, "unsupported: cannot be extracted as a contiguous slice "
"of the src of the collapse_shape");
}

int64_t offsetInDim = collapsedOffsetValue % expandedShapeSize;

// This is the dimension that slicing will occur along, so need to make
// sure that the slice size + offset will not exceed the shape size.
if (collapsedSizeValue < expandedShapeSize &&
(collapsedSizeValue + offsetInDim) >= expandedShapeSize) {
return rewriter.notifyMatchFailure(
sliceOp, "unsupported: slice cannot be extracted as a contiguous "
"slice of the src of the collapse_shape");
}

groupExpandedSizes.push_back(rewriter.getIndexAttr(
std::min(collapsedSizeValue, expandedShapeSize)));
groupExpandedOffsets.push_back(rewriter.getIndexAttr(offsetInDim));

// Remove the size and offset of trailing dimensions from the size and
// offset of the slice.
collapsedSizeValue /= expandedShapeSize;
collapsedSizeValue = std::max<int64_t>(collapsedSizeValue, 1);
collapsedOffsetValue /= expandedShapeSize;
}

expandedSizes.append(groupExpandedSizes.rbegin(),
groupExpandedSizes.rend());
expandedOffsets.append(groupExpandedOffsets.rbegin(),
groupExpandedOffsets.rend());
}

Value newSliceOp = rewriter.create<tensor::ExtractSliceOp>(
collapseShapeOp->getLoc(), collapseShapeOp.getSrc(), expandedOffsets,
expandedSizes, expandedStrides);
rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
sliceOp, sliceOp.getResultType(), newSliceOp,
collapseShapeOp.getReassociationIndices());

return success();
}
};

} // namespace

void mlir::tensor::populateReassociativeReshapeFoldingPatterns(
Expand All @@ -448,5 +633,6 @@ void mlir::tensor::populateBubbleUpExpandShapePatterns(

void mlir::tensor::populateBubbleUpExtractSliceOpPatterns(
RewritePatternSet &patterns) {
patterns.add<BubbleUpExpandShapeThroughExtractSlice>(patterns.getContext());
patterns.add<BubbleUpExpandShapeThroughExtractSlice,
BubbleUpCollapseShapeThroughExtractSlice>(patterns.getContext());
}
50 changes: 50 additions & 0 deletions mlir/test/Dialect/Linalg/transform-op-fuse.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -438,3 +438,53 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}

// -----

// CHECK-LABEL: func.func @bubble_up_extract_slice_through_collapse_shape(
// CHECK: scf.for %[[X:[A-Za-z0-9]+]] = {{.*}} -> (tensor<8x1800x32xf32>) {
// CHECK: %[[EXTRACT:.*]] = tensor.extract_slice
// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape %[[EXTRACT]]
// CHECK: %[[EXP1:.*]] = linalg.exp ins(%[[COLLAPSE]]
func.func @bubble_up_extract_slice_through_collapse_shape(%0: tensor<1x8x1800x32xf32>) -> tensor<8x1800x32xf32> {
%expand = tensor.collapse_shape %0 [[0, 1], [2], [3]] : tensor<1x8x1800x32xf32> into tensor<8x1800x32xf32>
%empty = tensor.empty() : tensor<8x1800x32xf32>
%exp = linalg.exp ins(%expand : tensor<8x1800x32xf32>) outs(%empty : tensor<8x1800x32xf32>) -> tensor<8x1800x32xf32>
return %exp : tensor<8x1800x32xf32>
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.exp"]} in %arg0 : (!transform.any_op) -> !transform.any_op
%transformed, %loops:1 = transform.structured.fuse %0 [1, 0, 0] interchange [0, 1, 2] apply_cleanup = true :
(!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">)
transform.yield
}
}


// -----

// CHECK-LABEL: func.func @bubble_up_extract_slice_through_collapse_shape_with_collapse_producer(
// CHECK: scf.for %[[X:[A-Za-z0-9]+]] = {{.*}}
// CHECK: %[[EXTRACT:.*]] = tensor.extract_slice
// CHECK: %[[ABS:.*]] = linalg.abs ins(%[[EXTRACT]]
// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape %[[ABS]]
// CHECK: %[[EXP:.*]] = linalg.exp ins(%[[COLLAPSE]]
func.func @bubble_up_extract_slice_through_collapse_shape_with_collapse_producer(%0: tensor<1x8x1800x32xf32>) -> tensor<8x1800x32xf32> {
%empty1 = tensor.empty() : tensor<1x8x1800x32xf32>
%abs = linalg.abs ins(%0 : tensor<1x8x1800x32xf32>) outs(%empty1 : tensor<1x8x1800x32xf32>) -> tensor<1x8x1800x32xf32>
%expand = tensor.collapse_shape %abs [[0, 1], [2], [3]] : tensor<1x8x1800x32xf32> into tensor<8x1800x32xf32>
%empty2 = tensor.empty() : tensor<8x1800x32xf32>
%exp = linalg.exp ins(%expand : tensor<8x1800x32xf32>) outs(%empty2 : tensor<8x1800x32xf32>) -> tensor<8x1800x32xf32>
return %exp : tensor<8x1800x32xf32>
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.exp"]} in %arg0 : (!transform.any_op) -> !transform.any_op
%transformed, %loops:1 = transform.structured.fuse %0 [1, 0, 0] interchange [0, 1, 2] apply_cleanup = true :
(!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">)
transform.yield
}
}
Loading