Skip to content

Commit 8ba75d4

Browse files
committed
Fix comments
1 parent f331d44 commit 8ba75d4

File tree

3 files changed

+28
-22
lines changed

3 files changed

+28
-22
lines changed

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1650,11 +1650,12 @@ struct DecomposePadOpPattern : public OpRewritePattern<tensor::PadOp> {
16501650
/// Rewrites a linalg::PackOp into a sequence of:
16511651
/// * tensor::PadOp + linalg::TransposeOp + tensor::EmptyOp +
16521652
/// tensor::InsertSliceOp ops.
1653+
/// (InsertSliceOp is rank-expanding).
16531654
///
1654-
/// Requires that all the tiled outer dims of the input linalg::PackOp are 1.
1655-
/// Note that this constraint means to effectively one tile is packed.
1655+
/// Requires that all the tiled-outer-dims of the input linalg::PackOp are 1.
1656+
/// Note that this constraint means that effectively exactly one tile is packed.
16561657
///
1657-
/// In addition, assumes that the un-tiled outer dims are not permuted.
1658+
/// In addition, assumes that the un-tiled-outer-dims are not permuted.
16581659
///
16591660
/// Before:
16601661
/// ```
@@ -1690,11 +1691,13 @@ struct DecomposeOuterUnitDimsPackOpPattern
16901691
PatternRewriter &rewriter) const override;
16911692
};
16921693

1693-
/// Rewrites a linalg::UnPackOp into a sequence of rank-reduced
1694+
/// Rewrites a linalg::UnPackOp into a sequence of:
16941695
/// * tensor::ExtractSliceOp + linalg::TransposeOp + tensor::InsertSliceOp
1696+
/// (ExtractSliceOp is rank-reducing).
16951697
///
1696-
/// Requires that all the tiled outer dims of the input linalg::PackOp are 1.
1697-
/// Note that this constraint means to effectively one tile is unpacked.
1698+
/// Requires that all the tiled-outer-dims of the input linalg::UnPackOp are 1.
1699+
/// Note that this constraint means that effectively exactly one tile is
1700+
/// unpacked.
16981701
///
16991702
/// Before:
17001703
/// ```

mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1143,15 +1143,18 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
11431143
ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
11441144
auto outerDimsPerm = packOp.getOuterDimsPerm();
11451145

1146-
// Verify that there are no non-unit un-tiled outer dims that are permuted.
1147-
// Supporting such cases will require refining the logic to generate the
1148-
// Transpose Op.
1146+
// Verify that there are no:
1147+
// * non-unit + un-tiled-outer-dims,
1148+
// that are permuted. Supporting such cases would require refining the logic
1149+
// that generates the Transpose Op.
11491150
if (!llvm::all_of(outerDimsPerm, [&innerDimsPos, &packOp](int64_t dim) {
11501151
static int prev = 0;
1151-
// Tiled dims are not relevant here.
1152+
// Skip tiled dims - these can be permuted.
11521153
if (llvm::is_contained(innerDimsPos, dim))
11531154
return true;
1154-
// Was this dim permuted? Note, permuting unit dims is fine.
1155+
1156+
// Check whether this dim has been permuted. Permuting unit dims is fine
1157+
// as that's effectively a no-op.
11551158
if (dim < prev && (packOp.getType().getShape()[prev] != 1 ||
11561159
packOp.getType().getShape()[dim] != 1))
11571160
return false;
@@ -1182,8 +1185,7 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
11821185
// Assumptions made:
11831186
// - All tiled outer dims are 1 - the corresponding transposition order
11841187
// doesn't matter, but requires all dim indices to be present.
1185-
// - Un-tiled outer dims remain un-permuted. (TODO: Fail when this does not
1186-
// hold)
1188+
// - Un-tiled outer dims remain un-permuted.
11871189

11881190
// 2.1 Get the permutation for linalg.transpose:
11891191
// [ untiled-dims, inner-dims-pos ]
@@ -1240,16 +1242,15 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
12401242
auto transposedOp = linalg::TransposeOp::create(rewriter, loc, input, empty,
12411243
srcPermForTranspose);
12421244

1243-
// 3. Insert the inner tile to the destination:
1245+
// 3. Insert the inner tile into the destination tensor:
12441246
// %inserted_tile = tensor.insert_slice(%transposed_tile)
12451247

12461248
// Compute the sizes attribute:
12471249
// [ outer-dims, tile-sizes ]
12481250
// Note that the output from the transpose Op excludes the tiled outer dims.
1249-
// Given the assumptions (all tiled outer dims == 1), we can safely use a
1250-
// rank-expanding tensor.insert_slice. Rather than manually computing where to
1251-
// insert new unit dims (resulting from the expansion), use the Pack op
1252-
// attributes.
1251+
// However, given the assumption that:
1252+
// * all tiled outer dims == 1,
1253+
// we can just use a rank-expanding tensor.insert_slice.
12531254
SmallVector<OpFoldResult> writeSizes;
12541255
for (auto size : packOp.getAllOuterDims()) {
12551256
writeSizes.push_back(rewriter.getIndexAttr(size));
@@ -1261,10 +1262,11 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
12611262
writeSizes.push_back(tileSizeOfr);
12621263
}
12631264

1265+
// TODO: Add a constructor for tensor.insert_slice that doesn't require
1266+
// strides nor offsets.
12641267
SmallVector<OpFoldResult> writeStrides(destRank, oneIdxAttr);
12651268
SmallVector<OpFoldResult> writeOffsets(destRank, zeroIdxAttr);
12661269

1267-
// TODO: A constructor that doesn't require strides nor offsets.
12681270
auto insert = tensor::InsertSliceOp::create(
12691271
rewriter, loc, transposedOp.getResult()[0], packOp.getDest(),
12701272
writeOffsets, writeSizes, writeStrides);

mlir/test/Dialect/Linalg/decompose-pack.mlir

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -340,16 +340,17 @@ func.func @pack_with_non_adjacent_and_non_permuted_inner_dims(%arg0: tensor<8x1x
340340
// CHECK: return %[[INSERT]]
341341

342342
// -----
343-
/// Note "126", which is a non-unit tile-outer-dim. This is not supported.
343+
344+
/// Note "126", which is a non-unit tiled-outer-dim. This is not supported.
344345

345346
func.func @negative_non_unit_tiled_outer_dim(%dest: tensor<1x126x1x1x8xf32>, %src: tensor<1x1x1x1001xf32>, %pad: f32) -> tensor<1x126x1x1x8xf32> {
346347
%pack = linalg.pack %src
347348
padding_value(%pad : f32)
348349
outer_dims_perm = [0, 3, 2, 1]
349350
inner_dims_pos = [3]
350351
inner_tiles = [8]
351-
into %dest : tensor<1x1x1x1001xf32>
352-
-> tensor<1x126x1x1x8xf32>
352+
into %dest
353+
: tensor<1x1x1x1001xf32> -> tensor<1x126x1x1x8xf32>
353354

354355
return %pack : tensor<1x126x1x1x8xf32>
355356
}

0 commit comments

Comments
 (0)