Skip to content

Commit f53b624

Browse files
authored
[MLIR][Linalg] Fix empty tensor assumptions for linalg.pack decomposition (#160246)
The original code seemed to assume that the tiling dimensions for the tensor.empty op before applying the transposing were always the last dimensions. However, pack allows you to choose any dimension to tile. The easiest way I found to solve this is to prefill the SmallVector with 1s of size (srcRank - numberOfTiles) and then appending the tile sizes. This way I could also get rid of the first loop in the code.
1 parent 81589a3 commit f53b624

File tree

3 files changed

+55
-36
lines changed

3 files changed

+55
-36
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,11 @@ class Linalg_RelayoutOp<string mnemonic, list<Trait> traits = []> :
5757
/// tile factors.
5858
DenseMap<int64_t, OpFoldResult> getDimAndTileMapping();
5959

60-
/// Return the tile sizes as OpFoldResult.
60+
// TODO: Return the folded result.
61+
/// Return the tile sizes as OpFoldResult. Will return the Value
62+
/// of the constant Op, not the constant Attribute.
63+
/// E.g., for %size = arith.constant 1 : i32 will return %size,
64+
/// not 1.
6165
SmallVector<OpFoldResult> getMixedTiles();
6266

6367
/// Return the tile sizes as `int64_t`. If a tile size is dynamic

mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp

Lines changed: 29 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1146,37 +1146,25 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
11461146
Attribute oneIdxAttr = rewriter.getIndexAttr(1);
11471147
Location loc = packOp.getLoc();
11481148

1149-
Value input = getPackOpSourceOrPaddedSource(rewriter, packOp);
1150-
DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
1151-
packOp.getDimAndTileMapping();
11521149
int64_t srcRank = packOp.getSourceRank();
11531150
int64_t destRank = packOp.getDestRank();
1154-
int64_t numTiles = destRank - srcRank;
1151+
ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
1152+
int64_t numberOfTiles = innerDimsPos.size();
11551153

1156-
// 1. Extract the inner tile sizes.
1157-
// Where possible, values are replaced with constant attributes (to match the
1158-
// behaviour of `getPackOpSourceOrPaddedSource`).
1159-
SmallVector<OpFoldResult> tileSizes;
1160-
for (auto i : llvm::seq<unsigned>(0, srcRank)) {
1161-
if (dimAndTileMapping.count(i)) {
1162-
// Rather than taking the tile size as is, extact the actual constant
1163-
// value Attribute where possible, e.g.:
1164-
// [Value: %tile_size = arith.constant 8 : index] --> [Attribute: 8]
1165-
auto [_, tileSize] =
1166-
getSimplifiedOfrAndStaticSizePair(dimAndTileMapping[i], rewriter);
1167-
tileSizes.push_back(tileSize);
1168-
}
1169-
}
1154+
// 1. Get the input that is going to be packed. If the input requires padding,
1155+
// add a padding operation and return that as the input.
1156+
Value input = getPackOpSourceOrPaddedSource(rewriter, packOp);
11701157

11711158
// 2. Transpose the input to match the inner tile order:
11721159
// %init = tensor.empty()
11731160
// %transposed_tile = linalg.transpose ins(%source_or_padded_source),
11741161
// outs(%init)
11751162
// Assumptions made:
1176-
// 1. All outer dims are 1 - the corresponding transposition order doesn't
1163+
// - All outer dims are 1 - the corresponding transposition order doesn't
11771164
// matter, but requires all dim indices to be present.
1165+
1166+
// 2.1 Get the permutation for linalg.transpose
11781167
SmallVector<int64_t> srcPermForTranspose;
1179-
ArrayRef<int64_t> innerDimPos(packOp.getInnerDimsPos());
11801168
for (int64_t i = 0; i < srcRank; i++) {
11811169
// We assume the `k` dimensions of the inner dim position, where `k` is the
11821170
// rank of the inner tiling, correspond to the last `k` indices of the
@@ -1185,27 +1173,34 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
11851173
// rank of the source tensor. For example if we have a source tensor with
11861174
// indices [0, 1, 2, 3] and inner dim position of [3, 0], the remaining
11871175
// indices are [1, 2]. and the transpose will be [1, 2, 3, 0].
1188-
if (llvm::is_contained(innerDimPos, i))
1176+
if (llvm::is_contained(innerDimsPos, i))
11891177
continue;
11901178
srcPermForTranspose.push_back(i);
11911179
}
1192-
srcPermForTranspose.append(innerDimPos.begin(), innerDimPos.end());
1180+
srcPermForTranspose.append(innerDimsPos.begin(), innerDimsPos.end());
1181+
1182+
// 2.2 Create the init tensor for linalg.transpose with the correct shape
1183+
SmallVector<OpFoldResult> shapeForEmptyOp(srcRank - numberOfTiles,
1184+
oneIdxAttr);
1185+
shapeForEmptyOp.append(packOp.getMixedTiles());
1186+
1187+
// getMixedTiles() may contain Values pointing to constant ops, not the
1188+
// constant attributes. Replace them with a true OpFoldResult.
1189+
llvm::transform(shapeForEmptyOp, shapeForEmptyOp.begin(),
1190+
[&](OpFoldResult ofr) {
1191+
if (auto val = llvm::dyn_cast<Value>(ofr))
1192+
return getAsOpFoldResult(val);
1193+
return ofr;
1194+
});
11931195

11941196
LDBG() << "Pack permutation: " << packOp;
11951197
LDBG() << "perm: " << llvm::interleaved(srcPermForTranspose);
1198+
LDBG() << "Shape of empty tensor: " << llvm::interleaved(shapeForEmptyOp);
11961199

1197-
// 2.1 Create tensor.empty (init value for TransposeOp)
1198-
SmallVector<OpFoldResult> transShapeForEmptyOp(srcRank - numTiles,
1199-
oneIdxAttr);
1200-
transShapeForEmptyOp.append(tileSizes);
1201-
1202-
applyPermutationToVector<OpFoldResult>(transShapeForEmptyOp,
1203-
srcPermForTranspose);
1204-
Value empty =
1205-
tensor::EmptyOp::create(rewriter, loc, transShapeForEmptyOp,
1206-
packOp.getSourceType().getElementType());
1200+
Value empty = tensor::EmptyOp::create(
1201+
rewriter, loc, shapeForEmptyOp, packOp.getSourceType().getElementType());
12071202

1208-
// 2.2 Create linalg.transpose
1203+
// 2.3 Create linalg.transpose
12091204
auto transposedOp = linalg::TransposeOp::create(rewriter, loc, input, empty,
12101205
srcPermForTranspose);
12111206

@@ -1214,8 +1209,7 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
12141209
SmallVector<OpFoldResult> writeStrides(destRank, oneIdxAttr);
12151210
SmallVector<OpFoldResult> writeOffsets(destRank, zeroIdxAttr);
12161211
// Outer dims are all 1s!
1217-
SmallVector<OpFoldResult> writeSizes(destRank - dimAndTileMapping.size(),
1218-
oneIdxAttr);
1212+
SmallVector<OpFoldResult> writeSizes(destRank - numberOfTiles, oneIdxAttr);
12191213
SmallVector<int64_t> writeShape;
12201214

12211215
for (auto tileSize : packOp.getMixedTiles()) {

mlir/test/Dialect/Linalg/decompose-pack.mlir

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,3 +274,24 @@ func.func @pack_with_adjacent_trailing_dimensions_inner_dims_pos_and_unit_outer(
274274
// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %[[DEST]]
275275
// CHECK-SAME: [0, 0, 0, 0, 0] [1, 1, 1, 4, 1] [1, 1, 1, 1, 1] : tensor<1x4x1xf32> into tensor<1x1x1x4x1xf32>
276276
// CHECK: return %[[INSERT]]
277+
278+
// -----
279+
280+
// The following example shows a pack operation where the inner dims
281+
// positions are non-adjacent and non-permuted.
282+
func.func @pack_with_non_adjacent_and_non_permuted_inner_dims(%arg0: tensor<8x1x1x1xf32>, %arg1:tensor<1x1x1x1x8x1xf32>) -> tensor<1x1x1x1x8x1xf32> {
283+
%pack = linalg.pack %arg0 outer_dims_perm = [0, 1, 2, 3] inner_dims_pos = [0, 3] inner_tiles = [8, 1] into %arg1: tensor<8x1x1x1xf32> -> tensor<1x1x1x1x8x1xf32>
284+
return %pack : tensor<1x1x1x1x8x1xf32>
285+
}
286+
287+
// CHECK-LABEL: func.func @pack_with_non_adjacent_and_non_permuted_inner_dims
288+
// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
289+
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
290+
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x1x8x1xf32>
291+
// CHECK: %[[TRANSP:.+]] = linalg.transpose
292+
// CHECK-SAME: ins(%[[SRC]] : tensor<8x1x1x1xf32>)
293+
// CHECK-SAME: outs(%[[EMPTY]] : tensor<1x1x8x1xf32>)
294+
// CHECK-SAME: permutation = [1, 2, 0, 3]
295+
// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %[[DEST]]
296+
// CHECK-SAME: [0, 0, 0, 0, 0, 0] [1, 1, 1, 1, 8, 1] [1, 1, 1, 1, 1, 1] : tensor<1x1x8x1xf32> into tensor<1x1x1x1x8x1xf32>
297+
// CHECK: return %[[INSERT]]

0 commit comments

Comments
 (0)