Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -1651,7 +1651,10 @@ struct DecomposePadOpPattern : public OpRewritePattern<tensor::PadOp> {
/// * tensor::PadOp + linalg::TransposeOp + tensor::EmptyOp +
/// tensor::InsertSliceOp ops.
///
/// Requires that all the outer dims of the input linalg::PackOp are 1.
/// Requires that all the tile outer dims of the input linalg::PackOp are 1.
/// Note that this constraint means to effectively one tile is packed.
///
/// In addition, assumes that the un-tiled outer dims are not permuted.
///
/// Before:
/// ```
Expand Down Expand Up @@ -1691,6 +1694,7 @@ struct DecomposeOuterUnitDimsPackOpPattern
/// * tensor::ExtractSliceOp + linalg::TransposeOp + tensor::InsertSliceOp
///
/// Requires that all the tiled outer dims of the input linalg::PackOp are 1.
/// Note that this constraint means to effectively one tile is unpacked.
///
/// Before:
/// ```
Expand Down
11 changes: 9 additions & 2 deletions mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5272,11 +5272,18 @@ ArrayRef<int64_t> PackOp::getAllOuterDims() {

SmallVector<int64_t> PackOp::getTiledOuterDims() {
auto innerDimsPos = getInnerDimsPos();
auto packedShape = getDestType().getShape();
SmallVector<int64_t> outerDims(getAllOuterDims());
SmallVector<int64_t> res;

// Recover the original order of the outer dims.
SmallVector<int64_t> outerDimPermInv(getOuterDimsPerm());
invertPermutationVector(outerDimPermInv);
if (!outerDimPermInv.empty())
applyPermutationToVector(outerDims, outerDimPermInv);

// Collect the outer dims corresponding to the tilled inner dims.
for (auto index : innerDimsPos)
res.push_back(packedShape[index]);
res.push_back(outerDims[index]);

return res;
}
Expand Down
58 changes: 40 additions & 18 deletions mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1134,9 +1134,7 @@ getPackUnpackRankReducedPerm(ArrayRef<int64_t> shape,

LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
linalg::PackOp packOp, PatternRewriter &rewriter) const {
// TODO: support the case that outer dimensions are not all 1s. A
// tensor.expand_shape will be generated in this case.
if (llvm::any_of(packOp.getAllOuterDims(),
if (llvm::any_of(packOp.getTiledOuterDims(),
[](int64_t dim) { return dim != 1; })) {
return rewriter.notifyMatchFailure(
packOp, "not all outer dimensions of the result are 1s");
Expand All @@ -1149,7 +1147,6 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
int64_t srcRank = packOp.getSourceRank();
int64_t destRank = packOp.getDestRank();
ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
int64_t numberOfTiles = innerDimsPos.size();

// 1. Get the input that is going to be packed. If the input requires padding,
// add a padding operation and return that as the input.
Expand All @@ -1160,10 +1157,14 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
// %transposed_tile = linalg.transpose ins(%source_or_padded_source),
// outs(%init)
// Assumptions made:
// - All outer dims are 1 - the corresponding transposition order doesn't
// matter, but requires all dim indices to be present.

// 2.1 Get the permutation for linalg.transpose
// - All tiled outer dims are 1 - the corresponding transposition order
// doesn't matter, but requires all dim indices to be present.
// - Un-tiled outer dims remain un-permuted. (TODO: Fail when this does not
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the pattern crashes at the moment when this happens? I am a bit afraid that this can cause hard to debug errors.

Wouldn't it be sufficient if you check that outer_dims_perm excluding the dimensions present in inner_dims_pos are monotonically growing?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion, updated in the latest commit!

// hold)

// 2.1 Get the permutation for linalg.transpose:
// [ untiled-dims, inner-dims-pos ]
// Note, this logic assumes that the untiled dims are not permuted.
SmallVector<int64_t> srcPermForTranspose;
for (int64_t i = 0; i < srcRank; i++) {
// We assume the `k` dimensions of the inner dim position, where `k` is the
Expand All @@ -1179,9 +1180,19 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
}
srcPermForTranspose.append(innerDimsPos.begin(), innerDimsPos.end());

// 2.2 Create the init tensor for linalg.transpose with the correct shape
SmallVector<OpFoldResult> shapeForEmptyOp(srcRank - numberOfTiles,
oneIdxAttr);
// 2.2 Create the init tensor for linalg.transpose with the correct shape:
// [ untiled-dims, tiled-dims ]
ShapedType inputTy = cast<ShapedType>(input.getType());
SmallVector<OpFoldResult> shapeForEmptyOp;
for (int64_t i = 0; i < srcRank; i++) {
if (llvm::is_contained(innerDimsPos, i))
continue;
if (inputTy.isStaticDim(i))
shapeForEmptyOp.push_back(rewriter.getIndexAttr(inputTy.getShape()[i]));
else
shapeForEmptyOp.emplace_back(
tensor::DimOp::create(rewriter, loc, input, i).getResult());
}
shapeForEmptyOp.append(packOp.getMixedTiles());

// getMixedTiles() may contain Values pointing to constant ops, not the
Expand All @@ -1206,23 +1217,34 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(

// 3. Insert the inner tile to the destination:
// %inserted_tile = tensor.insert_slice(%transposed_tile)
SmallVector<OpFoldResult> writeStrides(destRank, oneIdxAttr);
SmallVector<OpFoldResult> writeOffsets(destRank, zeroIdxAttr);
// Outer dims are all 1s!
SmallVector<OpFoldResult> writeSizes(destRank - numberOfTiles, oneIdxAttr);
SmallVector<int64_t> writeShape;

// Compute the sizes attribute:
// [ outer-dims, tile-sizes ]
// Note that the output from the transpose Op excludes the tiled outer dims.
// Given the assumptions (all tiled outer dims == 1), we can safely use a
// rank-expanding tensor.insert_slice. Rather than manually computing where to
// insert new unit dims (resulting from the expansion), use the Pack op
// attributes.
SmallVector<OpFoldResult> writeSizes;
for (auto size : packOp.getAllOuterDims()) {
writeSizes.push_back(rewriter.getIndexAttr(size));
}

for (auto tileSize : packOp.getMixedTiles()) {
auto [tileSizeStatic, tileSizeOfr] =
getSimplifiedOfrAndStaticSizePair(tileSize, rewriter);
writeSizes.push_back(tileSizeOfr);
writeShape.push_back(tileSizeStatic);
}

// 4. Replace tensor.packOp with tensor.insert_slice created above
SmallVector<OpFoldResult> writeStrides(destRank, oneIdxAttr);
SmallVector<OpFoldResult> writeOffsets(destRank, zeroIdxAttr);

// TODO: A constructor that doesn't require strised nor offsets.
auto insert = tensor::InsertSliceOp::create(
rewriter, loc, transposedOp.getResult()[0], packOp.getDest(),
writeOffsets, writeSizes, writeStrides);

// 4. Replace tensor.packOp with tensor.insert_slice created above
rewriter.replaceOp(packOp, insert.getResult());

return success();
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2310,6 +2310,7 @@ RankedTensorType ExtractSliceOp::inferResultType(
sourceTensorType.getEncoding());
}

// TODO: This uses neither offsets nor strides!
RankedTensorType ExtractSliceOp::inferResultType(
RankedTensorType sourceTensorType, ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes, ArrayRef<OpFoldResult> strides) {
Expand Down
36 changes: 36 additions & 0 deletions mlir/test/Dialect/Linalg/decompose-pack.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,25 @@ func.func @simple_KCRS_to_KCRSsr(%arg0: tensor<?x?xi32>, %arg1: tensor<1x1x?x1xi

// -----

func.func @NCHW_to_NCHWc(%src: tensor<2x32x16x8xf32>, %dest: tensor<2x1x16x8x32xf32>) -> tensor<2x1x16x8x32xf32> {
%pack = linalg.pack %src
inner_dims_pos = [1]
inner_tiles = [32] into %dest
: tensor<2x32x16x8xf32> -> tensor<2x1x16x8x32xf32>
return %pack : tensor<2x1x16x8x32xf32>
}
// CHECK-LABEL: func.func @NCHW_to_NCHWc(
// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
// CHECK: %[[INIT:.*]] = tensor.empty() : tensor<2x16x8x32xf32>
// CHECK: %[[TR:.*]] = linalg.transpose ins(%[[SRC]] : tensor<2x32x16x8xf32>) outs(%[[INIT]] : tensor<2x16x8x32xf32>) permutation = [0, 2, 3, 1]
// CHECK: %[[RES:.*]] = tensor.insert_slice %[[TR]] into %[[DEST]]
// CHECK-SAME: [0, 0, 0, 0, 0] [2, 1, 16, 8, 32] [1, 1, 1, 1, 1]
// CHECK-SAME: : tensor<2x16x8x32xf32> into tensor<2x1x16x8x32xf32>
// CHECK: return %[[RES]] : tensor<2x1x16x8x32xf32>

// -----

func.func @simple_pad_and_pack_static_tiles(%input: tensor<5x1xf32>, %output: tensor<1x1x8x2xf32>, %pad: f32) -> tensor<1x1x8x2xf32> {
%0 = linalg.pack %input padding_value(%pad : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %output : tensor<5x1xf32> -> tensor<1x1x8x2xf32>
return %0 : tensor<1x1x8x2xf32>
Expand Down Expand Up @@ -295,3 +314,20 @@ func.func @pack_with_non_adjacent_and_non_permuted_inner_dims(%arg0: tensor<8x1x
// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %[[DEST]]
// CHECK-SAME: [0, 0, 0, 0, 0, 0] [1, 1, 1, 1, 8, 1] [1, 1, 1, 1, 1, 1] : tensor<1x1x8x1xf32> into tensor<1x1x1x1x8x1xf32>
// CHECK: return %[[INSERT]]

// -----
/// Note "126", which is a non-unit tile-outer-dim. This is not supported.

func.func @negative_non_unit_tiled_outer_dim(%dest: tensor<1x126x1x1x8xf32>, %src: tensor<1x1x1x1001xf32>, %pad: f32) -> tensor<1x126x1x1x8xf32> {
%pack = linalg.pack %src
padding_value(%pad : f32)
outer_dims_perm = [0, 3, 2, 1]
inner_dims_pos = [3]
inner_tiles = [8]
into %dest : tensor<1x1x1x1001xf32>
-> tensor<1x126x1x1x8xf32>

return %pack : tensor<1x126x1x1x8xf32>
}
// CHECK-LABEL: @negative_non_unit_tiled_outer_dim(
// CHECK: linalg.pack