Skip to content

Commit 601f21a

Browse files
committed
address comments
Signed-off-by: hanhanW <[email protected]>
1 parent 2fc0316 commit 601f21a

File tree

5 files changed

+82
-31
lines changed

5 files changed

+82
-31
lines changed

mlir/include/mlir/Dialect/Linalg/IR/Linalg.h

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -90,11 +90,6 @@ Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim);
9090
OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val,
9191
int64_t dim);
9292

93-
/// Returns the outer shape in the packed domain before applying the
94-
/// transposition.
95-
template <typename OpTy>
96-
SmallVector<int64_t> getPackedOuterShapeWithoutTransposition(OpTy packOrUnPack);
97-
9893
} // namespace linalg
9994
} // namespace mlir
10095

@@ -150,4 +145,17 @@ std::pair<int64_t, int64_t> getFmrFromWinogradConv2DFmr(WinogradConv2DFmr fmr);
150145
#define GET_OP_CLASSES
151146
#include "mlir/Dialect/Linalg/IR/LinalgRelayoutOps.h.inc"
152147

148+
namespace mlir {
149+
namespace linalg {
150+
151+
/// Returns the outer shape in the packed domain before applying the
152+
/// transposition.
153+
template <typename OpTy,
154+
typename = std::enable_if_t<std::is_same_v<OpTy, linalg::PackOp> ||
155+
std::is_same_v<OpTy, linalg::UnPackOp>>>
156+
SmallVector<int64_t> getPackedOuterShapeWithoutTransposition(OpTy packOrUnPack);
157+
158+
} // namespace linalg
159+
} // namespace mlir
160+
153161
#endif // MLIR_DIALECT_LINALG_IR_LINALG_H

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4491,7 +4491,7 @@ Speculation::Speculatability ElementwiseOp::getSpeculatability() {
44914491
// PackOp/UnPackOp Common
44924492
//===----------------------------------------------------------------------===//
44934493

4494-
template <typename OpTy>
4494+
template <typename OpTy, typename>
44954495
SmallVector<int64_t>
44964496
getPackedOuterShapeWithoutTransposition(OpTy packOrUnPack) {
44974497
RankedTensorType packedType = (std::is_same<OpTy, PackOp>::value)
@@ -5520,19 +5520,19 @@ bool UnPackOp::canFoldSliceOp(tensor::ExtractSliceOp sliceOp) {
55205520
if (!areAllConstantIntValue(sliceOp.getMixedOffsets(), 0) ||
55215521
!areAllConstantIntValue(sliceOp.getMixedStrides(), 1))
55225522
return false;
5523-
RankedTensorType unpackedType = sliceOp.getResultType();
5523+
RankedTensorType unpackedTypeAfterFold = sliceOp.getResultType();
55245524
SmallVector<int64_t> outerShapeWithoutTranspose =
55255525
getPackedOuterShapeWithoutTransposition(*this);
55265526
for (auto [pos, tileSize] :
55275527
llvm::zip_equal(this->getInnerDimsPos(), this->getStaticInnerTiles())) {
5528-
if (unpackedType.isDynamicDim(pos))
5528+
if (unpackedTypeAfterFold.isDynamicDim(pos))
55295529
return false;
55305530
if (ShapedType::isDynamic(outerShapeWithoutTranspose[pos]))
55315531
return false;
55325532
if (ShapedType::isDynamic(tileSize))
55335533
return false;
55345534
int64_t paddingSize = outerShapeWithoutTranspose[pos] * tileSize -
5535-
unpackedType.getDimSize(pos);
5535+
unpackedTypeAfterFold.getDimSize(pos);
55365536
if (paddingSize >= tileSize)
55375537
return false;
55385538
}

mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -220,9 +220,11 @@ struct FoldPadWithPackOp : public OpRewritePattern<PackOp> {
220220
if (!isEqualConstantIntOrValue(paddingValue, constantPaddingValue))
221221
return failure();
222222

223-
// Folding is not allowed if it introduces artificial padding. It is not
224-
// safe to fold the ops if any dynamic dimension or tile size is present,
225-
// because we can not infer the padding size.
223+
// Folding is not allowed if it were to introduce artificial padding.
224+
// Folding is also disabled in the case of dynamic dimensions and/or tile
225+
// sizes - that is because it would be impossible to compute the padding
226+
// size and hence to establish whether "artificial" padding would be
227+
// created.
226228
RankedTensorType unpackedType = packOp.getSourceType();
227229
SmallVector<int64_t> outerShapeWithoutTranspose =
228230
getPackedOuterShapeWithoutTransposition(packOp);

mlir/test/Dialect/Linalg/canonicalize.mlir

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1889,7 +1889,7 @@ func.func @fold_cast_unpack_dynamic_tile_size(
18891889
// linalg.unpack + tensor.extract_slice
18901890
//===----------------------------------------------------------------------===//
18911891

1892-
func.func @fold_extract_slice_into_unpack(
1892+
func.func @fold_extract_slice_into_unpack_slicing_trailing_dim(
18931893
%src : tensor<28x2x1x16x16xf32>, %dest : tensor<28x28x15xf32>, %size : index
18941894
) -> tensor<28x28x10xf32> {
18951895
%unpack = linalg.unpack %src
@@ -1901,7 +1901,7 @@ func.func @fold_extract_slice_into_unpack(
19011901
[0, 0, 0] [28, 28, 10] [1, 1, 1] : tensor<28x28x15xf32> to tensor<28x28x10xf32>
19021902
return %extracted_slice : tensor<28x28x10xf32>
19031903
}
1904-
// CHECK-LABEL: func @fold_extract_slice_into_unpack
1904+
// CHECK-LABEL: func @fold_extract_slice_into_unpack_slicing_trailing_dim
19051905
// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
19061906
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
19071907
// CHECK-SAME: %[[SIZE:[a-zA-Z0-9]+]]
@@ -1913,6 +1913,51 @@ func.func @fold_extract_slice_into_unpack(
19131913

19141914
// -----
19151915

1916+
// The available dimension size is [17, 32], because CeilDiv(%d1, 16) == 2.
1917+
1918+
1919+
func.func @fold_extract_slice_into_unpack_slicing_dim_1(
1920+
%src : tensor<28x2x1x16x16xf32>, %dest : tensor<28x28x15xf32>, %size : index
1921+
) -> tensor<28x17x15xf32> {
1922+
%unpack = linalg.unpack %src
1923+
inner_dims_pos = [1, 2]
1924+
inner_tiles = [16, 16]
1925+
into %dest : tensor<28x2x1x16x16xf32> -> tensor<28x28x15xf32>
1926+
%extracted_slice = tensor.extract_slice %unpack
1927+
[0, 0, 0] [28, 17, 15] [1, 1, 1] : tensor<28x28x15xf32> to tensor<28x17x15xf32>
1928+
return %extracted_slice : tensor<28x17x15xf32>
1929+
}
1930+
// CHECK-LABEL: func @fold_extract_slice_into_unpack_slicing_dim_1(
1931+
// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
1932+
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
1933+
// CHECK-SAME: %[[SIZE:[a-zA-Z0-9]+]]
1934+
// CHECK: %[[DEST_SLICE:.+]] = tensor.extract_slice %[[DEST]]
1935+
// CHECK-SAME: [0, 0, 0] [28, 17, 15] [1, 1, 1]
1936+
// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[SRC]]
1937+
// CHECK-SAME: into %[[DEST_SLICE]]
1938+
// CHECK: return %[[UNPACK]]
1939+
1940+
// -----
1941+
1942+
// The available dimension size is [17, 32], because CeilDiv(%d1, 16) == 2.
1943+
1944+
func.func @no_fold_extract_slice_into_unpack_artificial_padding(
1945+
%src : tensor<28x2x1x16x16xf32>, %dest : tensor<28x28x15xf32>, %size : index
1946+
) -> tensor<28x16x15xf32> {
1947+
%unpack = linalg.unpack %src
1948+
inner_dims_pos = [1, 2]
1949+
inner_tiles = [16, 16]
1950+
into %dest : tensor<28x2x1x16x16xf32> -> tensor<28x28x15xf32>
1951+
%extracted_slice = tensor.extract_slice %unpack
1952+
[0, 0, 0] [28, 16, 15] [1, 1, 1] : tensor<28x28x15xf32> to tensor<28x16x15xf32>
1953+
return %extracted_slice : tensor<28x16x15xf32>
1954+
}
1955+
// CHECK-LABEL: func @no_fold_extract_slice_into_unpack_artificial_padding
1956+
// CHECK: linalg.unpack
1957+
// CHECK: tensor.extract_slice
1958+
1959+
// -----
1960+
19161961
func.func @no_fold_extract_slice_into_unpack_dynamic(
19171962
%src : tensor<28x2x?x16x16xf32>, %dest : tensor<28x32x?xf32>, %size : index
19181963
) -> tensor<28x28x?xf32> {

mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -69,39 +69,37 @@ func.func @nofold_unpack_slice_rank_reduced(%arg0 : tensor<?x?x8x4xf32>, %arg1 :
6969

7070
// -----
7171

72-
func.func @pad_pack(%src: tensor<16649x16xf32>) -> tensor<2082x1x8x32xf32> {
73-
%c0 = arith.constant 0 : index
72+
func.func @fold_pad_pack(%src: tensor<9x16xf32>) -> tensor<2x1x8x32xf32> {
7473
%cst = arith.constant 0.000000e+00 : f32
7574
%padded = tensor.pad %src low[0, 0] high[7, 0] {
7675
^bb0(%arg0: index, %arg1: index):
7776
tensor.yield %cst : f32
78-
} : tensor<16649x16xf32> to tensor<16656x16xf32>
79-
%empty = tensor.empty() : tensor<2082x1x8x32xf32>
77+
} : tensor<9x16xf32> to tensor<16x16xf32>
78+
%empty = tensor.empty() : tensor<2x1x8x32xf32>
8079
%pack = linalg.pack %padded padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %empty
81-
: tensor<16656x16xf32> -> tensor<2082x1x8x32xf32>
82-
return %pack : tensor<2082x1x8x32xf32>
80+
: tensor<16x16xf32> -> tensor<2x1x8x32xf32>
81+
return %pack : tensor<2x1x8x32xf32>
8382
}
84-
// CHECK-LABEL: func.func @pad_pack
83+
// CHECK-LABEL: func.func @fold_pad_pack
8584
// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
8685
// CHECK: %[[PAD_VAL:.+]] = arith.constant 0.000000e+00 : f32
87-
// CHECK: %[[DEST:.+]] = tensor.empty() : tensor<2082x1x8x32xf32>
86+
// CHECK: %[[DEST:.+]] = tensor.empty() : tensor<2x1x8x32xf32>
8887
// CHECK: %[[PACK:.+]] = linalg.pack %[[SRC]]
8988
// CHECK-SAME: padding_value(%[[PAD_VAL]] : f32)
9089
// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %[[DEST]]
9190

9291
// -----
9392

94-
func.func @nofold_pad_pack_artificial_padding(%src: tensor<16641x16xf32>) -> tensor<2082x1x8x32xf32> {
95-
%c0 = arith.constant 0 : index
93+
func.func @nofold_pad_pack_artificial_padding(%src: tensor<9x16xf32>) -> tensor<3x1x8x32xf32> {
9694
%cst = arith.constant 0.000000e+00 : f32
97-
%padded = tensor.pad %src low[0, 0] high[15, 0] {
95+
%padded = tensor.pad %src low[0, 0] high[8, 0] {
9896
^bb0(%arg0: index, %arg1: index):
9997
tensor.yield %cst : f32
100-
} : tensor<16641x16xf32> to tensor<16656x16xf32>
101-
%empty = tensor.empty() : tensor<2082x1x8x32xf32>
98+
} : tensor<9x16xf32> to tensor<17x16xf32>
99+
%empty = tensor.empty() : tensor<3x1x8x32xf32>
102100
%pack = linalg.pack %padded padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %empty
103-
: tensor<16656x16xf32> -> tensor<2082x1x8x32xf32>
104-
return %pack : tensor<2082x1x8x32xf32>
101+
: tensor<17x16xf32> -> tensor<3x1x8x32xf32>
102+
return %pack : tensor<3x1x8x32xf32>
105103
}
106104
// CHECK-LABLE: func.func @nofold_pad_pack_artificial_padding(
107105
// CHECK: tensor.pad
@@ -110,7 +108,6 @@ func.func @nofold_pad_pack_artificial_padding(%src: tensor<16641x16xf32>) -> ten
110108
// -----
111109

112110
func.func @nofold_pad_pack_with_nofold_attribute(%src: tensor<16649x16xf32>) -> tensor<2082x1x8x32xf32> {
113-
%c0 = arith.constant 0 : index
114111
%cst = arith.constant 0.000000e+00 : f32
115112
%padded = tensor.pad %src nofold low[0, 0] high[7, 0] {
116113
^bb0(%arg0: index, %arg1: index):
@@ -128,7 +125,6 @@ func.func @nofold_pad_pack_with_nofold_attribute(%src: tensor<16649x16xf32>) ->
128125
// -----
129126

130127
func.func @pad_pack_different_padding_value(%src: tensor<16641x16xf32>) -> tensor<2082x1x8x32xf32> {
131-
%c0 = arith.constant 0 : index
132128
%cst0 = arith.constant 0.000000e+00 : f32
133129
%cst1 = arith.constant 1.000000e+00 : f32
134130
%padded = tensor.pad %src low[0, 0] high[15, 0] {

0 commit comments

Comments
 (0)