Skip to content

Commit ad3cda7

Browse files
authored
[mlir][tensor] Enhance SimplifyUnPackToCollapseShape for unit dim cases. (#79262)
1 parent ca0e241 commit ad3cda7

File tree

2 files changed

+92
-28
lines changed

2 files changed

+92
-28
lines changed

mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp

Lines changed: 41 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,27 @@ static int64_t getNumGtOneDims(ArrayRef<int64_t> shape) {
2828
shape, [](int64_t v) { return ShapedType::isDynamic(v) || v > 1; });
2929
}
3030

31+
/// Returns success() if there is only 1 dimension size in non-packed domain
32+
/// being greater than 1 and packing only happens on the dimension.
33+
/// Note: this method should only be used by pack/unpack to reshape conversion.
34+
/// It assumes that non-unit inner tile size must be used by the non-unit
35+
/// dimension.
36+
static LogicalResult isPackOn1D(RewriterBase &rewriter, Operation *op,
37+
ArrayRef<int64_t> srcShape,
38+
ArrayRef<int64_t> innerPackTileSize) {
39+
if (getNumGtOneDims(srcShape) > 1) {
40+
return rewriter.notifyMatchFailure(
41+
op, "expects non-packed domain to have at most one non-unit dims");
42+
}
43+
// Non-unit inner tile size must be used by the non-unit dimension. If not, it
44+
// will faill on getting reassociation maps.
45+
if (getNumGtOneDims(innerPackTileSize) > 1) {
46+
return rewriter.notifyMatchFailure(
47+
op, "expects at most one non-unit inner tiles");
48+
}
49+
return success();
50+
}
51+
3152
/// Packing one-dimensional tensor can be expressed as an expand shape op.
3253
struct SimplifyPackToExpandShape : public OpRewritePattern<PackOp> {
3354
using OpRewritePattern<PackOp>::OpRewritePattern;
@@ -59,40 +80,18 @@ struct SimplifyPackToExpandShape : public OpRewritePattern<PackOp> {
5980
return success();
6081
}
6182

62-
/// Returns success() if there is only 1 dimension size in source being
63-
/// greater than 1 and packing only happens on the dimension. It assumes that
64-
/// the pack op does not have padding value.
65-
LogicalResult isPack1DSrc(RewriterBase &rewriter, PackOp packOp) const {
66-
assert(!packOp.getPaddingValue() &&
67-
"expect the op does not have padding value.");
68-
ArrayRef<int64_t> srcShape = packOp.getSourceType().getShape();
69-
if (getNumGtOneDims(srcShape) > 1) {
70-
return rewriter.notifyMatchFailure(
71-
packOp, "expects source to have at most one non-unit dims");
72-
}
73-
74-
// The pack op does not have padding value. Non-unit inner tile size must be
75-
// be used by the non-unit dimension.
76-
SmallVector<int64_t> innerTiles = packOp.getStaticTiles();
77-
if (getNumGtOneDims(innerTiles) > 1) {
78-
return rewriter.notifyMatchFailure(
79-
packOp, "expects at most one non-unit inner tiles");
80-
}
81-
82-
return success();
83-
}
84-
8583
LogicalResult matchAndRewrite(PackOp packOp,
8684
PatternRewriter &rewriter) const override {
8785
if (packOp.getPaddingValue())
8886
return rewriter.notifyMatchFailure(packOp, "expects no padding value");
8987

88+
RankedTensorType sourceType = packOp.getSourceType();
9089
if (failed(isPackOnInnerMostDim(rewriter, packOp)) &&
91-
failed(isPack1DSrc(rewriter, packOp))) {
90+
failed(isPackOn1D(rewriter, packOp, sourceType.getShape(),
91+
packOp.getStaticTiles()))) {
9292
return failure();
9393
}
9494

95-
RankedTensorType sourceType = packOp.getSourceType();
9695
RankedTensorType destType = packOp.getDestType();
9796
auto reassociation =
9897
getReassociationIndicesForReshape(sourceType, destType);
@@ -117,8 +116,9 @@ struct SimplifyUnPackToCollapseShape : public OpRewritePattern<UnPackOp> {
117116
operand, reassociation);
118117
}
119118

120-
LogicalResult matchAndRewrite(UnPackOp unpackOp,
121-
PatternRewriter &rewriter) const override {
119+
/// Returns success() if it is unpacking on the innermost dimension.
120+
LogicalResult isUnpackOnInnerMostDim(RewriterBase &rewriter,
121+
UnPackOp unpackOp) const {
122122
auto outerDimsPerm = unpackOp.getOuterDimsPerm();
123123
if (!outerDimsPerm.empty() && !isIdentityPermutation(outerDimsPerm)) {
124124
return rewriter.notifyMatchFailure(
@@ -134,9 +134,22 @@ struct SimplifyUnPackToCollapseShape : public OpRewritePattern<UnPackOp> {
134134
ArrayRef<int64_t> dimsPos = unpackOp.getInnerDimsPos();
135135
if (dimsPos.size() != 1 || (dimsPos[0] + 1 != destType.getRank())) {
136136
return rewriter.notifyMatchFailure(
137-
unpackOp, "expects unpacking at the innermost dimension");
137+
unpackOp, "expects unpacking on the innermost dimension");
138138
}
139139

140+
return success();
141+
}
142+
143+
LogicalResult matchAndRewrite(UnPackOp unpackOp,
144+
PatternRewriter &rewriter) const override {
145+
RankedTensorType destType = unpackOp.getDestType();
146+
if (failed(isUnpackOnInnerMostDim(rewriter, unpackOp)) &&
147+
failed(isPackOn1D(rewriter, unpackOp, destType.getShape(),
148+
unpackOp.getStaticTiles()))) {
149+
return failure();
150+
}
151+
152+
RankedTensorType sourceType = unpackOp.getSourceType();
140153
auto reassociation =
141154
getReassociationIndicesForReshape(sourceType, destType);
142155
if (!reassociation)

mlir/test/Dialect/Tensor/simplify-pack-unpack.mlir

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,3 +215,54 @@ func.func @single_first_inner_dim_unpacking(%arg0: tensor<8x5x32xf32>) -> tensor
215215
%0 = tensor.unpack %arg0 inner_dims_pos = [0] inner_tiles = [32] into %empty : tensor<8x5x32xf32> -> tensor<256x5xf32>
216216
return %0 : tensor<256x5xf32>
217217
}
218+
219+
// -----
220+
221+
// CHECK-LABEL: func.func @unpack_1x32x1x1_to_1x32
222+
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]
223+
// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0], [1, 2, 3]]
224+
// CHECK: return %[[COLLAPSED]]
225+
func.func @unpack_1x32x1x1_to_1x32(%arg0 : tensor<1x32x1x1xf32>) -> tensor<1x32xf32> {
226+
%empty = tensor.empty() : tensor<1x32xf32>
227+
%unpack = tensor.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [1, 1] into %empty
228+
: tensor<1x32x1x1xf32> -> tensor<1x32xf32>
229+
return %unpack : tensor<1x32xf32>
230+
}
231+
232+
// -----
233+
234+
// CHECK-LABEL: func.func @unpack_1x2x1x16_to_1x32
235+
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]
236+
// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0], [1, 2, 3]]
237+
// CHECK: return %[[COLLAPSED]]
238+
func.func @unpack_1x2x1x16_to_1x32(%arg0 : tensor<1x2x1x16xf32>) -> tensor<1x32xf32> {
239+
%empty = tensor.empty() : tensor<1x32xf32>
240+
%unpack = tensor.unpack %arg0 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [1, 16] into %empty
241+
: tensor<1x2x1x16xf32> -> tensor<1x32xf32>
242+
return %unpack : tensor<1x32xf32>
243+
}
244+
245+
// -----
246+
247+
// CHECK-LABEL: func.func @unpack_16x1x2x1_to_32x1
248+
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]
249+
// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1, 2], [3]]
250+
// CHECK: return %[[COLLAPSED]]
251+
func.func @unpack_16x1x2x1_to_32x1(%arg0 : tensor<1x16x2x1xf32>) -> tensor<32x1xf32> {
252+
%empty = tensor.empty() : tensor<32x1xf32>
253+
%unpack = tensor.unpack %arg0 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [2, 1] into %empty
254+
: tensor<1x16x2x1xf32> -> tensor<32x1xf32>
255+
return %unpack : tensor<32x1xf32>
256+
}
257+
258+
// -----
259+
260+
// CHECK-LABEL: func.func @unpack_16x1x1x2_to_32x1
261+
// CHECK-NOT: tensor.collapse_shape
262+
// CHECK: tensor.unpack
263+
func.func @unpack_16x1x1x2_to_32x1(%arg0 : tensor<16x1x1x2xf32>) -> tensor<32x1xf32> {
264+
%empty = tensor.empty() : tensor<32x1xf32>
265+
%unpack = tensor.unpack %arg0 inner_dims_pos = [1, 0] inner_tiles = [1, 2] into %empty
266+
: tensor<16x1x1x2xf32> -> tensor<32x1xf32>
267+
return %unpack : tensor<32x1xf32>
268+
}

0 commit comments

Comments
 (0)