From 68a7815d4eb937f5b3f78b4c02057a4eeaa8f13a Mon Sep 17 00:00:00 2001 From: Ofri Frishman Date: Wed, 12 Feb 2025 13:12:48 +0200 Subject: [PATCH] [MLIR] Add pattern to bubble up tensor.extract_slice Add a pattern that bubbles up tensor.extract_slice through tensor.expand_shape, and add a transform op to tensor dialect to directly use this pattern. This pattern enables tiling and fusing op chains which contain tensor.expand_shape if added as a cleanup pattern of tile and fuse utility. Without this pattern that would not be possible, as tensor.expand_shape does not implement the tiling interface. In addition, registering this pattern as a cleanup pattern for transform.structured.fuse. The pattern was first implement in IREE project by Quinn Dawkins and is being upstreamed. --------- Co-authored-by: Quinn Dawkins --- .../Tensor/TransformOps/TensorTransformOps.td | 11 + .../Dialect/Tensor/Transforms/Transforms.h | 6 + .../TransformOps/LinalgTransformOps.cpp | 1 + .../TransformOps/TensorTransformOps.cpp | 5 + .../Tensor/Transforms/ReshapePatterns.cpp | 217 ++++++++++++++++++ .../Dialect/Linalg/transform-op-fuse.mlir | 160 +++++++++++++ .../Tensor/bubble-up-extract-slice-op.mlir | 124 ++++++++++ 7 files changed, 524 insertions(+) create mode 100644 mlir/test/Dialect/Tensor/bubble-up-extract-slice-op.mlir diff --git a/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td b/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td index 81bab1b0c82f7..4d01dea55ca2f 100644 --- a/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td +++ b/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td @@ -111,6 +111,17 @@ def ApplyReassociativeReshapeFoldingPatternsOp : Op]> { + let description = [{ + Indicates that producers of tensor.extract_slice should swap and operate on + the result of the slice. + }]; + + let assemblyFormat = "attr-dict"; +} + def ApplyRewriteTensorOpsAsConstantPatternsOp : Op]> { diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h index ae695e0326ca1..dc4558a605a59 100644 --- a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h @@ -58,6 +58,12 @@ void populateFoldTensorSubsetIntoVectorTransferPatterns( void populateMergeConsecutiveInsertExtractSlicePatterns( RewritePatternSet &patterns); +/// Appends patterns that are used to bubble up tensor.extract slice op above +/// its producer. When used as cleanup patterns of tile and fuse, enables fusing +/// the producer with the consumer even if the producer does not implement the +/// tiling interface. +void populateBubbleUpExtractSliceOpPatterns(RewritePatternSet &patterns); + /// Populates `patterns` with patterns that drop redundant tensor.insert_slice /// rank expansions. void populateDropRedundantInsertSliceRankExpansionPatterns( diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index 51d1df52598c7..5146bebe0108e 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -582,6 +582,7 @@ transform::FuseOp::apply(transform::TransformRewriter &rewriter, RewritePatternSet patterns(context); tensor::ExtractSliceOp::getCanonicalizationPatterns(patterns, context); tensor::populateMergeConsecutiveInsertExtractSlicePatterns(patterns); + tensor::populateBubbleUpExtractSliceOpPatterns(patterns); tileAndFuseOptions.cleanupPatterns = std::move(patterns); } diff --git a/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp b/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp index 99199252710f9..e4c60781ca270 100644 --- a/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp +++ b/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp @@ -125,6 +125,11 @@ void transform::ApplyReassociativeReshapeFoldingPatternsOp::populatePatterns( tensor::populateReassociativeReshapeFoldingPatterns(patterns); } +void transform::ApplyBubbleUpExtractSlicePatternsOp::populatePatterns( + RewritePatternSet &patterns) { + tensor::populateBubbleUpExtractSliceOpPatterns(patterns); +} + void transform::ApplyRewriteTensorOpsAsConstantPatternsOp::populatePatterns( RewritePatternSet &patterns) { ControlFoldFn defaultControlFn = [](OpOperand *fusedOperand) { diff --git a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp index 5edd7a02bc42b..ae8e3528b02e0 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp @@ -6,10 +6,14 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/Transforms/Transforms.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/Interfaces/ValueBoundsOpInterface.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/LogicalResult.h" using namespace mlir; using namespace mlir::tensor; @@ -210,6 +214,214 @@ struct BubbleUpExpandThroughParallelCollapse } }; +/// Converts `tensor.extract_slice(tensor.expand_shape)` to +/// `tensor.expand_shape(tensor.extract_slice)`. +/// +/// For this transformation to be possible, the slice must be fully contiguous +/// within each reassociation group of the expand_shape. A slice is defined as +/// fully contiguous within a reassociation group if after flattening the +/// reassociation group to a single 1D range, then the slice taken out of the +/// group could be defined as a single contiguous subrange within that range. +/// +/// Rank reducing slices are not supported. +/// +/// Example: +/// The transformation is possible because each reassociation group has a +/// contiguous slice (i.e., [2x4->2x4], [2x8->1x5], [4x2x4->1x1x4]). +/// ``` +/// BEFORE: +/// %reshape = tensor.expand_shape %in [[0, 1], [2, 3], [4, 5, 6]] +/// tensor<8x16x32xf32> to tensor<2x4x2x8x4x2x4xf32> +/// %slice = tensor.extract_slice %reshape ... +/// tensor<2x4x2x8x4x2x4xf32> to tensor<2x4x1x5x1x1x4xf32> +/// +/// AFTER: +/// %slice = tensor.extract_slice %in ... +/// tensor<8x16x32xf32> to tensor<8x5x4xf32> +/// %reshape = tensor.expand_shape %slice [[0, 1], [2, 3], [4, 5, 6]] +/// tensor<8x5x4xf32> to tensor<2x4x1x5x1x1x4xf32> +/// ``` +/// +/// Note - this pattern could be extended to be a swap pattern between +/// `tensor.expand_shape` and `tensor.extract_slice`, but is currently +/// implemented only as a bubble up pattern for `tensor.extract_slice`. +struct BubbleUpExpandShapeThroughExtractSlice + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp, + PatternRewriter &rewriter) const override { + auto expandShapeOp = + sliceOp.getSource().getDefiningOp(); + + if (checkPreconditionForBubbleUpExtractSlice(sliceOp, expandShapeOp, + rewriter) + .failed()) + return failure(); + + // The tensor.extract_slice before applying the pattern works on the result + // of the tensor.expand_shape, so variables (i.e. inputs for ExtractSliceOp) + // referring to the state before applying the pattern are named with the + // prefix "expanded", and ones referring to the state after applying the + // pattern are named with the prefix "collapsed". + SmallVector expandedOffsets = sliceOp.getMixedOffsets(); + SmallVector expandedSizes = sliceOp.getMixedSizes(); + SmallVector expandedShape = + getMixedValues(expandShapeOp.getStaticOutputShape(), + expandShapeOp.getOutputShape(), rewriter); + + // Helper variables and function for accumulating the size values. + Location loc = expandShapeOp->getLoc(); + AffineExpr d0, d1, d2; + bindDims(rewriter.getContext(), d0, d1, d2); + // Multiply two integers. + auto mul = [&](OpFoldResult v1, OpFoldResult v2) { + auto mulMap = AffineMap::get(2, 0, {d0 * d1}); + return affine::makeComposedFoldedAffineApply(rewriter, loc, mulMap, + {v1, v2}); + }; + + // Compute new offsets, sizes, and strides for tensor.extract_slice. + // The new tensor.extract_slice will work on a tensor that has has a rank of + // ReassociationIndices.size(). In the loop a single offset, size, and + // stride value is computed per reassociation group. + SmallVector collapsedOffsets, collapsedSizes, + collapsedStrides; + for (const ReassociationIndices &indices : + expandShapeOp.getReassociationIndices()) { + // collapsedSize will hold the size of the single dim that represents the + // reassociation group in the non expanded tensor. + OpFoldResult collapsedSize = rewriter.getIndexAttr(1); + // The reassocGroupSizes and reassocGroupOffsets are used to create an + // affine.linearize_index op to linearize the single offset value required + // for this reassociation group. + SmallVector reassocGroupSizes, reassocGroupOffsets; + + for (long expandedDim : indices) { + // reassocGroupSizes and reassocGroupOffsets can be obtained directly + // from the expanded state, but the collapsed size requires calculation + // as it did not previously exist. + reassocGroupSizes.push_back(expandedShape[expandedDim]); + reassocGroupOffsets.push_back(expandedOffsets[expandedDim]); + collapsedSize = mul(collapsedSize, expandedSizes[expandedDim]); + } + + SmallVector offsetVals = + llvm::map_to_vector(reassocGroupOffsets, [&](OpFoldResult ofr) { + return getValueOrCreateConstantIndexOp(rewriter, loc, ofr); + }); + OpFoldResult collapsedOffset = + rewriter + .create(loc, offsetVals, + reassocGroupSizes, + /*disjoint=*/true) + .getResult(); + collapsedOffsets.push_back(collapsedOffset); + collapsedSizes.push_back(collapsedSize); + + // Only unit stride is supported. + collapsedStrides.push_back(rewriter.getIndexAttr(1)); + } + + // The shape of the result can be obtained from the sizes passed in. + SmallVector dynDims; + SmallVector shape; + dispatchIndexOpFoldResults(expandedSizes, dynDims, shape); + RankedTensorType resultType = RankedTensorType::get( + shape, expandShapeOp.getResultType().getElementType()); + + // Create a new ExtractSliceOp and ExpandShapeOp. + Value newSliceOp = rewriter.create( + loc, expandShapeOp.getSrc(), collapsedOffsets, collapsedSizes, + collapsedStrides); + rewriter.replaceOpWithNewOp( + sliceOp, resultType, newSliceOp, + expandShapeOp.getReassociationIndices(), expandedSizes); + return success(); + } + + // Helper function to check if all the required conditions for the + // tensor.extract_slice to be bubbled up through the tensor.expand_shape are + // met. + LogicalResult + checkPreconditionForBubbleUpExtractSlice(tensor::ExtractSliceOp sliceOp, + tensor::ExpandShapeOp expandShapeOp, + PatternRewriter &rewriter) const { + + if (!expandShapeOp) { + return rewriter.notifyMatchFailure( + sliceOp, "tensor.extract_slice source not produced by expand_shape"); + } + + if (!sliceOp.hasUnitStride()) { + return rewriter.notifyMatchFailure( + sliceOp, "unsupported: non-unit stride. Only contiguous slices can " + "be supported in this transformation."); + } + + SmallVector offsets = sliceOp.getMixedOffsets(); + SmallVector sizes = sliceOp.getMixedSizes(); + + if (static_cast(sliceOp.getResultType().getRank()) != + sizes.size()) { + return rewriter.notifyMatchFailure(sliceOp, + "unimplemented: rank reducing slice"); + } + + SmallVector outputShape = + getMixedValues(expandShapeOp.getStaticOutputShape(), + expandShapeOp.getOutputShape(), rewriter); + + std::function + isZeroOffsetAndFullSize = + [](OpFoldResult offset, OpFoldResult sliceSize, OpFoldResult size) { + if (!isConstantIntValue(offset, 0)) + return false; + FailureOr maybeEqual = + ValueBoundsConstraintSet::areEqual(sliceSize, size); + return llvm::succeeded(maybeEqual) && maybeEqual.value(); + }; + + // Check that the slice is contiguous within each reassociation group. + // The slice is contiguous only if after the first dimension where a non + // unit slice is taken, the slice size on all subsequent dimensions of the + // group is equal to the entire size of the dimension. + // Examples of contiguous slices: + // full sizes: [8, 8, 10] slice offsets: [0, 0, 0] slice sizes: [1, 1, 10] + // full sizes: [5, 10] slice offsets: [3, 0] slice sizes: [2, 10] + // Examples of non contiguous slices: + // full sizes: [8, 8, 10] slice offsets: [0, 0, 0] slice sizes: [1, 2, 5] + // full sizes: [5, 10] slice offsets: [0, 4] slice sizes: [2, 5] + for (const ReassociationIndices &indices : + expandShapeOp.getReassociationIndices()) { + int64_t i = 0; + int64_t e = indices.size(); + // Find the first expanded dim after the first dim with non-unit extracted + // size. + for (; i < e; ++i) { + if (!isConstantIntValue(sizes[indices[i]], 1)) { + // +1 to skip the first non-unit size dim. + i++; + break; + } + } + + // Verify that all subsequent dimensions extract the full size of the + // source tensor. + for (; i < e; ++i) { + int64_t expandedDim = indices[i]; + if (!isZeroOffsetAndFullSize(offsets[expandedDim], sizes[expandedDim], + outputShape[expandedDim])) { + return rewriter.notifyMatchFailure( + sliceOp, "Not a contiguous slice of the expanded tensor."); + } + } + } + + return success(); + } +}; + } // namespace void mlir::tensor::populateReassociativeReshapeFoldingPatterns( @@ -227,3 +439,8 @@ void mlir::tensor::populateBubbleUpExpandShapePatterns( RewritePatternSet &patterns) { patterns.add(patterns.getContext()); } + +void mlir::tensor::populateBubbleUpExtractSliceOpPatterns( + RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); +} diff --git a/mlir/test/Dialect/Linalg/transform-op-fuse.mlir b/mlir/test/Dialect/Linalg/transform-op-fuse.mlir index ac1ca9319d335..32740970ea6eb 100644 --- a/mlir/test/Dialect/Linalg/transform-op-fuse.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-fuse.mlir @@ -278,3 +278,163 @@ module attributes {transform.with_named_sequence} { transform.yield } } + +// ----- + +// CHECK-LABEL: func.func @bubble_up_extract_slice_through_expand_shape +// CHECK: scf.for %[[X:[A-Za-z0-9]+]] = {{.*}} +// CHECK: scf.for %[[Y:[A-Za-z0-9]+]] = {{.*}} +// CHECK: scf.for %[[Z:[A-Za-z0-9]+]] = {{.*}} +// CHECK: %[[LINEAR_IDX:.+]] = affine.linearize_index disjoint [%[[X]], %[[Y]], %[[Z]]] by (2, 3, 10) +// CHECK: %[[SLICE:.+]] = tensor.extract_slice %{{.*}}[%[[LINEAR_IDX]]] [5] [1] : tensor<60xf32> to tensor<5xf32> +// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[SLICE]] {{\[\[}}0, 1, 2]] output_shape [1, 1, 5] +// CHECK: linalg.exp ins(%[[EXPAND]] +func.func @bubble_up_extract_slice_through_expand_shape(%0: tensor<60xf32>) -> tensor<2x3x10xf32> { + %expand = tensor.expand_shape %0 [[0, 1, 2]] output_shape [2, 3, 10] : tensor<60xf32> into tensor<2x3x10xf32> + %empty = tensor.empty() : tensor<2x3x10xf32> + %exp = linalg.exp ins(%expand : tensor<2x3x10xf32>) outs(%empty : tensor<2x3x10xf32>) -> tensor<2x3x10xf32> + return %exp : tensor<2x3x10xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.exp"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %transformed, %loops:3 = transform.structured.fuse %0 [1, 1, 5] interchange [0, 1, 2] apply_cleanup = true : + (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">, !transform.any_op, !transform.any_op) + transform.yield + } +} + +// ----- + +// CHECK-LABEL: func.func @bubble_up_extract_slice_through_expand_shape_full_inner_dim +// CHECK: scf.for %[[X:[A-Za-z0-9]+]] = {{.*}} +// CHECK: scf.for %[[Y:[A-Za-z0-9]+]] = {{.*}} +// CHECK: %[[LINEAR_IDX:.+]] = affine.linearize_index disjoint [%[[X]], %[[Y]]{{.*}} by (3, 4, 10) +// CHECK: %[[SLICE:.+]] = tensor.extract_slice %{{.*}}[%[[LINEAR_IDX]]] [20] [1] : tensor<120xf32> to tensor<20xf32> +// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[SLICE]] {{\[\[}}0, 1, 2]] output_shape [1, 2, 10] +// CHECK: linalg.exp ins(%[[EXPAND]] +func.func @bubble_up_extract_slice_through_expand_shape_full_inner_dim(%0: tensor<120xf32>) -> tensor<3x4x10xf32> { + %expand = tensor.expand_shape %0 [[0, 1, 2]] output_shape [3, 4, 10] : tensor<120xf32> into tensor<3x4x10xf32> + %empty = tensor.empty() : tensor<3x4x10xf32> + %exp = linalg.exp ins(%expand : tensor<3x4x10xf32>) outs(%empty : tensor<3x4x10xf32>) -> tensor<3x4x10xf32> + return %exp : tensor<3x4x10xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.exp"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %transformed, %loops:2 = transform.structured.fuse %0 [1, 2, 0] interchange [0, 1, 2] apply_cleanup = true : + (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">, !transform.any_op) + transform.yield + } +} + +// ----- + +// CHECK-LABEL: func.func @no_bubble_up_extract_slice_through_expand_shape_non_contiguous +// CHECK: tensor.expand_shape +// CHECK: scf.for +// CHECK: scf.for +// CHECK: scf.for +// CHECK: linalg.exp +func.func @no_bubble_up_extract_slice_through_expand_shape_non_contiguous(%0: tensor<120xf32>) -> tensor<3x4x10xf32> { + %expand = tensor.expand_shape %0 [[0, 1, 2]] output_shape [3, 4, 10] : tensor<120xf32> into tensor<3x4x10xf32> + %empty = tensor.empty() : tensor<3x4x10xf32> + %exp = linalg.exp ins(%expand : tensor<3x4x10xf32>) outs(%empty : tensor<3x4x10xf32>) -> tensor<3x4x10xf32> + return %exp : tensor<3x4x10xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.exp"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %transformed, %loops:3 = transform.structured.fuse %0 [1, 2, 5] interchange [0, 1, 2] apply_cleanup = true : + (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">, !transform.any_op, !transform.any_op) + transform.yield + } +} + +// ----- + +// CHECK-LABEL: func.func @bubble_up_extract_slice_through_expand_shape_multiple_expanded_dims +// CHECK: %[[C0:.+]] = arith.constant 0 : index +// CHECK: scf.for %[[X:[A-Za-z0-9]+]] = {{.*}} +// CHECK: scf.for %[[Y:[A-Za-z0-9]+]] = {{.*}} +// CHECK: scf.for %[[Z:[A-Za-z0-9]+]] = {{.*}} +// CHECK: scf.for %[[W:[A-Za-z0-9]+]] = {{.*}} +// CHECK: %[[LINEAR_IDX0:.+]] = affine.linearize_index disjoint [%[[X]], %[[Y]], %[[C0]]] by (3, 4, 10) +// CHECK: %[[LINEAR_IDX1:.+]] = affine.linearize_index disjoint [%[[Z]], %[[W]]] by (7, 8) +// CHECK: %[[SLICE:.+]] = tensor.extract_slice %{{.*}}[%[[LINEAR_IDX0]], %[[LINEAR_IDX1]]] [20, 4] [1, 1] : tensor<120x56xf32> to tensor<20x4xf32> +// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[SLICE]] {{\[\[}}0, 1, 2], [3, 4]] output_shape [1, 2, 10, 1, 4] +// CHECK: linalg.exp ins(%[[EXPAND]] +module { + func.func @bubble_up_extract_slice_through_expand_shape_multiple_expanded_dims(%0: tensor<120x56xf32>) -> tensor<3x4x10x7x8xf32> { + %expand = tensor.expand_shape %0 [[0, 1, 2], [3, 4]] output_shape [3, 4, 10, 7, 8] : tensor<120x56xf32> into tensor<3x4x10x7x8xf32> + %empty = tensor.empty() : tensor<3x4x10x7x8xf32> + %exp = linalg.exp ins(%expand : tensor<3x4x10x7x8xf32>) outs(%empty : tensor<3x4x10x7x8xf32>) -> tensor<3x4x10x7x8xf32> + return %exp : tensor<3x4x10x7x8xf32> + } +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.exp"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %transformed, %loops:4 = transform.structured.fuse %0 [1, 2, 0, 1, 4] interchange [0, 1, 2, 3, 4] apply_cleanup = true : + (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">, !transform.any_op, !transform.any_op, !transform.any_op) + transform.yield + } +} + +// ----- + +// CHECK-LABEL: func.func @bubble_up_extract_slice_through_expand_shape_and_fuse_with_expand_producer +// CHECK: scf.for %[[X:[A-Za-z0-9]+]] = {{.*}} +// CHECK: %[[LINEAR_IDX:.+]] = affine.linearize_index disjoint [%[[X]], {{.*}} by (8, 32) +// CHECK: %[[SLICE:.+]] = tensor.extract_slice %{{.*}}[0, 0, %[[LINEAR_IDX]]] [1, 1800, 32] [1, 1, 1] : tensor<1x1800x256xf32> to tensor<1x1800x32xf32> +// CHECK: %[[ABS:.+]] = linalg.abs ins(%[[SLICE]] +// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[ABS]] {{\[\[}}0], [1], [2, 3]] output_shape [1, 1800, 1, 32] +// CHECK: linalg.exp ins(%[[EXPAND]] +module { + func.func @bubble_up_extract_slice_through_expand_shape_and_fuse_with_expand_producer(%0: tensor<1x1800x256xf32>) -> tensor<1x1800x8x32xf32> { + %empty1 = tensor.empty() : tensor<1x1800x256xf32> + %exp1 = linalg.abs ins(%0 : tensor<1x1800x256xf32>) outs(%empty1 : tensor<1x1800x256xf32>) -> tensor<1x1800x256xf32> + %expand = tensor.expand_shape %exp1 [[0], [1], [2, 3]] output_shape [1, 1800, 8, 32] : tensor<1x1800x256xf32> into tensor<1x1800x8x32xf32> + %empty2 = tensor.empty() : tensor<1x1800x8x32xf32> + %exp2 = linalg.exp ins(%expand : tensor<1x1800x8x32xf32>) outs(%empty2 : tensor<1x1800x8x32xf32>) -> tensor<1x1800x8x32xf32> + return %exp2 : tensor<1x1800x8x32xf32> + } +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.exp"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %transformed, %loops:1 = transform.structured.fuse %0 [0, 0, 1, 0] interchange [0, 1, 2, 3] apply_cleanup = true : + (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">) + transform.yield + } +} + +// ----- + +// CHECK-LABEL: func.func @no_bubble_up_extract_slice_through_expand_shape_on_cleanup_false +// CHECK: %[[EXPAND:.+]] = tensor.expand_shape {{.*}} {{\[\[}}0, 1, 2]] output_shape [2, 3, 10] +// CHECK: scf.for %[[X:[A-Za-z0-9]+]] = {{.*}} +// CHECK: scf.for %[[Y:[A-Za-z0-9]+]] = {{.*}} +// CHECK: scf.for %[[Z:[A-Za-z0-9]+]] = {{.*}} +// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[EXPAND]]{{.*}} [1, 1, 5] [1, 1, 1] : tensor<2x3x10xf32> to tensor<1x1x5xf32> +// CHECK: linalg.exp ins(%[[SLICE]] +func.func @no_bubble_up_extract_slice_through_expand_shape_on_cleanup_false(%0: tensor<60xf32>) -> tensor<2x3x10xf32> { + %expand = tensor.expand_shape %0 [[0, 1, 2]] output_shape [2, 3, 10] : tensor<60xf32> into tensor<2x3x10xf32> + %empty = tensor.empty() : tensor<2x3x10xf32> + %exp = linalg.exp ins(%expand : tensor<2x3x10xf32>) outs(%empty : tensor<2x3x10xf32>) -> tensor<2x3x10xf32> + return %exp : tensor<2x3x10xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.exp"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %transformed, %loops:3 = transform.structured.fuse %0 [1, 1, 5] interchange [0, 1, 2] apply_cleanup = false : + (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">, !transform.any_op, !transform.any_op) + transform.yield + } +} diff --git a/mlir/test/Dialect/Tensor/bubble-up-extract-slice-op.mlir b/mlir/test/Dialect/Tensor/bubble-up-extract-slice-op.mlir new file mode 100644 index 0000000000000..252e7494bff79 --- /dev/null +++ b/mlir/test/Dialect/Tensor/bubble-up-extract-slice-op.mlir @@ -0,0 +1,124 @@ +// RUN: mlir-opt -split-input-file -transform-interpreter %s | FileCheck %s + +// CHECK-LABEL: func.func @bubble_up_extract_slice_through_expand_shape( +// CHECK-SAME: %[[SRC:.*]]: tensor<60xf32>) -> tensor<1x1x5xf32> { +// CHECK: %[[C1:.+]] = arith.constant 5 : index +// CHECK: %[[EXTRACT:.*]] = tensor.extract_slice %[[SRC]][%[[C1]]] [5] [1] : tensor<60xf32> to tensor<5xf32> +// CHECK: %[[EXPAND:.*]] = tensor.expand_shape %[[EXTRACT]] {{\[\[}}0, 1, 2]] output_shape [1, 1, 5] : tensor<5xf32> into tensor<1x1x5xf32> +// CHECK: return %[[EXPAND]] : tensor<1x1x5xf32> + +func.func @bubble_up_extract_slice_through_expand_shape(%src: tensor<60xf32>) -> tensor<1x1x5xf32> { + %expand = tensor.expand_shape %src [[0, 1, 2]] output_shape [2, 3, 10] : tensor<60xf32> into tensor<2x3x10xf32> + %extract = tensor.extract_slice %expand[0, 0, 5][1, 1, 5][1, 1, 1] : tensor<2x3x10xf32> to tensor<1x1x5xf32> + return %extract : tensor<1x1x5xf32> +} + +// CHECK-LABEL: func.func @no_bubble_up_extract_slice_on_non_contiguous( +// CHECK: %[[EXPAND:.*]] = tensor.expand_shape +// CHECK: %[[EXTRACT:.*]] = tensor.extract_slice +// CHECK: return %[[EXTRACT]] + +func.func @no_bubble_up_extract_slice_on_non_contiguous(%src: tensor<60xf32>) -> tensor<1x2x5xf32> { + %expand = tensor.expand_shape %src [[0, 1, 2]] output_shape [2, 3, 10] : tensor<60xf32> into tensor<2x3x10xf32> + %extract = tensor.extract_slice %expand[0, 0, 5][1, 2, 5][1, 1, 1] : tensor<2x3x10xf32> to tensor<1x2x5xf32> + return %extract : tensor<1x2x5xf32> +} + +// CHECK-LABEL: func.func @no_bubble_up_extract_slice_on_stride( +// CHECK: %[[EXPAND:.*]] = tensor.expand_shape +// CHECK: %[[EXTRACT:.*]] = tensor.extract_slice +// CHECK: return %[[EXTRACT]] + +func.func @no_bubble_up_extract_slice_on_stride(%src: tensor<60xf32>) -> tensor<1x1x5xf32> { + %expand = tensor.expand_shape %src [[0, 1, 2]] output_shape [2, 3, 10] : tensor<60xf32> into tensor<2x3x10xf32> + %extract = tensor.extract_slice %expand[0, 0, 5][1, 1, 5][1, 1, 2] : tensor<2x3x10xf32> to tensor<1x1x5xf32> + return %extract : tensor<1x1x5xf32> +} + +// CHECK-LABEL: func.func @no_bubble_up_extract_slice_on_rank_reducing( +// CHECK: %[[EXPAND:.*]] = tensor.expand_shape +// CHECK: %[[EXTRACT:.*]] = tensor.extract_slice +// CHECK: return %[[EXTRACT]] + +func.func @no_bubble_up_extract_slice_on_rank_reducing(%src: tensor<60xf32>) -> tensor<1x5xf32> { + %expand = tensor.expand_shape %src [[0, 1, 2]] output_shape [2, 3, 10] : tensor<60xf32> into tensor<2x3x10xf32> + %extract = tensor.extract_slice %expand[0, 0, 5][1, 1, 5][1, 1, 1] : tensor<2x3x10xf32> to tensor<1x5xf32> + return %extract : tensor<1x5xf32> +} + +// CHECK-LABEL: func.func @bubble_up_extract_slice_through_expand_shape_multiple_expanded_dims( +// CHECK-SAME: %[[SRC:.*]]: tensor<120x56xf32>) -> tensor<1x2x10x1x4xf32> { +// CHECK: %[[C0:.+]] = arith.constant 0 : index +// CHECK: %[[EXTRACT:.*]] = tensor.extract_slice %[[SRC]][%[[C0]], %[[C0]]] [20, 4] [1, 1] : tensor<120x56xf32> to tensor<20x4xf32> +// CHECK: %[[EXPAND:.*]] = tensor.expand_shape %[[EXTRACT]] {{\[\[}}0, 1, 2], [3, 4]] output_shape [1, 2, 10, 1, 4] : tensor<20x4xf32> into tensor<1x2x10x1x4xf32> +// CHECK: return %[[EXPAND]] : tensor<1x2x10x1x4xf32> + +func.func @bubble_up_extract_slice_through_expand_shape_multiple_expanded_dims(%src: tensor<120x56xf32>) -> tensor<1x2x10x1x4xf32> { + %expand = tensor.expand_shape %src [[0, 1, 2], [3, 4]] output_shape [3, 4, 10, 7, 8] : tensor<120x56xf32> into tensor<3x4x10x7x8xf32> + %extract = tensor.extract_slice %expand[0, 0, 0, 0, 0][1, 2, 10, 1, 4][1, 1, 1, 1, 1] : tensor<3x4x10x7x8xf32> to tensor<1x2x10x1x4xf32> + return %extract : tensor<1x2x10x1x4xf32> +} + +// CHECK-LABEL: func.func @bubble_up_extract_slice_with_trailing_full_dims( +// CHECK-SAME: %[[SRC:.*]]: tensor<60xf32>) -> tensor<2x5x2xf32> { +// CHECK: %[[C0:.+]] = arith.constant 0 : index +// CHECK: %[[EXTRACT:.*]] = tensor.extract_slice %[[SRC]][%[[C0]]] [20] [1] : tensor<60xf32> to tensor<20xf32> +// CHECK: %[[EXPAND:.*]] = tensor.expand_shape %[[EXTRACT]] {{\[\[}}0, 1, 2]] output_shape [2, 5, 2] : tensor<20xf32> into tensor<2x5x2xf32> +// CHECK: return %[[EXPAND]] : tensor<2x5x2xf32> +func.func @bubble_up_extract_slice_with_trailing_full_dims(%src: tensor<60xf32>) -> tensor<2x5x2xf32> { + %expand = tensor.expand_shape %src [[0, 1, 2]] output_shape [6, 5, 2] : tensor<60xf32> into tensor<6x5x2xf32> + %extract = tensor.extract_slice %expand[0, 0, 0][2, 5, 2][1, 1, 1] : tensor<6x5x2xf32> to tensor<2x5x2xf32> + return %extract : tensor<2x5x2xf32> +} + +// CHECK-LABEL: func.func @bubble_up_extract_slice_dont_fold_linearize_index( +// CHECK-SAME: %[[SRC:.*]]: tensor<60xf32>, +// CHECK-SAME: %[[OFFSET_0:.*]]: index, +// CHECK-SAME: %[[OFFSET_1:.*]]: index) -> tensor<1x1x5xf32> { +// CHECK: %[[C1:.+]] = arith.constant 5 : index +// CHECK: %[[LINEARIZE:.*]] = affine.linearize_index disjoint {{\[}}%[[OFFSET_0]], %[[OFFSET_1]], %[[C1]]] by (2, 3, 10) : index +// CHECK: %[[EXTRACT:.*]] = tensor.extract_slice %[[SRC]][%[[LINEARIZE]]] [5] [1] : tensor<60xf32> to tensor<5xf32> +// CHECK: %[[EXPAND:.*]] = tensor.expand_shape %[[EXTRACT]] {{\[\[}}0, 1, 2]] output_shape [1, 1, 5] : tensor<5xf32> into tensor<1x1x5xf32> +// CHECK: return %[[EXPAND]] : tensor<1x1x5xf32> +func.func @bubble_up_extract_slice_dont_fold_linearize_index(%src: tensor<60xf32>, %offset_0 : index, %offset_1 : index) -> tensor<1x1x5xf32> { + %expand = tensor.expand_shape %src [[0, 1, 2]] output_shape [2, 3, 10] : tensor<60xf32> into tensor<2x3x10xf32> + %extract = tensor.extract_slice %expand[%offset_0, %offset_1, 5][1, 1, 5][1, 1, 1] : tensor<2x3x10xf32> to tensor<1x1x5xf32> + return %extract : tensor<1x1x5xf32> +} + +// CHECK-LABEL: func.func @bubble_up_extract_slice_not_all_dims_expanded( +// CHECK-SAME: %[[SRC:.*]]: tensor<60x12xf32>) -> tensor<1x1x5x12xf32> { +// CHECK-DAG: %[[C5:.+]] = arith.constant 5 : index +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK: %[[EXTRACT:.*]] = tensor.extract_slice %[[SRC]]{{\[}}%[[C5]], %[[C0]]] [5, 12] [1, 1] : tensor<60x12xf32> to tensor<5x12xf32> +// CHECK: %[[EXPAND:.*]] = tensor.expand_shape %[[EXTRACT]] {{\[\[}}0, 1, 2], [3]] output_shape [1, 1, 5, 12] : tensor<5x12xf32> into tensor<1x1x5x12xf32> +// CHECK: return %[[EXPAND]] : tensor<1x1x5x12xf32> +func.func @bubble_up_extract_slice_not_all_dims_expanded(%src: tensor<60x12xf32>) -> tensor<1x1x5x12xf32> { + %expand = tensor.expand_shape %src [[0, 1, 2], [3]] output_shape [2, 3, 10, 12] : tensor<60x12xf32> into tensor<2x3x10x12xf32> + %extract = tensor.extract_slice %expand[0, 0, 5, 0][1, 1, 5, 12][1, 1, 1, 1] : tensor<2x3x10x12xf32> to tensor<1x1x5x12xf32> + return %extract : tensor<1x1x5x12xf32> +} + +// CHECK-LABEL: func.func @bubble_up_extract_slice_affine_apply_not_folded( +// CHECK-SAME: %[[SRC:.*]]: tensor<60xf32>, +// CHECK-SAME: %[[SLICE_SIZE:.*]]: index) -> tensor { +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[AFFINE_APPLY:.*]] = affine.apply #map(){{\[}}%[[SLICE_SIZE]]] +// CHECK: %[[EXTRACT:.*]] = tensor.extract_slice %[[SRC]]{{\[}}%[[C0]]] {{\[}}%[[AFFINE_APPLY]]] [1] : tensor<60xf32> to tensor +// CHECK: %[[EXPAND:.*]] = tensor.expand_shape %[[EXTRACT]] {{\[\[}}0, 1, 2]] output_shape {{\[}}%[[SLICE_SIZE]], 5, 2] : tensor into tensor +// CHECK: return %[[EXPAND]] : tensor +func.func @bubble_up_extract_slice_affine_apply_not_folded(%src: tensor<60xf32>, %slice_size : index) -> tensor { + %expand = tensor.expand_shape %src [[0, 1, 2]] output_shape [6, 5, 2] : tensor<60xf32> into tensor<6x5x2xf32> + %extract = tensor.extract_slice %expand[0, 0, 0][%slice_size, 5, 2][1, 1, 1] : tensor<6x5x2xf32> to tensor + return %extract : tensor +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%root: !transform.any_op {transform.readonly}) { + %func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func"> + transform.apply_patterns to %func_op { + transform.apply_patterns.tensor.bubble_up_extract_slice + } : !transform.op<"func.func"> + transform.yield + } +}