Skip to content

Commit 1ff6d9d

Browse files
authored
[mlir][linalg] Take artificial padding into account for pack/unpack folding. (#150272)
The revision only folds the tensor.pad/extract_slice op into linalg.pack/unpack ops only when it is safe to fold. It is not valid to have artificial padding. The documentation improvement and verifier update will be done in a separate PR (i.e., #149624). The revision is a step towards it. --------- Signed-off-by: hanhanW <[email protected]>
1 parent efe1aa8 commit 1ff6d9d

File tree

6 files changed

+270
-52
lines changed

6 files changed

+270
-52
lines changed

mlir/include/mlir/Dialect/Linalg/IR/Linalg.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#define MLIR_DIALECT_LINALG_IR_LINALG_H
1111

1212
#include "mlir/Bytecode/BytecodeOpInterface.h"
13+
#include "mlir/Dialect/Tensor/IR/Tensor.h"
1314
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
1415
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
1516
#include "mlir/IR/AffineExpr.h"
@@ -144,4 +145,17 @@ std::pair<int64_t, int64_t> getFmrFromWinogradConv2DFmr(WinogradConv2DFmr fmr);
144145
#define GET_OP_CLASSES
145146
#include "mlir/Dialect/Linalg/IR/LinalgRelayoutOps.h.inc"
146147

148+
namespace mlir {
149+
namespace linalg {
150+
151+
/// Returns the outer shape in the packed domain before applying the
152+
/// transposition.
153+
template <typename OpTy,
154+
typename = std::enable_if_t<std::is_same_v<OpTy, linalg::PackOp> ||
155+
std::is_same_v<OpTy, linalg::UnPackOp>>>
156+
SmallVector<int64_t> getPackedOuterShapeWithoutTransposition(OpTy packOrUnPack);
157+
158+
} // namespace linalg
159+
} // namespace mlir
160+
147161
#endif // MLIR_DIALECT_LINALG_IR_LINALG_H

mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,10 @@ def Linalg_UnPackOp : Linalg_RelayoutOp<"unpack"> {
360360
ArrayRef<int64_t> innerPermutation,
361361
ArrayRef<int64_t> outerPermutation);
362362

363+
/// Returns true if it is statically known that the `sliceOp` result shape
364+
/// is compatible with the `unPackOp`. I.e., it does not drop any tile.
365+
bool canFoldSliceOp(tensor::ExtractSliceOp sliceOp);
366+
363367
/// Check if this UnPackOp is like a simple unpad operation.
364368
/// In other words, this operation:
365369
/// 1. drops useless dimensions (dimension of size 1), and

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 50 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4492,6 +4492,29 @@ Speculation::Speculatability ElementwiseOp::getSpeculatability() {
44924492
//===----------------------------------------------------------------------===//
44934493
// PackOp/UnPackOp Common
44944494
//===----------------------------------------------------------------------===//
4495+
4496+
template <typename OpTy, typename>
4497+
SmallVector<int64_t>
4498+
getPackedOuterShapeWithoutTransposition(OpTy packOrUnPack) {
4499+
RankedTensorType packedType = (std::is_same<OpTy, PackOp>::value)
4500+
? packOrUnPack.getDestType()
4501+
: packOrUnPack.getSourceType();
4502+
RankedTensorType unpackedType = (std::is_same<OpTy, PackOp>::value)
4503+
? packOrUnPack.getSourceType()
4504+
: packOrUnPack.getDestType();
4505+
SmallVector<int64_t> result(
4506+
packedType.getShape().take_front(unpackedType.getRank()));
4507+
if (!packOrUnPack.getOuterDimsPerm().empty()) {
4508+
applyPermutationToVector(
4509+
result, invertPermutationVector(packOrUnPack.getOuterDimsPerm()));
4510+
}
4511+
return result;
4512+
}
4513+
template SmallVector<int64_t>
4514+
getPackedOuterShapeWithoutTransposition<PackOp>(PackOp);
4515+
template SmallVector<int64_t>
4516+
getPackedOuterShapeWithoutTransposition<UnPackOp>(UnPackOp);
4517+
44954518
// Given the (potentially) updated packed type, `newPackedTy`, generates an
44964519
// updated mixed-tile-sizes attribute. A tile size is updated only
44974520
// when:
@@ -5452,11 +5475,7 @@ LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
54525475
if (unPackOp->hasOneUse()) {
54535476
auto extractSliceUser =
54545477
dyn_cast<tensor::ExtractSliceOp>(*unPackOp->getUsers().begin());
5455-
if (extractSliceUser &&
5456-
areAllConstantIntValue(extractSliceUser.getMixedOffsets(), 0) &&
5457-
areAllConstantIntValue(extractSliceUser.getMixedStrides(), 1) &&
5458-
extractSliceUser.getSourceType().getRank() ==
5459-
extractSliceUser.getResultType().getRank()) {
5478+
if (extractSliceUser && unPackOp.canFoldSliceOp(extractSliceUser)) {
54605479
OpBuilder::InsertionGuard g(rewriter);
54615480
rewriter.setInsertionPoint(unPackOp);
54625481
auto newDest = tensor::ExtractSliceOp::create(
@@ -5499,6 +5518,32 @@ LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
54995518
return failure();
55005519
}
55015520

5521+
bool UnPackOp::canFoldSliceOp(tensor::ExtractSliceOp sliceOp) {
5522+
// Rank-reduced folding is not supported.
5523+
if (sliceOp.getResultType().getRank() != this->getDestType().getRank())
5524+
return false;
5525+
if (!areAllConstantIntValue(sliceOp.getMixedOffsets(), 0) ||
5526+
!areAllConstantIntValue(sliceOp.getMixedStrides(), 1))
5527+
return false;
5528+
RankedTensorType unpackedTypeAfterFold = sliceOp.getResultType();
5529+
SmallVector<int64_t> outerShapeWithoutTranspose =
5530+
getPackedOuterShapeWithoutTransposition(*this);
5531+
for (auto [pos, tileSize] :
5532+
llvm::zip_equal(this->getInnerDimsPos(), this->getStaticInnerTiles())) {
5533+
if (unpackedTypeAfterFold.isDynamicDim(pos))
5534+
return false;
5535+
if (ShapedType::isDynamic(outerShapeWithoutTranspose[pos]))
5536+
return false;
5537+
if (ShapedType::isDynamic(tileSize))
5538+
return false;
5539+
int64_t paddingSize = outerShapeWithoutTranspose[pos] * tileSize -
5540+
unpackedTypeAfterFold.getDimSize(pos);
5541+
if (paddingSize >= tileSize)
5542+
return false;
5543+
}
5544+
return true;
5545+
}
5546+
55025547
bool UnPackOp::isLikeUnPad() {
55035548
RankedTensorType packedTensorType = getSourceType();
55045549
return isLikePadUnPad(*this, packedTensorType);

mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,33 @@ struct FoldPadWithPackOp : public OpRewritePattern<PackOp> {
220220
if (!isEqualConstantIntOrValue(paddingValue, constantPaddingValue))
221221
return failure();
222222

223+
// Folding is not allowed if it were to introduce artificial padding.
224+
// Folding is also disabled in the case of dynamic dimensions and/or tile
225+
// sizes - that is because it would be impossible to compute the padding
226+
// size and hence to establish whether "artificial" padding would be
227+
// created.
228+
RankedTensorType unpackedType = packOp.getSourceType();
229+
SmallVector<int64_t> outerShapeWithoutTranspose =
230+
getPackedOuterShapeWithoutTransposition(packOp);
231+
for (auto [pos, tileSize, high] :
232+
llvm::zip_equal(packOp.getInnerDimsPos(), packOp.getStaticInnerTiles(),
233+
padOp.getMixedHighPad())) {
234+
if (unpackedType.isDynamicDim(pos))
235+
return failure();
236+
if (ShapedType::isDynamic(outerShapeWithoutTranspose[pos]))
237+
return failure();
238+
if (ShapedType::isDynamic(tileSize))
239+
return failure();
240+
std::optional<int64_t> cstHigh = getConstantIntValue(high);
241+
if (!cstHigh)
242+
return failure();
243+
int64_t paddingSize = outerShapeWithoutTranspose[pos] * tileSize -
244+
unpackedType.getDimSize(pos);
245+
// Do not fold the op if it requires artificial padding.
246+
if (paddingSize + cstHigh.value() >= tileSize)
247+
return failure();
248+
}
249+
223250
rewriter.replaceOpWithNewOp<PackOp>(
224251
packOp, padOp.getSource(), packOp.getDest(), packOp.getInnerDimsPos(),
225252
packOp.getMixedTiles(), constantPaddingValue,
@@ -251,17 +278,8 @@ struct FoldUnpackWithExtractSliceOp
251278
if (controlFn && !controlFn(&sliceOp.getSourceMutable()))
252279
return failure();
253280

254-
if (sliceOp.getResultType().getRank() != unpackOp.getDestType().getRank()) {
255-
return rewriter.notifyMatchFailure(
256-
sliceOp, "rank-reduced folding is not supported");
257-
}
258-
259-
// Check all offsets are zeros, and all strides are ones.
260-
if (!areAllConstantIntValue(sliceOp.getMixedOffsets(), 0) ||
261-
!areAllConstantIntValue(sliceOp.getMixedStrides(), 1)) {
262-
return rewriter.notifyMatchFailure(
263-
sliceOp, "expects offsets to be 0s and strides to be 1s");
264-
}
281+
if (!unpackOp.canFoldSliceOp(sliceOp))
282+
return failure();
265283

266284
// Create a new empty output tensor.
267285
Type elementType = unpackOp.getDestType().getElementType();

mlir/test/Dialect/Linalg/canonicalize.mlir

Lines changed: 64 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1889,31 +1889,84 @@ func.func @fold_cast_unpack_dynamic_tile_size(
18891889
// linalg.unpack + tensor.extract_slice
18901890
//===----------------------------------------------------------------------===//
18911891

1892-
func.func @fold_extract_slice_into_unpack(
1893-
%src : tensor<28x2x?x16x16xf32>, %dest : tensor<28x32x?xf32>, %size : index
1894-
) -> tensor<28x28x?xf32> {
1892+
func.func @fold_extract_slice_into_unpack_slicing_trailing_dim(%src : tensor<28x2x1x16x16xf32>, %dest : tensor<28x28x15xf32>) -> tensor<28x28x10xf32> {
18951893
%unpack = linalg.unpack %src
18961894
outer_dims_perm = [0, 1, 2]
18971895
inner_dims_pos = [1, 2]
18981896
inner_tiles = [16, 16]
1899-
into %dest : tensor<28x2x?x16x16xf32> -> tensor<28x32x?xf32>
1897+
into %dest : tensor<28x2x1x16x16xf32> -> tensor<28x28x15xf32>
19001898
%extracted_slice = tensor.extract_slice %unpack
1901-
[0, 0, 0] [28, 28, %size] [1, 1, 1] : tensor<28x32x?xf32> to tensor<28x28x?xf32>
1902-
return %extracted_slice : tensor<28x28x?xf32>
1899+
[0, 0, 0] [28, 28, 10] [1, 1, 1] : tensor<28x28x15xf32> to tensor<28x28x10xf32>
1900+
return %extracted_slice : tensor<28x28x10xf32>
19031901
}
1902+
// CHECK-LABEL: func @fold_extract_slice_into_unpack_slicing_trailing_dim
1903+
// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
1904+
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
1905+
// CHECK: %[[DEST_SLICE:.+]] = tensor.extract_slice %[[DEST]]
1906+
// CHECK-SAME: [0, 0, 0] [28, 28, 10] [1, 1, 1]
1907+
// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[SRC]]
1908+
// CHECK-SAME: into %[[DEST_SLICE]]
1909+
// CHECK: return %[[UNPACK]]
1910+
1911+
// -----
1912+
1913+
// The available dimension size is [17, 32], because CeilDiv(%d1, 16) == 2.
19041914

1905-
// CHECK-LABEL: func @fold_extract_slice_into_unpack
1906-
// CHECK-SAME: %[[SRC:.+]]: tensor<28x2x?x16x16xf32>
1907-
// CHECK-SAME: %[[DEST:.+]]: tensor<28x32x?xf32>
1908-
// CHECK-SAME: %[[SIZE:.+]]: index
1915+
func.func @fold_extract_slice_into_unpack_slicing_dim_1(%src : tensor<28x2x1x16x16xf32>, %dest : tensor<28x28x15xf32>) -> tensor<28x17x15xf32> {
1916+
%unpack = linalg.unpack %src
1917+
inner_dims_pos = [1, 2]
1918+
inner_tiles = [16, 16]
1919+
into %dest : tensor<28x2x1x16x16xf32> -> tensor<28x28x15xf32>
1920+
%extracted_slice = tensor.extract_slice %unpack
1921+
[0, 0, 0] [28, 17, 15] [1, 1, 1] : tensor<28x28x15xf32> to tensor<28x17x15xf32>
1922+
return %extracted_slice : tensor<28x17x15xf32>
1923+
}
1924+
// CHECK-LABEL: func @fold_extract_slice_into_unpack_slicing_dim_1(
1925+
// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
1926+
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
19091927
// CHECK: %[[DEST_SLICE:.+]] = tensor.extract_slice %[[DEST]]
1910-
// CHECK-SAME: [0, 0, 0] [28, 28, %[[SIZE]]] [1, 1, 1]
1928+
// CHECK-SAME: [0, 0, 0] [28, 17, 15] [1, 1, 1]
19111929
// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[SRC]]
19121930
// CHECK-SAME: into %[[DEST_SLICE]]
19131931
// CHECK: return %[[UNPACK]]
19141932

19151933
// -----
19161934

1935+
// The available dimension size is [17, 32], because CeilDiv(%d1, 16) == 2.
1936+
1937+
func.func @no_fold_extract_slice_into_unpack_artificial_padding(%src : tensor<28x2x1x16x16xf32>, %dest : tensor<28x28x15xf32>) -> tensor<28x16x15xf32> {
1938+
%unpack = linalg.unpack %src
1939+
inner_dims_pos = [1, 2]
1940+
inner_tiles = [16, 16]
1941+
into %dest : tensor<28x2x1x16x16xf32> -> tensor<28x28x15xf32>
1942+
%extracted_slice = tensor.extract_slice %unpack
1943+
[0, 0, 0] [28, 16, 15] [1, 1, 1] : tensor<28x28x15xf32> to tensor<28x16x15xf32>
1944+
return %extracted_slice : tensor<28x16x15xf32>
1945+
}
1946+
// CHECK-LABEL: func @no_fold_extract_slice_into_unpack_artificial_padding
1947+
// CHECK: linalg.unpack
1948+
// CHECK: tensor.extract_slice
1949+
1950+
// -----
1951+
1952+
func.func @no_fold_extract_slice_into_unpack_dynamic(
1953+
%src : tensor<28x2x?x16x16xf32>, %dest : tensor<28x32x?xf32>, %size : index
1954+
) -> tensor<28x28x?xf32> {
1955+
%unpack = linalg.unpack %src
1956+
outer_dims_perm = [0, 1, 2]
1957+
inner_dims_pos = [1, 2]
1958+
inner_tiles = [16, 16]
1959+
into %dest : tensor<28x2x?x16x16xf32> -> tensor<28x32x?xf32>
1960+
%extracted_slice = tensor.extract_slice %unpack
1961+
[0, 0, 0] [28, 28, %size] [1, 1, 1] : tensor<28x32x?xf32> to tensor<28x28x?xf32>
1962+
return %extracted_slice : tensor<28x28x?xf32>
1963+
}
1964+
// CHECK-LABEL: func @no_fold_extract_slice_into_unpack_dynamic
1965+
// CHECK: linalg.unpack
1966+
// CHECK: tensor.extract_slice
1967+
1968+
// -----
1969+
19171970
func.func @no_fold_extract_slice_into_unpack_rank_reducing(
19181971
%src : tensor<28x2x16xf32>, %dest : tensor<28x32xf32>
19191972
) -> tensor<28xf32> {

0 commit comments

Comments
 (0)