Skip to content

Commit a5c1a2d

Browse files
banach-spaceakadutta
authored andcommitted
[mlir][linalg] Extend DecomposeOuterUnitDimsPackOpPattern (linalg.pack) (llvm#162666)
Similarly to llvm#152960, this PR fixes `getTiledOuterDims` for `linalg.pack` by ensuring that the `outer_dims_perm` attributeis properly taken into account. This enables the main change in this PR: relaxing the constraints in * `DecomposeOuterUnitDimsPackOpPattern`. Specifically, the pattern is extended to allow non-unit untiled outer dimensions. For example: ```mlir func.func @example( %src: tensor<2x32x16x8xf32>, %dest: tensor<2x1x16x8x32xf32>) -> tensor<2x1x16x8x32xf32> { %pack = linalg.pack %src inner_dims_pos = [1] inner_tiles = [32] into %dest : tensor<2x32x16x8xf32> -> tensor<2x1x16x8x32xf32> return %pack : tensor<2x1x16x8x32xf32> } ``` decomposes as: ```mlir func.func @example( %src: tensor<2x32x16x8xf32>, %dest: tensor<2x1x16x8x32xf32>) -> tensor<2x1x16x8x32xf32> { %0 = tensor.empty() : tensor<2x16x8x32xf32> %transposed = linalg.transpose ins(%src : tensor<2x32x16x8xf32>) outs(%init : tensor<2x16x8x32xf32>) permutation = [0, 2, 3, 1] %inserted_slice = tensor.insert_slice %transposed into %dest[0, 0, 0, 0, 0] [2, 1, 16, 8, 32] [1, 1, 1, 1, 1] : tensor<2x16x8x32xf32> into tensor<2x1x16x8x32xf32> return %inserted_slice : tensor<2x1x16x8x32xf32> } ``` Importantly, this change makes `DecomposeOuterUnitDimsPackOpPattern` (the decomposition pattern for `linalg.pack`) consistent with the corresponding pattern for `linalg.unpack`: * `DecomposeOuterUnitDimsUnPackOpPattern`. One notable assumption remains: untiled outer dimensions are not permuted. This was already the case but is now explicitly documented. Co-authored by: Max Bartel <[email protected]>
1 parent 694be6a commit a5c1a2d

File tree

5 files changed

+150
-25
lines changed

5 files changed

+150
-25
lines changed

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1650,8 +1650,12 @@ struct DecomposePadOpPattern : public OpRewritePattern<tensor::PadOp> {
16501650
/// Rewrites a linalg::PackOp into a sequence of:
16511651
/// * tensor::PadOp + linalg::TransposeOp + tensor::EmptyOp +
16521652
/// tensor::InsertSliceOp ops.
1653+
/// (InsertSliceOp is rank-expanding).
16531654
///
1654-
/// Requires that all the outer dims of the input linalg::PackOp are 1.
1655+
/// Requires that all the tiled-outer-dims of the input linalg::PackOp are 1.
1656+
/// Note that this constraint means that effectively exactly one tile is packed.
1657+
///
1658+
/// In addition, assumes that the un-tiled-outer-dims are not permuted.
16551659
///
16561660
/// Before:
16571661
/// ```
@@ -1687,10 +1691,13 @@ struct DecomposeOuterUnitDimsPackOpPattern
16871691
PatternRewriter &rewriter) const override;
16881692
};
16891693

1690-
/// Rewrites a linalg::UnPackOp into a sequence of rank-reduced
1694+
/// Rewrites a linalg::UnPackOp into a sequence of:
16911695
/// * tensor::ExtractSliceOp + linalg::TransposeOp + tensor::InsertSliceOp
1696+
/// (ExtractSliceOp is rank-reducing).
16921697
///
1693-
/// Requires that all the tiled outer dims of the input linalg::PackOp are 1.
1698+
/// Requires that all the tiled-outer-dims of the input linalg::UnPackOp are 1.
1699+
/// Note that this constraint means that effectively exactly one tile is
1700+
/// unpacked.
16941701
///
16951702
/// Before:
16961703
/// ```

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5272,11 +5272,18 @@ ArrayRef<int64_t> PackOp::getAllOuterDims() {
52725272

52735273
SmallVector<int64_t> PackOp::getTiledOuterDims() {
52745274
auto innerDimsPos = getInnerDimsPos();
5275-
auto packedShape = getDestType().getShape();
5275+
SmallVector<int64_t> outerDims(getAllOuterDims());
52765276
SmallVector<int64_t> res;
52775277

5278+
// Recover the original order of the outer dims.
5279+
SmallVector<int64_t> outerDimPermInv(getOuterDimsPerm());
5280+
invertPermutationVector(outerDimPermInv);
5281+
if (!outerDimPermInv.empty())
5282+
applyPermutationToVector(outerDims, outerDimPermInv);
5283+
5284+
// Collect the outer dims corresponding to the tilled inner dims.
52785285
for (auto index : innerDimsPos)
5279-
res.push_back(packedShape[index]);
5286+
res.push_back(outerDims[index]);
52805287

52815288
return res;
52825289
}

mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp

Lines changed: 69 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1134,22 +1134,45 @@ getPackUnpackRankReducedPerm(ArrayRef<int64_t> shape,
11341134

11351135
LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
11361136
linalg::PackOp packOp, PatternRewriter &rewriter) const {
1137-
// TODO: support the case that outer dimensions are not all 1s. A
1138-
// tensor.expand_shape will be generated in this case.
1139-
if (llvm::any_of(packOp.getAllOuterDims(),
1137+
if (llvm::any_of(packOp.getTiledOuterDims(),
11401138
[](int64_t dim) { return dim != 1; })) {
11411139
return rewriter.notifyMatchFailure(
11421140
packOp, "not all outer dimensions of the result are 1s");
11431141
}
11441142

1143+
ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
1144+
auto outerDimsPerm = packOp.getOuterDimsPerm();
1145+
1146+
// Verify that there are no:
1147+
// * non-unit + un-tiled-outer-dims,
1148+
// that are permuted. Supporting such cases would require refining the logic
1149+
// that generates the Transpose Op.
1150+
if (!llvm::all_of(outerDimsPerm, [&innerDimsPos, &packOp](int64_t dim) {
1151+
static int prev = 0;
1152+
// Skip tiled dims - these can be permuted.
1153+
if (llvm::is_contained(innerDimsPos, dim))
1154+
return true;
1155+
1156+
// Check whether this dim has been permuted. Permuting unit dims is fine
1157+
// as that's effectively a no-op.
1158+
if (dim < prev && (packOp.getType().getShape()[prev] != 1 ||
1159+
packOp.getType().getShape()[dim] != 1))
1160+
return false;
1161+
1162+
prev = dim;
1163+
return true;
1164+
})) {
1165+
return rewriter.notifyMatchFailure(
1166+
packOp, "At least one non-unit and un-tiled outer dim is permuted, "
1167+
"this is not supported ATM!");
1168+
}
1169+
11451170
Attribute zeroIdxAttr = rewriter.getIndexAttr(0);
11461171
Attribute oneIdxAttr = rewriter.getIndexAttr(1);
11471172
Location loc = packOp.getLoc();
11481173

11491174
int64_t srcRank = packOp.getSourceRank();
11501175
int64_t destRank = packOp.getDestRank();
1151-
ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
1152-
int64_t numberOfTiles = innerDimsPos.size();
11531176

11541177
// 1. Get the input that is going to be packed. If the input requires padding,
11551178
// add a padding operation and return that as the input.
@@ -1160,10 +1183,13 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
11601183
// %transposed_tile = linalg.transpose ins(%source_or_padded_source),
11611184
// outs(%init)
11621185
// Assumptions made:
1163-
// - All outer dims are 1 - the corresponding transposition order doesn't
1164-
// matter, but requires all dim indices to be present.
1186+
// - All tiled outer dims are 1 - the corresponding transposition order
1187+
// doesn't matter, but requires all dim indices to be present.
1188+
// - Un-tiled outer dims remain un-permuted.
11651189

1166-
// 2.1 Get the permutation for linalg.transpose
1190+
// 2.1 Get the permutation for linalg.transpose:
1191+
// [ untiled-dims, inner-dims-pos ]
1192+
// Note, this logic assumes that the untiled dims are not permuted.
11671193
SmallVector<int64_t> srcPermForTranspose;
11681194
for (int64_t i = 0; i < srcRank; i++) {
11691195
// We assume the `k` dimensions of the inner dim position, where `k` is the
@@ -1179,9 +1205,21 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
11791205
}
11801206
srcPermForTranspose.append(innerDimsPos.begin(), innerDimsPos.end());
11811207

1182-
// 2.2 Create the init tensor for linalg.transpose with the correct shape
1183-
SmallVector<OpFoldResult> shapeForEmptyOp(srcRank - numberOfTiles,
1184-
oneIdxAttr);
1208+
// 2.2 Create the init tensor for linalg.transpose with the correct shape:
1209+
// [ untiled-dims, tiled-dims ]
1210+
ShapedType inputTy = cast<ShapedType>(input.getType());
1211+
SmallVector<OpFoldResult> shapeForEmptyOp;
1212+
for (int64_t i = 0; i < srcRank; i++) {
1213+
if (llvm::is_contained(innerDimsPos, i)) {
1214+
// The tiled dims are appended after this loop.
1215+
continue;
1216+
}
1217+
if (inputTy.isStaticDim(i))
1218+
shapeForEmptyOp.push_back(rewriter.getIndexAttr(inputTy.getShape()[i]));
1219+
else
1220+
shapeForEmptyOp.emplace_back(
1221+
tensor::DimOp::create(rewriter, loc, input, i).getResult());
1222+
}
11851223
shapeForEmptyOp.append(packOp.getMixedTiles());
11861224

11871225
// getMixedTiles() may contain Values pointing to constant ops, not the
@@ -1204,25 +1242,36 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
12041242
auto transposedOp = linalg::TransposeOp::create(rewriter, loc, input, empty,
12051243
srcPermForTranspose);
12061244

1207-
// 3. Insert the inner tile to the destination:
1245+
// 3. Insert the inner tile into the destination tensor:
12081246
// %inserted_tile = tensor.insert_slice(%transposed_tile)
1209-
SmallVector<OpFoldResult> writeStrides(destRank, oneIdxAttr);
1210-
SmallVector<OpFoldResult> writeOffsets(destRank, zeroIdxAttr);
1211-
// Outer dims are all 1s!
1212-
SmallVector<OpFoldResult> writeSizes(destRank - numberOfTiles, oneIdxAttr);
1213-
SmallVector<int64_t> writeShape;
1247+
1248+
// Compute the sizes attribute:
1249+
// [ outer-dims, tile-sizes ]
1250+
// Note that the output from the transpose Op excludes the tiled outer dims.
1251+
// However, given the assumption that:
1252+
// * all tiled outer dims == 1,
1253+
// we can just use a rank-expanding tensor.insert_slice.
1254+
SmallVector<OpFoldResult> writeSizes;
1255+
for (auto size : packOp.getAllOuterDims()) {
1256+
writeSizes.push_back(rewriter.getIndexAttr(size));
1257+
}
12141258

12151259
for (auto tileSize : packOp.getMixedTiles()) {
1216-
auto [tileSizeStatic, tileSizeOfr] =
1260+
auto [_, tileSizeOfr] =
12171261
getSimplifiedOfrAndStaticSizePair(tileSize, rewriter);
12181262
writeSizes.push_back(tileSizeOfr);
1219-
writeShape.push_back(tileSizeStatic);
12201263
}
12211264

1222-
// 4. Replace tensor.packOp with tensor.insert_slice created above
1265+
// TODO: Add a constructor for tensor.insert_slice that doesn't require
1266+
// strides nor offsets.
1267+
SmallVector<OpFoldResult> writeStrides(destRank, oneIdxAttr);
1268+
SmallVector<OpFoldResult> writeOffsets(destRank, zeroIdxAttr);
1269+
12231270
auto insert = tensor::InsertSliceOp::create(
12241271
rewriter, loc, transposedOp.getResult()[0], packOp.getDest(),
12251272
writeOffsets, writeSizes, writeStrides);
1273+
1274+
// 4. Replace tensor.packOp with tensor.insert_slice created above
12261275
rewriter.replaceOp(packOp, insert.getResult());
12271276

12281277
return success();

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2310,6 +2310,7 @@ RankedTensorType ExtractSliceOp::inferResultType(
23102310
sourceTensorType.getEncoding());
23112311
}
23122312

2313+
// TODO: This uses neither offsets nor strides!
23132314
RankedTensorType ExtractSliceOp::inferResultType(
23142315
RankedTensorType sourceTensorType, ArrayRef<OpFoldResult> offsets,
23152316
ArrayRef<OpFoldResult> sizes, ArrayRef<OpFoldResult> strides) {

mlir/test/Dialect/Linalg/decompose-pack.mlir

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,25 @@ func.func @simple_KCRS_to_KCRSsr(%arg0: tensor<?x?xi32>, %arg1: tensor<1x1x?x1xi
3131

3232
// -----
3333

34+
func.func @NCHW_to_NCHWc(%src: tensor<2x32x16x8xf32>, %dest: tensor<2x1x16x8x32xf32>) -> tensor<2x1x16x8x32xf32> {
35+
%pack = linalg.pack %src
36+
inner_dims_pos = [1]
37+
inner_tiles = [32] into %dest
38+
: tensor<2x32x16x8xf32> -> tensor<2x1x16x8x32xf32>
39+
return %pack : tensor<2x1x16x8x32xf32>
40+
}
41+
// CHECK-LABEL: func.func @NCHW_to_NCHWc(
42+
// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
43+
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
44+
// CHECK: %[[INIT:.*]] = tensor.empty() : tensor<2x16x8x32xf32>
45+
// CHECK: %[[TR:.*]] = linalg.transpose ins(%[[SRC]] : tensor<2x32x16x8xf32>) outs(%[[INIT]] : tensor<2x16x8x32xf32>) permutation = [0, 2, 3, 1]
46+
// CHECK: %[[RES:.*]] = tensor.insert_slice %[[TR]] into %[[DEST]]
47+
// CHECK-SAME: [0, 0, 0, 0, 0] [2, 1, 16, 8, 32] [1, 1, 1, 1, 1]
48+
// CHECK-SAME: : tensor<2x16x8x32xf32> into tensor<2x1x16x8x32xf32>
49+
// CHECK: return %[[RES]] : tensor<2x1x16x8x32xf32>
50+
51+
// -----
52+
3453
func.func @simple_pad_and_pack_static_tiles(%input: tensor<5x1xf32>, %output: tensor<1x1x8x2xf32>, %pad: f32) -> tensor<1x1x8x2xf32> {
3554
%0 = linalg.pack %input padding_value(%pad : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %output : tensor<5x1xf32> -> tensor<1x1x8x2xf32>
3655
return %0 : tensor<1x1x8x2xf32>
@@ -157,6 +176,8 @@ func.func @simple_pad_and_pack_dynamic_tiles(%input: tensor<5x1xf32>, %output: t
157176

158177
// -----
159178

179+
// Note - un-tiled outer dims are permueted. However, these are unit dims, which is supported.
180+
160181
func.func @simple_pad_and_pack_dynamic_tile_not_all_dims_tiled(%input: tensor<1x1x5x1xf32>, %output: tensor<1x1x1x1x2x?xf32>, %pad: f32, %high: index) -> tensor<1x1x1x1x2x?xf32> {
161182
%0 = linalg.pack %input padding_value(%pad : f32) outer_dims_perm = [1, 0, 2, 3] inner_dims_pos = [3, 2] inner_tiles = [2, %high] into %output : tensor<1x1x5x1xf32> -> tensor<1x1x1x1x2x?xf32>
162183
return %0 : tensor<1x1x1x1x2x?xf32>
@@ -182,6 +203,28 @@ func.func @simple_pad_and_pack_dynamic_tile_not_all_dims_tiled(%input: tensor<1x
182203

183204
// -----
184205

206+
// Similar as the example above, but one of the un-tiled outer dims that are permuted is non-unit: (7,1) -> (1, 7)
207+
208+
func.func @negative_not_all_dims_tiled_outer_dim_0_permuted(%input: tensor<7x1x5x1xf32>, %output: tensor<1x7x1x1x2x?xf32>, %pad: f32, %high: index) -> tensor<1x7x1x1x2x?xf32> {
209+
%0 = linalg.pack %input padding_value(%pad : f32) outer_dims_perm = [1, 0, 2, 3] inner_dims_pos = [3, 2] inner_tiles = [2, %high] into %output : tensor<7x1x5x1xf32> -> tensor<1x7x1x1x2x?xf32>
210+
return %0 : tensor<1x7x1x1x2x?xf32>
211+
}
212+
// CHECK-LABEL: func.func @negative_not_all_dims_tiled_outer_dim_0_permuted
213+
// CHECK: linalg.pack
214+
215+
// -----
216+
217+
// Similar as the example above, but one of the un-tiled outer dims that are permuted is non-unit: (1, 7) -> (7, 1).
218+
219+
func.func @negative_not_all_dims_tiled_outer_dim_1_permuted(%input: tensor<1x7x5x1xf32>, %output: tensor<7x1x1x1x2x?xf32>, %pad: f32, %high: index) -> tensor<7x1x1x1x2x?xf32> {
220+
%0 = linalg.pack %input padding_value(%pad : f32) outer_dims_perm = [1, 0, 2, 3] inner_dims_pos = [3, 2] inner_tiles = [2, %high] into %output : tensor<1x7x5x1xf32> -> tensor<7x1x1x1x2x?xf32>
221+
return %0 : tensor<7x1x1x1x2x?xf32>
222+
}
223+
// CHECK-LABEL: func.func @negative_not_all_dims_tiled_outer_dim_1_permuted
224+
// CHECK: linalg.pack
225+
226+
// -----
227+
185228
func.func @simple_NC_to_CNnc(%arg0: tensor<32x8xf32>, %arg1: tensor<1x1x32x8xf32>) -> tensor<1x1x32x8xf32>{
186229
%0 = linalg.pack %arg0 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 8] into %arg1 : tensor<32x8xf32> -> tensor<1x1x32x8xf32>
187230
return %0 : tensor<1x1x32x8xf32>
@@ -295,3 +338,21 @@ func.func @pack_with_non_adjacent_and_non_permuted_inner_dims(%arg0: tensor<8x1x
295338
// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %[[DEST]]
296339
// CHECK-SAME: [0, 0, 0, 0, 0, 0] [1, 1, 1, 1, 8, 1] [1, 1, 1, 1, 1, 1] : tensor<1x1x8x1xf32> into tensor<1x1x1x1x8x1xf32>
297340
// CHECK: return %[[INSERT]]
341+
342+
// -----
343+
344+
/// Note "126", which is a non-unit tiled-outer-dim. This is not supported.
345+
346+
func.func @negative_non_unit_tiled_outer_dim(%dest: tensor<1x126x1x1x8xf32>, %src: tensor<1x1x1x1001xf32>, %pad: f32) -> tensor<1x126x1x1x8xf32> {
347+
%pack = linalg.pack %src
348+
padding_value(%pad : f32)
349+
outer_dims_perm = [0, 3, 2, 1]
350+
inner_dims_pos = [3]
351+
inner_tiles = [8]
352+
into %dest
353+
: tensor<1x1x1x1001xf32> -> tensor<1x126x1x1x8xf32>
354+
355+
return %pack : tensor<1x126x1x1x8xf32>
356+
}
357+
// CHECK-LABEL: @negative_non_unit_tiled_outer_dim(
358+
// CHECK: linalg.pack

0 commit comments

Comments
 (0)