Skip to content

Commit 66e520e

Browse files
committed
fixup! [mlir][linalg] Fix and Refactor DecomposeOuterUnitDimsUnPackOpPattern
Update comments
1 parent e56f5cb commit 66e520e

File tree

1 file changed

+22
-12
lines changed

1 file changed

+22
-12
lines changed

mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1261,21 +1261,30 @@ LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
12611261
Attribute zeroIdxAttr = rewriter.getIndexAttr(0);
12621262
Attribute oneIdxAttr = rewriter.getIndexAttr(1);
12631263

1264-
// The sizes, affset and strides attributes for ExtractSliceOp.
1264+
// The shape for ExtractSliceOp. Note that this will consist of 3 blocks of
1265+
// dims:
1266+
// [ outer-untiled-dims, outer-tiled-dims, tile-sizes ]
1267+
SmallVector<int64_t> readShapeForExtractSlice;
1268+
// The sizes attribute for ExtractSliceOp. Due to rank-reducing (and
1269+
// outer-tiled-dims being all 1), this will be
1270+
// [ outer-untiled-dims, tile-sizes ]
12651271
SmallVector<OpFoldResult> extractSliceSizes;
1272+
// The offset and strides attributes for ExtractSliceOp.
12661273
SmallVector<OpFoldResult> extractSliceOffsets(srcRank, zeroIdxAttr);
12671274
SmallVector<OpFoldResult> extractSliceStrides(srcRank, oneIdxAttr);
1268-
// The shape for ExtractSliceOp (due to rank-reducing, this is likely !=
1269-
// extractSliceSizes).
1270-
SmallVector<int64_t> readShapeForExtractSlice;
12711275

12721276
// Shape for EmptyOp that's used as the init value for TransposeOp below.
1273-
// This should match tile size + transposition.
1277+
// This should be:
1278+
// [ outer-untiled-dims, tile-sizes ]
1279+
// However, skip unit dims - TransposeOp (below) applies rank-reduced
1280+
// permutation.
12741281
SmallVector<OpFoldResult> shapeForEmptyOp;
12751282

12761283
for (auto i : llvm::seq<unsigned>(0, destRank)) {
1277-
// Given the assumption that all outer tiled dims are 1, the corresponding
1278-
// slice size to read is also 1. As this will be rank-reducing "extract
1284+
// Compute sizes attribute for ExtractSliceOp - outer-tiled-dims.
1285+
//
1286+
// As all outer tiled dims are 1, so the corresponding
1287+
// slice size to read will also 1. As this will be rank-reducing "extract
12791288
// slice" (i.e. the unit dims will be "collapsed"), there's no need to
12801289
// update:
12811290
// * the output shape for ExtractSliceOp, nor
@@ -1285,7 +1294,8 @@ LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
12851294
continue;
12861295
}
12871296

1288-
// Compute sizes attribute for ExtractSliceOp + EmptyOp
1297+
// Compute sizes attribute for ExtractSliceOp + EmptyOp -
1298+
// outer-untiled-dims
12891299
if (ShapedType::isDynamic(srcShape[i])) {
12901300
OpFoldResult dynamicDim =
12911301
rewriter.create<tensor::DimOp>(loc, source, i).getResult();
@@ -1296,15 +1306,15 @@ LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
12961306
if (srcShape[i] != 1)
12971307
shapeForEmptyOp.push_back(rewriter.getIndexAttr(srcShape[i]));
12981308
}
1299-
// Compute the output shape for ExtractSliceOp (take into account
1300-
// rank-reducing)
1309+
// Compute the output shape for ExtractSliceOp - outer-untiled-dims (take
1310+
// into account rank-reducing)
13011311
if (srcShape[i] != 1) {
13021312
readShapeForExtractSlice.push_back(srcShape[i]);
13031313
}
13041314
}
1315+
// Append the tile sizes to "sizes attribute" for ExtractSliceOp and the
1316+
// shape for EmptyOp.
13051317
auto mixedTiles = unpackOp.getMixedTiles();
1306-
// TODO: This effectively assumes that that tile sizes match the trailing
1307-
// sizes for ExtractSliceOp and EmptyOp - document this.
13081318
extractSliceSizes.append(mixedTiles.begin(), mixedTiles.end());
13091319
shapeForEmptyOp.append(mixedTiles.begin(), mixedTiles.end());
13101320

0 commit comments

Comments
 (0)