Skip to content

Commit f331d44

Browse files
committed
Address comments from Max
1 parent f6cff33 commit f331d44

File tree

3 files changed

+54
-5
lines changed

3 files changed

+54
-5
lines changed

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1651,7 +1651,7 @@ struct DecomposePadOpPattern : public OpRewritePattern<tensor::PadOp> {
16511651
/// * tensor::PadOp + linalg::TransposeOp + tensor::EmptyOp +
16521652
/// tensor::InsertSliceOp ops.
16531653
///
1654-
/// Requires that all the tile outer dims of the input linalg::PackOp are 1.
1654+
/// Requires that all the tiled outer dims of the input linalg::PackOp are 1.
16551655
/// Note that this constraint means to effectively one tile is packed.
16561656
///
16571657
/// In addition, assumes that the un-tiled outer dims are not permuted.

mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1140,13 +1140,36 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
11401140
packOp, "not all outer dimensions of the result are 1s");
11411141
}
11421142

1143+
ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
1144+
auto outerDimsPerm = packOp.getOuterDimsPerm();
1145+
1146+
// Verify that there are no non-unit un-tiled outer dims that are permuted.
1147+
// Supporting such cases will require refining the logic to generate the
1148+
// Transpose Op.
1149+
if (!llvm::all_of(outerDimsPerm, [&innerDimsPos, &packOp](int64_t dim) {
1150+
static int prev = 0;
1151+
// Tiled dims are not relevant here.
1152+
if (llvm::is_contained(innerDimsPos, dim))
1153+
return true;
1154+
// Was this dim permuted? Note, permuting unit dims is fine.
1155+
if (dim < prev && (packOp.getType().getShape()[prev] != 1 ||
1156+
packOp.getType().getShape()[dim] != 1))
1157+
return false;
1158+
1159+
prev = dim;
1160+
return true;
1161+
})) {
1162+
return rewriter.notifyMatchFailure(
1163+
packOp, "At least one non-unit and un-tiled outer dim is permuted, "
1164+
"this is not supported ATM!");
1165+
}
1166+
11431167
Attribute zeroIdxAttr = rewriter.getIndexAttr(0);
11441168
Attribute oneIdxAttr = rewriter.getIndexAttr(1);
11451169
Location loc = packOp.getLoc();
11461170

11471171
int64_t srcRank = packOp.getSourceRank();
11481172
int64_t destRank = packOp.getDestRank();
1149-
ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
11501173

11511174
// 1. Get the input that is going to be packed. If the input requires padding,
11521175
// add a padding operation and return that as the input.
@@ -1185,8 +1208,10 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
11851208
ShapedType inputTy = cast<ShapedType>(input.getType());
11861209
SmallVector<OpFoldResult> shapeForEmptyOp;
11871210
for (int64_t i = 0; i < srcRank; i++) {
1188-
if (llvm::is_contained(innerDimsPos, i))
1211+
if (llvm::is_contained(innerDimsPos, i)) {
1212+
// The tiled dims are appended after this loop.
11891213
continue;
1214+
}
11901215
if (inputTy.isStaticDim(i))
11911216
shapeForEmptyOp.push_back(rewriter.getIndexAttr(inputTy.getShape()[i]));
11921217
else
@@ -1231,15 +1256,15 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
12311256
}
12321257

12331258
for (auto tileSize : packOp.getMixedTiles()) {
1234-
auto [tileSizeStatic, tileSizeOfr] =
1259+
auto [_, tileSizeOfr] =
12351260
getSimplifiedOfrAndStaticSizePair(tileSize, rewriter);
12361261
writeSizes.push_back(tileSizeOfr);
12371262
}
12381263

12391264
SmallVector<OpFoldResult> writeStrides(destRank, oneIdxAttr);
12401265
SmallVector<OpFoldResult> writeOffsets(destRank, zeroIdxAttr);
12411266

1242-
// TODO: A constructor that doesn't require strised nor offsets.
1267+
// TODO: A constructor that doesn't require strides nor offsets.
12431268
auto insert = tensor::InsertSliceOp::create(
12441269
rewriter, loc, transposedOp.getResult()[0], packOp.getDest(),
12451270
writeOffsets, writeSizes, writeStrides);

mlir/test/Dialect/Linalg/decompose-pack.mlir

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,8 @@ func.func @simple_pad_and_pack_dynamic_tiles(%input: tensor<5x1xf32>, %output: t
176176

177177
// -----
178178

179+
// Note - un-tiled outer dims are permueted. However, these are unit dims, which is supported.
180+
179181
func.func @simple_pad_and_pack_dynamic_tile_not_all_dims_tiled(%input: tensor<1x1x5x1xf32>, %output: tensor<1x1x1x1x2x?xf32>, %pad: f32, %high: index) -> tensor<1x1x1x1x2x?xf32> {
180182
%0 = linalg.pack %input padding_value(%pad : f32) outer_dims_perm = [1, 0, 2, 3] inner_dims_pos = [3, 2] inner_tiles = [2, %high] into %output : tensor<1x1x5x1xf32> -> tensor<1x1x1x1x2x?xf32>
181183
return %0 : tensor<1x1x1x1x2x?xf32>
@@ -201,6 +203,28 @@ func.func @simple_pad_and_pack_dynamic_tile_not_all_dims_tiled(%input: tensor<1x
201203

202204
// -----
203205

206+
// Similar as the example above, but one of the un-tiled outer dims that are permuted is non-unit: (7,1) -> (1, 7)
207+
208+
func.func @negative_not_all_dims_tiled_outer_dim_0_permuted(%input: tensor<7x1x5x1xf32>, %output: tensor<1x7x1x1x2x?xf32>, %pad: f32, %high: index) -> tensor<1x7x1x1x2x?xf32> {
209+
%0 = linalg.pack %input padding_value(%pad : f32) outer_dims_perm = [1, 0, 2, 3] inner_dims_pos = [3, 2] inner_tiles = [2, %high] into %output : tensor<7x1x5x1xf32> -> tensor<1x7x1x1x2x?xf32>
210+
return %0 : tensor<1x7x1x1x2x?xf32>
211+
}
212+
// CHECK-LABEL: func.func @negative_not_all_dims_tiled_outer_dim_0_permuted
213+
// CHECK: linalg.pack
214+
215+
// -----
216+
217+
// Similar as the example above, but one of the un-tiled outer dims that are permuted is non-unit: (1, 7) -> (7, 1).
218+
219+
func.func @negative_not_all_dims_tiled_outer_dim_1_permuted(%input: tensor<1x7x5x1xf32>, %output: tensor<7x1x1x1x2x?xf32>, %pad: f32, %high: index) -> tensor<7x1x1x1x2x?xf32> {
220+
%0 = linalg.pack %input padding_value(%pad : f32) outer_dims_perm = [1, 0, 2, 3] inner_dims_pos = [3, 2] inner_tiles = [2, %high] into %output : tensor<1x7x5x1xf32> -> tensor<7x1x1x1x2x?xf32>
221+
return %0 : tensor<7x1x1x1x2x?xf32>
222+
}
223+
// CHECK-LABEL: func.func @negative_not_all_dims_tiled_outer_dim_1_permuted
224+
// CHECK: linalg.pack
225+
226+
// -----
227+
204228
func.func @simple_NC_to_CNnc(%arg0: tensor<32x8xf32>, %arg1: tensor<1x1x32x8xf32>) -> tensor<1x1x32x8xf32>{
205229
%0 = linalg.pack %arg0 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 8] into %arg1 : tensor<32x8xf32> -> tensor<1x1x32x8xf32>
206230
return %0 : tensor<1x1x32x8xf32>

0 commit comments

Comments
 (0)