Skip to content

Commit 72fbf71

Browse files
committed
[MLIR] Bubble up tensor.extract_slice through tensor.collapse_shape
Add a pattern that bubbles up tensor.extract_slice through tensor.collapse_shape. The pattern is registered in a pattern population function that is used by the transform op transform.apply_patterns.tensor.bubble_up_extract_slice and by the tranform op transform.structured.fuse as a cleanup pattern. This pattern enables tiling and fusing op chains which contain tensor.collapse_shape if added as a cleanup pattern of tile and fuse utility. Without this pattern that would not be possible, as tensor.collapse_shape does not implement the tiling interface. This is an additional pattern to the one added in PR llvm#126898
1 parent d325547 commit 72fbf71

File tree

3 files changed

+391
-1
lines changed

3 files changed

+391
-1
lines changed

mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp

Lines changed: 188 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,10 @@
1212
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
1313
#include "mlir/IR/PatternMatch.h"
1414
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
15+
#include "llvm/ADT/STLExtras.h"
1516
#include "llvm/Support/Debug.h"
1617
#include "llvm/Support/LogicalResult.h"
18+
#include <algorithm>
1719

1820
using namespace mlir;
1921
using namespace mlir::tensor;
@@ -428,6 +430,190 @@ struct BubbleUpExpandShapeThroughExtractSlice
428430
}
429431
};
430432

433+
/// Converts `tensor.collapse_shape(tensor.extract_slice)` to
434+
/// `tensor.extract_slice(tensor.collapse_shape)`.
435+
///
436+
/// For this transformation to be possible, the slice must be representable as a
437+
/// contiguous slice within each reassociation group of the src.
438+
///
439+
/// In case the size and offset extracted are static then this is possible if
440+
/// the following conditions are met:
441+
/// Let T be a tensor of shape [A0, A1, ..., An], and let S = [S0, S1, ..., Sn]
442+
/// be the shape of a desired slice. A slice of shape S can be extracted as a
443+
/// contiguous block of memory if and only if there exists an index k in {0, 1,
444+
/// ..., n} such that:
445+
/// S_i = 1 for all i < k (that is, all leading dimensions are singleton),
446+
/// 1 <= S_k <= A_k (that is, non trivial slicing occurs along exactly
447+
/// one dimension),
448+
/// S_i = A_i for all i > k (that is, all trailing dimensions are preserved
449+
/// in full).
450+
/// In other words, the slice shape S must be of the form:
451+
/// [ 1, 1, ..., 1, Sk, Ak + 1, Ak + 2, ...,An ]
452+
///
453+
/// In case the size and/or offset extracted are dynamic then this is possible
454+
/// only if there is single dimension in the reassociation group that has a size
455+
/// not equal to 1.
456+
/// In other words, the tensor shape must be of the form:
457+
/// [ 1, 1, ..., 1, A, 1, ...,1 ]
458+
/// Note - it might be possible to enable this pattern for more cases when the
459+
/// size/offset are dynamic via performing an analysis of the possible values
460+
/// that could be given to the size/offset.
461+
///
462+
/// Example:
463+
/// The transformation is possible because each reassociation group can be
464+
/// represented as a contiguous slice (i.e., [8x16->2x16], [1x7->1x?],
465+
/// [20->10]).
466+
/// ```
467+
/// BEFORE:
468+
/// %collapse = tensor.collapse_shape %src [[0, 1], [2, 3], [4]] ...
469+
/// tensor<8x16x1x7x20f32> to tensor<128x7x20xf32>
470+
/// %slice = tensor.extract_slice %slice [0, 0, 0][32, %size, 10][1, 1, 1]
471+
/// tensor<128x7x20xf32> to tensor<32x?x10xf32>
472+
///
473+
/// AFTER:
474+
/// %slice = tensor.extract_slice %src [0, 0, 0, 0, 0][2, 16, 1, %size, 10]
475+
// [1, 1, 1, 1, 1] : tensor<8x16x1x7x20f32> to tensor<2x16x1x?x10xf32>
476+
/// %collapse = tensor.collapse_shape %slice [[0, 1], [2, 3], [4]] ...
477+
/// tensor<2x16x1x?x10xf32> to tensor<32x?x10xf32>
478+
/// ```
479+
struct BubbleUpCollapseShapeThroughExtractSlice
480+
: public OpRewritePattern<tensor::ExtractSliceOp> {
481+
using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
482+
483+
LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
484+
PatternRewriter &rewriter) const override {
485+
auto collapseShapeOp =
486+
sliceOp.getSource().getDefiningOp<tensor::CollapseShapeOp>();
487+
if (!collapseShapeOp)
488+
return rewriter.notifyMatchFailure(
489+
sliceOp,
490+
"tensor.extract_slice source not produced by tensor.collapse_shape");
491+
492+
if (!sliceOp.hasUnitStride()) {
493+
return rewriter.notifyMatchFailure(
494+
sliceOp, "unsupported: non-unit stride. Only contiguous slices can "
495+
"be supported in this transformation.");
496+
}
497+
498+
// The tensor.extract_slice before applying the pattern works on the result
499+
// of the tensor.collapse_shape, so variables (i.e. inputs for
500+
// ExtractSliceOp) referring to the state before applying the pattern are
501+
// named with the prefix "collapsed", and ones referring to the state after
502+
// applying the pattern are named with the prefix "expanded".
503+
SmallVector<OpFoldResult> collapsedOffsets = sliceOp.getMixedOffsets();
504+
SmallVector<OpFoldResult> collapsedSizes = sliceOp.getMixedSizes();
505+
506+
if (static_cast<size_t>(sliceOp.getResultType().getRank()) !=
507+
collapsedSizes.size())
508+
return rewriter.notifyMatchFailure(sliceOp,
509+
"unimplemented: rank reducing slice");
510+
511+
ArrayRef<int64_t> srcShape = collapseShapeOp.getSrcType().getShape();
512+
SmallVector<ReassociationIndices, 4> reassociationIndices =
513+
collapseShapeOp.getReassociationIndices();
514+
515+
// Compute new offsets, sizes, and strides for tensor.extract_slice.
516+
// The new tensor.extract_slice will work on a tensor that has has a rank
517+
// equal to the rank of the src of the collapse_shape. In each iteration of
518+
// the loop, the offsets and sizes will be computed per reassociation group.
519+
SmallVector<OpFoldResult> expandedOffsets, expandedSizes;
520+
SmallVector<OpFoldResult> expandedStrides(srcShape.size(),
521+
rewriter.getIndexAttr(1));
522+
523+
for (auto [groupIdx, reassocIndices] :
524+
enumerate(collapseShapeOp.getReassociationIndices())) {
525+
OpFoldResult collapsedSize = collapsedSizes[groupIdx];
526+
OpFoldResult collapsedOffset = collapsedOffsets[groupIdx];
527+
// Case #1 - size and/or offset are dynamic.
528+
// In this case, the slice can be represented as a contiguous slice only
529+
// if there is a single dimension in the reassociation group that has a
530+
// size not equal to 1.
531+
if (isa<Value>(collapsedSize) || isa<Value>(collapsedOffset)) {
532+
int nonUnitSizeCount = 0;
533+
for (int64_t expandedShapeIdx : reassocIndices) {
534+
if (srcShape[expandedShapeIdx] != 1) {
535+
nonUnitSizeCount++;
536+
expandedSizes.emplace_back(collapsedSize);
537+
expandedOffsets.emplace_back(collapsedOffset);
538+
continue;
539+
}
540+
541+
expandedSizes.emplace_back(rewriter.getIndexAttr(1));
542+
expandedOffsets.emplace_back(rewriter.getIndexAttr(0));
543+
}
544+
545+
if (nonUnitSizeCount != 1) {
546+
return rewriter.notifyMatchFailure(
547+
sliceOp,
548+
"unsupported: slice cannot be verified to be contiguous");
549+
}
550+
continue;
551+
}
552+
553+
// Case #2 = size and offset are static.
554+
// Verify that the slice can be represented as a contiguous slice of the
555+
// src of the collapse_shape.
556+
// Checking this must be done on order of most
557+
// internal dimensions first, so traversal is done in reverse order of the
558+
// reassociation group.
559+
int64_t collapsedSizeValue = getConstantIntValue(collapsedSize).value();
560+
int64_t collapsedOffsetValue =
561+
getConstantIntValue(collapsedOffset).value();
562+
563+
SmallVector<OpFoldResult> groupExpandedSizes, groupExpandedOffsets;
564+
565+
for (int64_t expandedShapeIdx : llvm::reverse(reassocIndices)) {
566+
int64_t expandedShapeSize = srcShape[expandedShapeIdx];
567+
568+
// This is a dimension that slicing will occur on, so need to make sure
569+
// that the slice size can be set to the shape size and the offset to 0.
570+
if (collapsedSizeValue >= expandedShapeSize &&
571+
(collapsedSizeValue % expandedShapeSize != 0 ||
572+
collapsedOffsetValue % expandedShapeSize != 0)) {
573+
return rewriter.notifyMatchFailure(
574+
sliceOp, "unsupported: cannot be extracted as a contiguous slice "
575+
"of the src of the collapse_shape");
576+
}
577+
578+
int64_t offsetInDim = collapsedOffsetValue % expandedShapeSize;
579+
580+
// This is the dimension that slicing will occur along, so need to make
581+
// sure that the slice size + offset will not exceed the shape size.
582+
if (collapsedSizeValue < expandedShapeSize &&
583+
(collapsedSizeValue + offsetInDim) >= expandedShapeSize) {
584+
return rewriter.notifyMatchFailure(
585+
sliceOp, "unsupported: slice cannot be extracted as a contiguous "
586+
"slice of the src of the collapse_shape");
587+
}
588+
589+
groupExpandedSizes.emplace_back(rewriter.getIndexAttr(
590+
std::min(collapsedSizeValue, expandedShapeSize)));
591+
groupExpandedOffsets.emplace_back(rewriter.getIndexAttr(offsetInDim));
592+
593+
// Remove the size and offset of trailing dimensions from the size and
594+
// offset of the slice.
595+
collapsedSizeValue /= expandedShapeSize;
596+
collapsedSizeValue = std::max<int64_t>(collapsedSizeValue, 1);
597+
collapsedOffsetValue /= expandedShapeSize;
598+
}
599+
600+
expandedSizes.append(groupExpandedSizes.rbegin(),
601+
groupExpandedSizes.rend());
602+
expandedOffsets.append(groupExpandedOffsets.rbegin(),
603+
groupExpandedOffsets.rend());
604+
}
605+
606+
Value newSliceOp = rewriter.create<tensor::ExtractSliceOp>(
607+
collapseShapeOp->getLoc(), collapseShapeOp.getSrc(), expandedOffsets,
608+
expandedSizes, expandedStrides);
609+
rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
610+
sliceOp, sliceOp.getResultType(), newSliceOp,
611+
collapseShapeOp.getReassociationIndices());
612+
613+
return success();
614+
}
615+
};
616+
431617
} // namespace
432618

433619
void mlir::tensor::populateReassociativeReshapeFoldingPatterns(
@@ -448,5 +634,6 @@ void mlir::tensor::populateBubbleUpExpandShapePatterns(
448634

449635
void mlir::tensor::populateBubbleUpExtractSliceOpPatterns(
450636
RewritePatternSet &patterns) {
451-
patterns.add<BubbleUpExpandShapeThroughExtractSlice>(patterns.getContext());
637+
patterns.add<BubbleUpExpandShapeThroughExtractSlice,
638+
BubbleUpCollapseShapeThroughExtractSlice>(patterns.getContext());
452639
}

mlir/test/Dialect/Linalg/transform-op-fuse.mlir

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -438,3 +438,53 @@ module attributes {transform.with_named_sequence} {
438438
transform.yield
439439
}
440440
}
441+
442+
// -----
443+
444+
// CHECK-LABEL: func.func @bubble_up_extract_slice_through_collapse_shape(
445+
// CHECK: scf.for %[[X:[A-Za-z0-9]+]] = {{.*}} -> (tensor<8x1800x32xf32>) {
446+
// CHECK: %[[EXTRACT1:.*]] = tensor.extract_slice
447+
// CHECK: %[[COLLAPSE1:.*]] = tensor.collapse_shape %[[EXTRACT1]]
448+
// CHECK: %[[EXP1:.*]] = linalg.exp ins(%[[COLLAPSE1]]
449+
func.func @bubble_up_extract_slice_through_collapse_shape(%0: tensor<1x8x1800x32xf32>) -> tensor<8x1800x32xf32> {
450+
%expand = tensor.collapse_shape %0 [[0, 1], [2], [3]] : tensor<1x8x1800x32xf32> into tensor<8x1800x32xf32>
451+
%empty = tensor.empty() : tensor<8x1800x32xf32>
452+
%exp = linalg.exp ins(%expand : tensor<8x1800x32xf32>) outs(%empty : tensor<8x1800x32xf32>) -> tensor<8x1800x32xf32>
453+
return %exp : tensor<8x1800x32xf32>
454+
}
455+
456+
module attributes {transform.with_named_sequence} {
457+
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
458+
%0 = transform.structured.match ops{["linalg.exp"]} in %arg0 : (!transform.any_op) -> !transform.any_op
459+
%transformed, %loops:1 = transform.structured.fuse %0 [1, 0, 0] interchange [0, 1, 2] apply_cleanup = true :
460+
(!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">)
461+
transform.yield
462+
}
463+
}
464+
465+
466+
// -----
467+
468+
// CHECK-LABEL: func.func @bubble_up_extract_slice_through_collapse_shape_with_collapse_producer(
469+
// CHECK: scf.for %[[X:[A-Za-z0-9]+]] = {{.*}}
470+
// CHECK: %[[VAL_9:.*]] = tensor.extract_slice
471+
// CHECK: %[[VAL_11:.*]] = linalg.abs ins(%[[VAL_9]]
472+
// CHECK: %[[VAL_12:.*]] = tensor.collapse_shape %[[VAL_11]]
473+
// CHECK: %[[VAL_14:.*]] = linalg.exp ins(%[[VAL_12]]
474+
func.func @bubble_up_extract_slice_through_collapse_shape_with_collapse_producer(%0: tensor<1x8x1800x32xf32>) -> tensor<8x1800x32xf32> {
475+
%empty1 = tensor.empty() : tensor<1x8x1800x32xf32>
476+
%abs = linalg.abs ins(%0 : tensor<1x8x1800x32xf32>) outs(%empty1 : tensor<1x8x1800x32xf32>) -> tensor<1x8x1800x32xf32>
477+
%expand = tensor.collapse_shape %abs [[0, 1], [2], [3]] : tensor<1x8x1800x32xf32> into tensor<8x1800x32xf32>
478+
%empty2 = tensor.empty() : tensor<8x1800x32xf32>
479+
%exp = linalg.exp ins(%expand : tensor<8x1800x32xf32>) outs(%empty2 : tensor<8x1800x32xf32>) -> tensor<8x1800x32xf32>
480+
return %exp : tensor<8x1800x32xf32>
481+
}
482+
483+
module attributes {transform.with_named_sequence} {
484+
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
485+
%0 = transform.structured.match ops{["linalg.exp"]} in %arg0 : (!transform.any_op) -> !transform.any_op
486+
%transformed, %loops:1 = transform.structured.fuse %0 [1, 0, 0] interchange [0, 1, 2] apply_cleanup = true :
487+
(!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">)
488+
transform.yield
489+
}
490+
}

0 commit comments

Comments
 (0)