Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions mlir/include/mlir/Dialect/Linalg/IR/Linalg.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#define MLIR_DIALECT_LINALG_IR_LINALG_H

#include "mlir/Bytecode/BytecodeOpInterface.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/IR/AffineExpr.h"
Expand Down Expand Up @@ -89,6 +90,11 @@ Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim);
OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val,
int64_t dim);

/// Returns the outer shape in the packed domain before applying the
/// transposition.
template <typename OpTy>
SmallVector<int64_t> getPackedOuterShapeWithoutTransposition(OpTy packOrUnPack);

} // namespace linalg
} // namespace mlir

Expand Down
4 changes: 4 additions & 0 deletions mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,10 @@ def Linalg_UnPackOp : Linalg_RelayoutOp<"unpack"> {
ArrayRef<int64_t> innerPermutation,
ArrayRef<int64_t> outerPermutation);

/// Returns true if it is statically known that the `sliceOp` result shape
/// is compatible with the `unPackOp`. I.e., it does not drop any tile.
bool canFoldSliceOp(tensor::ExtractSliceOp sliceOp);

/// Check if this UnPackOp is like a simple unpad operation.
/// In other words, this operation:
/// 1. drops useless dimensions (dimension of size 1), and
Expand Down
55 changes: 50 additions & 5 deletions mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4490,6 +4490,29 @@ Speculation::Speculatability ElementwiseOp::getSpeculatability() {
//===----------------------------------------------------------------------===//
// PackOp/UnPackOp Common
//===----------------------------------------------------------------------===//

template <typename OpTy>
SmallVector<int64_t>
getPackedOuterShapeWithoutTransposition(OpTy packOrUnPack) {
RankedTensorType packedType = (std::is_same<OpTy, PackOp>::value)
? packOrUnPack.getDestType()
: packOrUnPack.getSourceType();
RankedTensorType unpackedType = (std::is_same<OpTy, PackOp>::value)
? packOrUnPack.getSourceType()
: packOrUnPack.getDestType();
SmallVector<int64_t> result(
packedType.getShape().take_front(unpackedType.getRank()));
if (!packOrUnPack.getOuterDimsPerm().empty()) {
applyPermutationToVector(
result, invertPermutationVector(packOrUnPack.getOuterDimsPerm()));
}
return result;
}
template SmallVector<int64_t>
getPackedOuterShapeWithoutTransposition<PackOp>(PackOp);
template SmallVector<int64_t>
getPackedOuterShapeWithoutTransposition<UnPackOp>(UnPackOp);

// Given the (potentially) updated packed type, `newPackedTy`, generates an
// updated mixed-tile-sizes attribute. A tile size is updated only
// when:
Expand Down Expand Up @@ -5447,11 +5470,7 @@ LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
if (unPackOp->hasOneUse()) {
auto extractSliceUser =
dyn_cast<tensor::ExtractSliceOp>(*unPackOp->getUsers().begin());
if (extractSliceUser &&
areAllConstantIntValue(extractSliceUser.getMixedOffsets(), 0) &&
areAllConstantIntValue(extractSliceUser.getMixedStrides(), 1) &&
extractSliceUser.getSourceType().getRank() ==
extractSliceUser.getResultType().getRank()) {
if (extractSliceUser && unPackOp.canFoldSliceOp(extractSliceUser)) {
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(unPackOp);
auto newDest = rewriter.create<tensor::ExtractSliceOp>(
Expand Down Expand Up @@ -5494,6 +5513,32 @@ LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
return failure();
}

bool UnPackOp::canFoldSliceOp(tensor::ExtractSliceOp sliceOp) {
// Rank-reduced folding is not supported.
if (sliceOp.getResultType().getRank() != this->getDestType().getRank())
return false;
if (!areAllConstantIntValue(sliceOp.getMixedOffsets(), 0) ||
!areAllConstantIntValue(sliceOp.getMixedStrides(), 1))
return false;
RankedTensorType unpackedType = sliceOp.getResultType();
SmallVector<int64_t> outerShapeWithoutTranspose =
getPackedOuterShapeWithoutTransposition(*this);
for (auto [pos, tileSize] :
llvm::zip_equal(this->getInnerDimsPos(), this->getStaticInnerTiles())) {
if (unpackedType.isDynamicDim(pos))
return false;
if (ShapedType::isDynamic(outerShapeWithoutTranspose[pos]))
return false;
if (ShapedType::isDynamic(tileSize))
return false;
int64_t paddingSize = outerShapeWithoutTranspose[pos] * tileSize -
unpackedType.getDimSize(pos);
if (paddingSize >= tileSize)
return false;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be good to add a test where a non-trailing dim is "sliced", e.g.

func.func @fold_extract_slice_into_unpack(
    %src : tensor<28x2x1x16x16xf32>, %dest : tensor<28x28x15xf32>, %size : index
) -> tensor<28x16x15xf32> {
  %unpack = linalg.unpack %src
      outer_dims_perm = [0, 1, 2]
      inner_dims_pos = [1, 2]
      inner_tiles = [16, 16]
      into %dest : tensor<28x2x1x16x16xf32> -> tensor<28x28x15xf32>
  %extracted_slice = tensor.extract_slice %unpack
      [0, 0, 0] [28, 16, 15] [1, 1, 1] : tensor<28x28x15xf32> to tensor<28x16x15xf32>
  return %extracted_slice : tensor<28x16x15xf32>
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, if I'm not missing something, I don't see tests for when we hit this case where we would need artificial padding. It would be nice to add these negative cases :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I missed that!

I think @banach-space's example is a negative test. Anyway, I added two more tests that should address your comments.

}
return true;
}

bool UnPackOp::isLikeUnPad() {
RankedTensorType packedTensorType = getSourceType();
return isLikePadUnPad(*this, packedTensorType);
Expand Down
38 changes: 27 additions & 11 deletions mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,31 @@ struct FoldPadWithPackOp : public OpRewritePattern<PackOp> {
if (!isEqualConstantIntOrValue(paddingValue, constantPaddingValue))
return failure();

// Folding is not allowed if it introduces artificial padding. It is not
// safe to fold the ops if any dynamic dimension or tile size is present,
// because we can not infer the padding size.
RankedTensorType unpackedType = packOp.getSourceType();
SmallVector<int64_t> outerShapeWithoutTranspose =
getPackedOuterShapeWithoutTransposition(packOp);
for (auto [pos, tileSize, high] :
llvm::zip_equal(packOp.getInnerDimsPos(), packOp.getStaticInnerTiles(),
padOp.getMixedHighPad())) {
if (unpackedType.isDynamicDim(pos))
return failure();
if (ShapedType::isDynamic(outerShapeWithoutTranspose[pos]))
return failure();
if (ShapedType::isDynamic(tileSize))
return failure();
std::optional<int64_t> cstHigh = getConstantIntValue(high);
if (!cstHigh)
return failure();
int64_t paddingSize = outerShapeWithoutTranspose[pos] * tileSize -
unpackedType.getDimSize(pos);
// Do not fold the op if it requires artificial padding.
if (paddingSize + cstHigh.value() >= tileSize)
return failure();
}

rewriter.replaceOpWithNewOp<PackOp>(
packOp, padOp.getSource(), packOp.getDest(), packOp.getInnerDimsPos(),
packOp.getMixedTiles(), constantPaddingValue,
Expand Down Expand Up @@ -251,17 +276,8 @@ struct FoldUnpackWithExtractSliceOp
if (controlFn && !controlFn(&sliceOp.getSourceMutable()))
return failure();

if (sliceOp.getResultType().getRank() != unpackOp.getDestType().getRank()) {
return rewriter.notifyMatchFailure(
sliceOp, "rank-reduced folding is not supported");
}

// Check all offsets are zeros, and all strides are ones.
if (!areAllConstantIntValue(sliceOp.getMixedOffsets(), 0) ||
!areAllConstantIntValue(sliceOp.getMixedStrides(), 1)) {
return rewriter.notifyMatchFailure(
sliceOp, "expects offsets to be 0s and strides to be 1s");
}
if (!unpackOp.canFoldSliceOp(sliceOp))
return failure();

// Create a new empty output tensor.
Type elementType = unpackOp.getDestType().getElementType();
Expand Down
37 changes: 27 additions & 10 deletions mlir/test/Dialect/Linalg/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1890,30 +1890,47 @@ func.func @fold_cast_unpack_dynamic_tile_size(
//===----------------------------------------------------------------------===//

func.func @fold_extract_slice_into_unpack(
%src : tensor<28x2x?x16x16xf32>, %dest : tensor<28x32x?xf32>, %size : index
) -> tensor<28x28x?xf32> {
%src : tensor<28x2x1x16x16xf32>, %dest : tensor<28x28x15xf32>, %size : index
) -> tensor<28x28x10xf32> {
%unpack = linalg.unpack %src
outer_dims_perm = [0, 1, 2]
inner_dims_pos = [1, 2]
inner_tiles = [16, 16]
into %dest : tensor<28x2x?x16x16xf32> -> tensor<28x32x?xf32>
into %dest : tensor<28x2x1x16x16xf32> -> tensor<28x28x15xf32>
%extracted_slice = tensor.extract_slice %unpack
[0, 0, 0] [28, 28, %size] [1, 1, 1] : tensor<28x32x?xf32> to tensor<28x28x?xf32>
return %extracted_slice : tensor<28x28x?xf32>
[0, 0, 0] [28, 28, 10] [1, 1, 1] : tensor<28x28x15xf32> to tensor<28x28x10xf32>
return %extracted_slice : tensor<28x28x10xf32>
}

// CHECK-LABEL: func @fold_extract_slice_into_unpack
// CHECK-SAME: %[[SRC:.+]]: tensor<28x2x?x16x16xf32>
// CHECK-SAME: %[[DEST:.+]]: tensor<28x32x?xf32>
// CHECK-SAME: %[[SIZE:.+]]: index
// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[SIZE:[a-zA-Z0-9]+]]
// CHECK: %[[DEST_SLICE:.+]] = tensor.extract_slice %[[DEST]]
// CHECK-SAME: [0, 0, 0] [28, 28, %[[SIZE]]] [1, 1, 1]
// CHECK-SAME: [0, 0, 0] [28, 28, 10] [1, 1, 1]
// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[SRC]]
// CHECK-SAME: into %[[DEST_SLICE]]
// CHECK: return %[[UNPACK]]

// -----

func.func @no_fold_extract_slice_into_unpack_dynamic(
%src : tensor<28x2x?x16x16xf32>, %dest : tensor<28x32x?xf32>, %size : index
) -> tensor<28x28x?xf32> {
%unpack = linalg.unpack %src
outer_dims_perm = [0, 1, 2]
inner_dims_pos = [1, 2]
inner_tiles = [16, 16]
into %dest : tensor<28x2x?x16x16xf32> -> tensor<28x32x?xf32>
%extracted_slice = tensor.extract_slice %unpack
[0, 0, 0] [28, 28, %size] [1, 1, 1] : tensor<28x32x?xf32> to tensor<28x28x?xf32>
return %extracted_slice : tensor<28x28x?xf32>
}
// CHECK-LABEL: func @no_fold_extract_slice_into_unpack_dynamic
// CHECK: linalg.unpack
// CHECK: tensor.extract_slice

// -----

func.func @no_fold_extract_slice_into_unpack_rank_reducing(
%src : tensor<28x2x16xf32>, %dest : tensor<28x32xf32>
) -> tensor<28xf32> {
Expand Down
60 changes: 44 additions & 16 deletions mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
Original file line number Diff line number Diff line change
@@ -1,22 +1,32 @@
// RUN: mlir-opt -split-input-file -test-linalg-transform-patterns=test-fold-into-pack-and-unpack %s | FileCheck %s
// RUN: mlir-opt -split-input-file -test-linalg-transform-patterns=test-fold-into-pack-and-unpack-control %s | FileCheck %s --check-prefix=CONTROL

func.func @fold_unpack_slice(%arg0 : tensor<?x?x8x4xf32>, %arg1 : tensor<?x?xf32>,
func.func @fold_unpack_slice(%arg0 : tensor<2082x1x8x32xf32>) -> tensor<16649x16xf32> {
%empty = tensor.empty() : tensor<16656x16xf32>
%0 = linalg.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %empty
: tensor<2082x1x8x32xf32> -> tensor<16656x16xf32>
%1 = tensor.extract_slice %0[0, 0] [16649, 16] [1, 1] : tensor<16656x16xf32> to tensor<16649x16xf32>
return %1 : tensor<16649x16xf32>
}
// CHECK-LABEL: func @fold_unpack_slice(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<16649x16xf32>
// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[ARG0]] inner_dims_pos = [0, 1] inner_tiles = [8, 32]
// CHECK-SAME: into %[[INIT]]
// CHECK: return %[[UNPACK]]

// -----

func.func @nofold_dynamic_unpack_slice(%arg0 : tensor<?x?x8x4xf32>, %arg1 : tensor<?x?xf32>,
%arg2 : index, %arg3 : index) -> tensor<?x?xf32> {
%0 = linalg.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [8, 4] into %arg1
: tensor<?x?x8x4xf32> -> tensor<?x?xf32>
%1 = tensor.extract_slice %0[0, 0] [%arg2, %arg3] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
return %1 : tensor<?x?xf32>
}
// CHECK: func @fold_unpack_slice(
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x8x4xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index
// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index
// CHECK: %[[INIT:.+]] = tensor.empty(%[[ARG2]], %[[ARG3]]) : tensor<?x?xf32>
// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[ARG0]] inner_dims_pos = [0, 1] inner_tiles = [8, 4]
// CHECK-SAME: into %[[INIT]]
// CHECK: return %[[UNPACK]]
// CHECK-LABEL: func @nofold_dynamic_unpack_slice(
// CHECK: linalg.unpack
// CHECK: tensor.extract_slice

// -----

Expand Down Expand Up @@ -59,13 +69,13 @@ func.func @nofold_unpack_slice_rank_reduced(%arg0 : tensor<?x?x8x4xf32>, %arg1 :

// -----

func.func @pad_pack(%src: tensor<16641x16xf32>) -> tensor<2082x1x8x32xf32> {
func.func @pad_pack(%src: tensor<16649x16xf32>) -> tensor<2082x1x8x32xf32> {
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f32
%padded = tensor.pad %src low[0, 0] high[15, 0] {
%padded = tensor.pad %src low[0, 0] high[7, 0] {
^bb0(%arg0: index, %arg1: index):
tensor.yield %cst : f32
} : tensor<16641x16xf32> to tensor<16656x16xf32>
} : tensor<16649x16xf32> to tensor<16656x16xf32>
%empty = tensor.empty() : tensor<2082x1x8x32xf32>
%pack = linalg.pack %padded padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %empty
: tensor<16656x16xf32> -> tensor<2082x1x8x32xf32>
Expand All @@ -81,10 +91,10 @@ func.func @pad_pack(%src: tensor<16641x16xf32>) -> tensor<2082x1x8x32xf32> {

// -----

func.func @nofold_pad_pack(%src: tensor<16641x16xf32>) -> tensor<2082x1x8x32xf32> {
func.func @nofold_pad_pack_artificial_padding(%src: tensor<16641x16xf32>) -> tensor<2082x1x8x32xf32> {
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f32
%padded = tensor.pad %src nofold low[0, 0] high[15, 0] {
%padded = tensor.pad %src low[0, 0] high[15, 0] {
^bb0(%arg0: index, %arg1: index):
tensor.yield %cst : f32
} : tensor<16641x16xf32> to tensor<16656x16xf32>
Expand All @@ -93,7 +103,25 @@ func.func @nofold_pad_pack(%src: tensor<16641x16xf32>) -> tensor<2082x1x8x32xf32
: tensor<16656x16xf32> -> tensor<2082x1x8x32xf32>
return %pack : tensor<2082x1x8x32xf32>
}
// CHECK-LABEL: func.func @nofold_pad_pack
// CHECK-LABLE: func.func @nofold_pad_pack_artificial_padding(
// CHECK: tensor.pad
// CHECK: linalg.pack

// -----

func.func @nofold_pad_pack(%src: tensor<16649x16xf32>) -> tensor<2082x1x8x32xf32> {
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f32
%padded = tensor.pad %src nofold low[0, 0] high[7, 0] {
^bb0(%arg0: index, %arg1: index):
tensor.yield %cst : f32
} : tensor<16649x16xf32> to tensor<16656x16xf32>
%empty = tensor.empty() : tensor<2082x1x8x32xf32>
%pack = linalg.pack %padded padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %empty
: tensor<16656x16xf32> -> tensor<2082x1x8x32xf32>
return %pack : tensor<2082x1x8x32xf32>
}
// CHECK-LABEL: func.func @nofold_pad_pack(
// CHECK: tensor.pad
// CHECK: linalg.pack

Expand Down
Loading