Skip to content

Commit eae9f99

Browse files
committed
(linalg.pack): simplify outer dims patterns after review
1 parent 9f031bc commit eae9f99

File tree

2 files changed

+30
-32
lines changed

2 files changed

+30
-32
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,8 @@ class Linalg_RelayoutOp<string mnemonic, list<Trait> traits = []> :
5757
/// tile factors.
5858
DenseMap<int64_t, OpFoldResult> getDimAndTileMapping();
5959

60-
/// Return the tile sizes as OpFoldResult.
60+
/// Return the tile sizes as OpFoldResult. Will return the Value
61+
/// of the constant Op, not the constant Attribute.
6162
SmallVector<OpFoldResult> getMixedTiles();
6263

6364
/// Return the tile sizes as `int64_t`. If a tile size is dynamic

mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp

Lines changed: 28 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1146,38 +1146,25 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
11461146
Attribute oneIdxAttr = rewriter.getIndexAttr(1);
11471147
Location loc = packOp.getLoc();
11481148

1149-
Value input = getPackOpSourceOrPaddedSource(rewriter, packOp);
1150-
DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
1151-
packOp.getDimAndTileMapping();
11521149
int64_t srcRank = packOp.getSourceRank();
11531150
int64_t destRank = packOp.getDestRank();
1151+
ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
1152+
int64_t numberOfTiles = innerDimsPos.size();
11541153

1155-
// 1. Extract the inner tile sizes and the shapes for the tensor.empty op
1156-
// before transposing. Where possible, values are replaced with constant
1157-
// attributes (to match the behaviour of `getPackOpSourceOrPaddedSource`).
1158-
SmallVector<OpFoldResult> transShapeForEmptyOp(srcRank, oneIdxAttr);
1159-
SmallVector<OpFoldResult> tileSizes;
1160-
for (auto i : llvm::seq<unsigned>(0, srcRank)) {
1161-
if (dimAndTileMapping.count(i)) {
1162-
// Rather than taking the tile size as is, extact the actual constant
1163-
// value Attribute where possible, e.g.:
1164-
// [Value: %tile_size = arith.constant 8 : index] --> [Attribute: 8]
1165-
auto [_, tileSize] =
1166-
getSimplifiedOfrAndStaticSizePair(dimAndTileMapping[i], rewriter);
1167-
tileSizes.push_back(tileSize);
1168-
transShapeForEmptyOp[i] = tileSize;
1169-
}
1170-
}
1154+
// 1. Get the input that is going to be packed. If the input requires padding,
1155+
// add a padding operation and return that as the input.
1156+
Value input = getPackOpSourceOrPaddedSource(rewriter, packOp);
11711157

11721158
// 2. Transpose the input to match the inner tile order:
11731159
// %init = tensor.empty()
11741160
// %transposed_tile = linalg.transpose ins(%source_or_padded_source),
11751161
// outs(%init)
11761162
// Assumptions made:
1177-
// 1. All outer dims are 1 - the corresponding transposition order doesn't
1163+
// - All outer dims are 1 - the corresponding transposition order doesn't
11781164
// matter, but requires all dim indices to be present.
1165+
1166+
// 2.1 Get the permutation for linalg.transpose
11791167
SmallVector<int64_t> srcPermForTranspose;
1180-
ArrayRef<int64_t> innerDimPos(packOp.getInnerDimsPos());
11811168
for (int64_t i = 0; i < srcRank; i++) {
11821169
// We assume the `k` dimensions of the inner dim position, where `k` is the
11831170
// rank of the inner tiling, correspond to the last `k` indices of the
@@ -1186,21 +1173,32 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
11861173
// rank of the source tensor. For example if we have a source tensor with
11871174
// indices [0, 1, 2, 3] and inner dim position of [3, 0], the remaining
11881175
// indices are [1, 2]. and the transpose will be [1, 2, 3, 0].
1189-
if (llvm::is_contained(innerDimPos, i))
1176+
if (llvm::is_contained(innerDimsPos, i))
11901177
continue;
11911178
srcPermForTranspose.push_back(i);
11921179
}
1193-
srcPermForTranspose.append(innerDimPos.begin(), innerDimPos.end());
1180+
srcPermForTranspose.append(innerDimsPos.begin(), innerDimsPos.end());
1181+
1182+
// 2.2 Create the init tensor for linalg.transpose with the correct shape
1183+
SmallVector<OpFoldResult> shapeForEmptyOp(srcRank - numberOfTiles,
1184+
oneIdxAttr);
1185+
shapeForEmptyOp.append(packOp.getMixedTiles());
1186+
1187+
// getMixedTiles() may contain Values pointing to constant ops, not the
1188+
// constant attributes. Replace them with a true OpFoldResult.
1189+
llvm::transform(shapeForEmptyOp, shapeForEmptyOp.begin(),
1190+
[&](OpFoldResult ofr) {
1191+
if (auto val = llvm::dyn_cast<Value>(ofr))
1192+
return getAsOpFoldResult(val);
1193+
return ofr;
1194+
});
11941195

11951196
LDBG() << "Pack permutation: " << packOp;
11961197
LDBG() << "perm: " << llvm::interleaved(srcPermForTranspose);
1198+
LDBG() << "Shape of empty tensor: " << llvm::interleaved(shapeForEmptyOp);
11971199

1198-
// 2.2 Transpose the tensor.empty shapes.
1199-
applyPermutationToVector<OpFoldResult>(transShapeForEmptyOp,
1200-
srcPermForTranspose);
1201-
Value empty =
1202-
tensor::EmptyOp::create(rewriter, loc, transShapeForEmptyOp,
1203-
packOp.getSourceType().getElementType());
1200+
Value empty = tensor::EmptyOp::create(
1201+
rewriter, loc, shapeForEmptyOp, packOp.getSourceType().getElementType());
12041202

12051203
// 2.3 Create linalg.transpose
12061204
auto transposedOp = linalg::TransposeOp::create(rewriter, loc, input, empty,
@@ -1211,8 +1209,7 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
12111209
SmallVector<OpFoldResult> writeStrides(destRank, oneIdxAttr);
12121210
SmallVector<OpFoldResult> writeOffsets(destRank, zeroIdxAttr);
12131211
// Outer dims are all 1s!
1214-
SmallVector<OpFoldResult> writeSizes(destRank - dimAndTileMapping.size(),
1215-
oneIdxAttr);
1212+
SmallVector<OpFoldResult> writeSizes(destRank - numberOfTiles, oneIdxAttr);
12161213
SmallVector<int64_t> writeShape;
12171214

12181215
for (auto tileSize : packOp.getMixedTiles()) {

0 commit comments

Comments
 (0)