Skip to content

Commit fb2842e

Browse files
authored
[mlir][linalg] fix DecomposeOuterUnitDimsPackOpPattern (#21)
Given the following example: ``` func.func @pack_with_unit_outer_dims_and_non_adjacent_inner(%arg0: tensor<3x1x4xf32>, %arg1: tensor<1x1x1x4x3xf32>) -> tensor<1x1x1x4x3xf32> { %pack = linalg.pack %arg0 outer_dims_perm = [1, 2, 0] inner_dims_pos = [2, 0] inner_tiles = [4, 3] into %arg1 : tensor<3x1x4xf32> -> tensor<1x1x1x4x3xf32> return %pack : tensor<1x1x1x4x3xf32> } ``` We would up until now creating an invalid transpose. That is because we would use the `getDimAndTileMapping()` function of the packOp which tranposes the tile dimensions to match based on the given `inner_dims_pos` value. Here in the above example we have `inner_dims_pos` of `[2, 0]` meaning from the source tensors the indices 2 and 0 must be the `inner_tiles`. This property is not required for calculating the tile sizes as the destination tensor shape will be simply `[1x1x1x4x3]`. The inner dimensions positions are only required for calculating the tranpose. With this we can simplify the pattern.
1 parent c357983 commit fb2842e

File tree

3 files changed

+51
-19
lines changed

3 files changed

+51
-19
lines changed

mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1160,8 +1160,6 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
11601160
Location loc = packOp.getLoc();
11611161

11621162
Value input = getPackOpSourceOrPaddedSource(rewriter, packOp);
1163-
DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
1164-
packOp.getDimAndTileMapping();
11651163
int64_t srcRank = packOp.getSourceRank();
11661164
int64_t destRank = packOp.getDestRank();
11671165
int64_t numTiles = destRank - srcRank;
@@ -1174,18 +1172,21 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
11741172
packOp, "Attempting to tile non-trailing source dims!");
11751173

11761174
// 1. Extract the inner tile sizes.
1177-
// Where possible, values are replaced with constant attributes (to match the
1178-
// behaviour of `getPackOpSourceOrPaddedSource`).
1175+
// Use the tile sizes as defined in the operation. As all the outer
1176+
// dimensions are 1 and by definition the last `k` dimensions of the
1177+
// destination tensor (packed tensor) will be the tile sizes, we can simply
1178+
// use the tiles for calculating our transpose permutations.
1179+
//
1180+
// Where possible, values are replaced with constant attributes (to match
1181+
// the behaviour of `getPackOpSourceOrPaddedSource`).
11791182
SmallVector<OpFoldResult> tileSizes;
1180-
for (auto i : llvm::seq<unsigned>(0, srcRank)) {
1181-
if (dimAndTileMapping.count(i)) {
1182-
// Rather than taking the tile size as is, extact the actual constant
1183-
// value Attribute where possible, e.g.:
1184-
// [Value: %tile_size = arith.constant 8 : index] --> [Attribute: 8]
1185-
auto [_, tileSize] =
1186-
getSimplifiedOfrAndStaticSizePair(dimAndTileMapping[i], rewriter);
1187-
tileSizes.push_back(tileSize);
1188-
}
1183+
for (const OpFoldResult &tileSizeDef : packOp.getMixedTiles()) {
1184+
// Rather than taking the tile size as is, extract the actual constant
1185+
// value Attribute where possible, e.g.:
1186+
// [Value: %tile_size = arith.constant 8 : index] --> [Attribute: 8]
1187+
auto [_, tileSize] =
1188+
getSimplifiedOfrAndStaticSizePair(tileSizeDef, rewriter);
1189+
tileSizes.push_back(tileSize);
11891190
}
11901191

11911192
// 2. Transpose the input to match the inner tile order:
@@ -1218,9 +1219,6 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
12181219
SmallVector<OpFoldResult> transShapeForEmptyOp(srcRank - numTiles,
12191220
oneIdxAttr);
12201221
transShapeForEmptyOp.append(tileSizes);
1221-
1222-
applyPermutationToVector<OpFoldResult>(transShapeForEmptyOp,
1223-
srcPermForTranspose);
12241222
Value empty = rewriter.create<tensor::EmptyOp>(
12251223
loc, transShapeForEmptyOp, packOp.getSourceType().getElementType());
12261224

@@ -1233,8 +1231,7 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
12331231
SmallVector<OpFoldResult> writeStrides(destRank, oneIdxAttr);
12341232
SmallVector<OpFoldResult> writeOffsets(destRank, zeroIdxAttr);
12351233
// Outer dims are all 1s!
1236-
SmallVector<OpFoldResult> writeSizes(destRank - dimAndTileMapping.size(),
1237-
oneIdxAttr);
1234+
SmallVector<OpFoldResult> writeSizes(destRank - numTiles, oneIdxAttr);
12381235
SmallVector<int64_t> writeShape;
12391236

12401237
for (auto tileSize : packOp.getMixedTiles()) {

mlir/test/Dialect/Linalg/decompose-pack.mlir

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,4 +247,22 @@ func.func @pack_with_unit_outer_dims_and_unit_inner(%arg0: tensor<1x1x4xf32>, %a
247247
// CHECK-SAME: permutation = [1, 2, 0]
248248
// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %[[DEST]]
249249
// CHECK-SAME: [0, 0, 0, 0, 0] [1, 1, 1, 4, 1] [1, 1, 1, 1, 1] : tensor<1x4x1xf32> into tensor<1x1x1x4x1xf32>
250-
// CHECK: return %[[INSERT]]
250+
// CHECK: return %[[INSERT]]
251+
252+
// -----
253+
254+
func.func @pack_with_unit_outer_dims_and_non_adjacent_inner(%arg0: tensor<4x1x3xf32>, %arg1: tensor<1x1x1x3x4xf32>) -> tensor<1x1x1x3x4xf32> {
255+
%pack = linalg.pack %arg0 outer_dims_perm = [1, 2, 0] inner_dims_pos = [2, 0] inner_tiles = [3, 4] into %arg1 : tensor<4x1x3xf32> -> tensor<1x1x1x3x4xf32>
256+
return %pack : tensor<1x1x1x3x4xf32>
257+
}
258+
// CHECK-LABEL: func.func @pack_with_unit_outer_dims_and_non_adjacent_inner
259+
// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
260+
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
261+
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x3x4xf32>
262+
// CHECK: %[[TRANSP:.+]] = linalg.transpose
263+
// CHECK-SAME: ins(%[[SRC]] : tensor<4x1x3xf32>)
264+
// CHECK-SAME: outs(%[[EMPTY]] : tensor<1x3x4xf32>)
265+
// CHECK-SAME: permutation = [1, 2, 0]
266+
// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %[[DEST]]
267+
// CHECK-SAME: [0, 0, 0, 0, 0] [1, 1, 1, 3, 4] [1, 1, 1, 1, 1] : tensor<1x3x4xf32> into tensor<1x1x1x3x4xf32>
268+
// CHECK: return %[[INSERT]]

mlir/test/Dialect/Linalg/decompose-unpack.mlir

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,3 +169,20 @@ func.func @unpack_with_dynamic_dims(%arg0: tensor<?x1x1x1x8x32xf32>, %arg1: tens
169169
// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[EXTRACT_SLICE]] into %[[DEST]]
170170
// CHECK-SAME: [0, 0, 0, 0] [%[[DIM0_DEST]], 1, 32, 8] [1, 1, 1, 1]
171171
// CHECK: return %[[INSERT]]
172+
173+
// -----
174+
175+
func.func @unpack_with_unit_outer_dims_and_non_adjacent_inner(%arg0: tensor<1x1x1x3x4xf32>, %arg1: tensor<4x1x3xf32>) -> tensor<4x1x3xf32> {
176+
%pack = linalg.unpack %arg0 outer_dims_perm = [1, 2, 0] inner_dims_pos = [2, 0] inner_tiles = [3, 4] into %arg1 : tensor<1x1x1x3x4xf32> -> tensor<4x1x3xf32>
177+
return %pack : tensor<4x1x3xf32>
178+
}
179+
// CHECK-LABEL: func.func @unpack_with_unit_outer_dims_and_non_adjacent_inner
180+
// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
181+
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
182+
// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[SRC]][0, 0, 0, 0, 0] [1, 1, 1, 3, 4] [1, 1, 1, 1, 1] : tensor<1x1x1x3x4xf32> to tensor<3x4xf32>
183+
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<4x3xf32>
184+
// CHECK: %[[TRANSP:.+]] = linalg.transpose
185+
// CHECK-SAME: ins(%[[SLICE]] : tensor<3x4xf32>)
186+
// CHECK-SAME: outs(%[[EMPTY]] : tensor<4x3xf32>) permutation = [1, 0]
187+
// CHECK: %[[INSERT:.+]] = tensor.insert_slice %transposed into %[[DEST]][0, 0, 0] [4, 1, 3] [1, 1, 1] : tensor<4x3xf32> into tensor<4x1x3xf32>
188+
// CHECK: return %[[INSERT]]

0 commit comments

Comments
 (0)