Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -1651,7 +1651,10 @@ struct DecomposePadOpPattern : public OpRewritePattern<tensor::PadOp> {
/// * tensor::PadOp + linalg::TransposeOp + tensor::EmptyOp +
/// tensor::InsertSliceOp ops.
///
/// Requires that all the outer dims of the input linalg::PackOp are 1.
/// Requires that all the tiled outer dims of the input linalg::PackOp are 1.
/// Note that this constraint means to effectively one tile is packed.
///
/// In addition, assumes that the un-tiled outer dims are not permuted.
///
/// Before:
/// ```
Expand Down Expand Up @@ -1691,6 +1694,7 @@ struct DecomposeOuterUnitDimsPackOpPattern
/// * tensor::ExtractSliceOp + linalg::TransposeOp + tensor::InsertSliceOp
///
/// Requires that all the tiled outer dims of the input linalg::PackOp are 1.
/// Note that this constraint means to effectively one tile is unpacked.
///
/// Before:
/// ```
Expand Down
11 changes: 9 additions & 2 deletions mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5272,11 +5272,18 @@ ArrayRef<int64_t> PackOp::getAllOuterDims() {

SmallVector<int64_t> PackOp::getTiledOuterDims() {
auto innerDimsPos = getInnerDimsPos();
auto packedShape = getDestType().getShape();
SmallVector<int64_t> outerDims(getAllOuterDims());
SmallVector<int64_t> res;

// Recover the original order of the outer dims.
SmallVector<int64_t> outerDimPermInv(getOuterDimsPerm());
invertPermutationVector(outerDimPermInv);
if (!outerDimPermInv.empty())
applyPermutationToVector(outerDims, outerDimPermInv);

// Collect the outer dims corresponding to the tilled inner dims.
for (auto index : innerDimsPos)
res.push_back(packedShape[index]);
res.push_back(outerDims[index]);

return res;
}
Expand Down
87 changes: 67 additions & 20 deletions mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1134,22 +1134,42 @@ getPackUnpackRankReducedPerm(ArrayRef<int64_t> shape,

LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
linalg::PackOp packOp, PatternRewriter &rewriter) const {
// TODO: support the case that outer dimensions are not all 1s. A
// tensor.expand_shape will be generated in this case.
if (llvm::any_of(packOp.getAllOuterDims(),
if (llvm::any_of(packOp.getTiledOuterDims(),
[](int64_t dim) { return dim != 1; })) {
return rewriter.notifyMatchFailure(
packOp, "not all outer dimensions of the result are 1s");
}

ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
auto outerDimsPerm = packOp.getOuterDimsPerm();

// Verify that there are no non-unit un-tiled outer dims that are permuted.
// Supporting such cases will require refining the logic to generate the
// Transpose Op.
if (!llvm::all_of(outerDimsPerm, [&innerDimsPos, &packOp](int64_t dim) {
static int prev = 0;
// Tiled dims are not relevant here.
if (llvm::is_contained(innerDimsPos, dim))
return true;
// Was this dim permuted? Note, permuting unit dims is fine.
if (dim < prev && (packOp.getType().getShape()[prev] != 1 ||
packOp.getType().getShape()[dim] != 1))
return false;

prev = dim;
return true;
})) {
return rewriter.notifyMatchFailure(
packOp, "At least one non-unit and un-tiled outer dim is permuted, "
"this is not supported ATM!");
}

Attribute zeroIdxAttr = rewriter.getIndexAttr(0);
Attribute oneIdxAttr = rewriter.getIndexAttr(1);
Location loc = packOp.getLoc();

int64_t srcRank = packOp.getSourceRank();
int64_t destRank = packOp.getDestRank();
ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
int64_t numberOfTiles = innerDimsPos.size();

// 1. Get the input that is going to be packed. If the input requires padding,
// add a padding operation and return that as the input.
Expand All @@ -1160,10 +1180,14 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
// %transposed_tile = linalg.transpose ins(%source_or_padded_source),
// outs(%init)
// Assumptions made:
// - All outer dims are 1 - the corresponding transposition order doesn't
// matter, but requires all dim indices to be present.

// 2.1 Get the permutation for linalg.transpose
// - All tiled outer dims are 1 - the corresponding transposition order
// doesn't matter, but requires all dim indices to be present.
// - Un-tiled outer dims remain un-permuted. (TODO: Fail when this does not
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the pattern crashes at the moment when this happens? I am a bit afraid that this can cause hard to debug errors.

Wouldn't it be sufficient if you check that outer_dims_perm excluding the dimensions present in inner_dims_pos are monotonically growing?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion, updated in the latest commit!

// hold)

// 2.1 Get the permutation for linalg.transpose:
// [ untiled-dims, inner-dims-pos ]
// Note, this logic assumes that the untiled dims are not permuted.
SmallVector<int64_t> srcPermForTranspose;
for (int64_t i = 0; i < srcRank; i++) {
// We assume the `k` dimensions of the inner dim position, where `k` is the
Expand All @@ -1179,9 +1203,21 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
}
srcPermForTranspose.append(innerDimsPos.begin(), innerDimsPos.end());

// 2.2 Create the init tensor for linalg.transpose with the correct shape
SmallVector<OpFoldResult> shapeForEmptyOp(srcRank - numberOfTiles,
oneIdxAttr);
// 2.2 Create the init tensor for linalg.transpose with the correct shape:
// [ untiled-dims, tiled-dims ]
ShapedType inputTy = cast<ShapedType>(input.getType());
SmallVector<OpFoldResult> shapeForEmptyOp;
for (int64_t i = 0; i < srcRank; i++) {
if (llvm::is_contained(innerDimsPos, i)) {
// The tiled dims are appended after this loop.
continue;
}
if (inputTy.isStaticDim(i))
shapeForEmptyOp.push_back(rewriter.getIndexAttr(inputTy.getShape()[i]));
else
shapeForEmptyOp.emplace_back(
tensor::DimOp::create(rewriter, loc, input, i).getResult());
}
shapeForEmptyOp.append(packOp.getMixedTiles());

// getMixedTiles() may contain Values pointing to constant ops, not the
Expand All @@ -1206,23 +1242,34 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(

// 3. Insert the inner tile to the destination:
// %inserted_tile = tensor.insert_slice(%transposed_tile)
SmallVector<OpFoldResult> writeStrides(destRank, oneIdxAttr);
SmallVector<OpFoldResult> writeOffsets(destRank, zeroIdxAttr);
// Outer dims are all 1s!
SmallVector<OpFoldResult> writeSizes(destRank - numberOfTiles, oneIdxAttr);
SmallVector<int64_t> writeShape;

// Compute the sizes attribute:
// [ outer-dims, tile-sizes ]
// Note that the output from the transpose Op excludes the tiled outer dims.
// Given the assumptions (all tiled outer dims == 1), we can safely use a
// rank-expanding tensor.insert_slice. Rather than manually computing where to
// insert new unit dims (resulting from the expansion), use the Pack op
// attributes.
SmallVector<OpFoldResult> writeSizes;
for (auto size : packOp.getAllOuterDims()) {
writeSizes.push_back(rewriter.getIndexAttr(size));
}

for (auto tileSize : packOp.getMixedTiles()) {
auto [tileSizeStatic, tileSizeOfr] =
auto [_, tileSizeOfr] =
getSimplifiedOfrAndStaticSizePair(tileSize, rewriter);
writeSizes.push_back(tileSizeOfr);
writeShape.push_back(tileSizeStatic);
}

// 4. Replace tensor.packOp with tensor.insert_slice created above
SmallVector<OpFoldResult> writeStrides(destRank, oneIdxAttr);
SmallVector<OpFoldResult> writeOffsets(destRank, zeroIdxAttr);

// TODO: A constructor that doesn't require strides nor offsets.
auto insert = tensor::InsertSliceOp::create(
rewriter, loc, transposedOp.getResult()[0], packOp.getDest(),
writeOffsets, writeSizes, writeStrides);

// 4. Replace tensor.packOp with tensor.insert_slice created above
rewriter.replaceOp(packOp, insert.getResult());

return success();
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2310,6 +2310,7 @@ RankedTensorType ExtractSliceOp::inferResultType(
sourceTensorType.getEncoding());
}

// TODO: This uses neither offsets nor strides!
RankedTensorType ExtractSliceOp::inferResultType(
RankedTensorType sourceTensorType, ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes, ArrayRef<OpFoldResult> strides) {
Expand Down
60 changes: 60 additions & 0 deletions mlir/test/Dialect/Linalg/decompose-pack.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,25 @@ func.func @simple_KCRS_to_KCRSsr(%arg0: tensor<?x?xi32>, %arg1: tensor<1x1x?x1xi

// -----

func.func @NCHW_to_NCHWc(%src: tensor<2x32x16x8xf32>, %dest: tensor<2x1x16x8x32xf32>) -> tensor<2x1x16x8x32xf32> {
%pack = linalg.pack %src
inner_dims_pos = [1]
inner_tiles = [32] into %dest
: tensor<2x32x16x8xf32> -> tensor<2x1x16x8x32xf32>
return %pack : tensor<2x1x16x8x32xf32>
}
// CHECK-LABEL: func.func @NCHW_to_NCHWc(
// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
// CHECK: %[[INIT:.*]] = tensor.empty() : tensor<2x16x8x32xf32>
// CHECK: %[[TR:.*]] = linalg.transpose ins(%[[SRC]] : tensor<2x32x16x8xf32>) outs(%[[INIT]] : tensor<2x16x8x32xf32>) permutation = [0, 2, 3, 1]
// CHECK: %[[RES:.*]] = tensor.insert_slice %[[TR]] into %[[DEST]]
// CHECK-SAME: [0, 0, 0, 0, 0] [2, 1, 16, 8, 32] [1, 1, 1, 1, 1]
// CHECK-SAME: : tensor<2x16x8x32xf32> into tensor<2x1x16x8x32xf32>
// CHECK: return %[[RES]] : tensor<2x1x16x8x32xf32>

// -----

func.func @simple_pad_and_pack_static_tiles(%input: tensor<5x1xf32>, %output: tensor<1x1x8x2xf32>, %pad: f32) -> tensor<1x1x8x2xf32> {
%0 = linalg.pack %input padding_value(%pad : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %output : tensor<5x1xf32> -> tensor<1x1x8x2xf32>
return %0 : tensor<1x1x8x2xf32>
Expand Down Expand Up @@ -157,6 +176,8 @@ func.func @simple_pad_and_pack_dynamic_tiles(%input: tensor<5x1xf32>, %output: t

// -----

// Note - un-tiled outer dims are permueted. However, these are unit dims, which is supported.

func.func @simple_pad_and_pack_dynamic_tile_not_all_dims_tiled(%input: tensor<1x1x5x1xf32>, %output: tensor<1x1x1x1x2x?xf32>, %pad: f32, %high: index) -> tensor<1x1x1x1x2x?xf32> {
%0 = linalg.pack %input padding_value(%pad : f32) outer_dims_perm = [1, 0, 2, 3] inner_dims_pos = [3, 2] inner_tiles = [2, %high] into %output : tensor<1x1x5x1xf32> -> tensor<1x1x1x1x2x?xf32>
return %0 : tensor<1x1x1x1x2x?xf32>
Expand All @@ -182,6 +203,28 @@ func.func @simple_pad_and_pack_dynamic_tile_not_all_dims_tiled(%input: tensor<1x

// -----

// Similar as the example above, but one of the un-tiled outer dims that are permuted is non-unit: (7,1) -> (1, 7)

func.func @negative_not_all_dims_tiled_outer_dim_0_permuted(%input: tensor<7x1x5x1xf32>, %output: tensor<1x7x1x1x2x?xf32>, %pad: f32, %high: index) -> tensor<1x7x1x1x2x?xf32> {
%0 = linalg.pack %input padding_value(%pad : f32) outer_dims_perm = [1, 0, 2, 3] inner_dims_pos = [3, 2] inner_tiles = [2, %high] into %output : tensor<7x1x5x1xf32> -> tensor<1x7x1x1x2x?xf32>
return %0 : tensor<1x7x1x1x2x?xf32>
}
// CHECK-LABEL: func.func @negative_not_all_dims_tiled_outer_dim_0_permuted
// CHECK: linalg.pack

// -----

// Similar as the example above, but one of the un-tiled outer dims that are permuted is non-unit: (1, 7) -> (7, 1).

func.func @negative_not_all_dims_tiled_outer_dim_1_permuted(%input: tensor<1x7x5x1xf32>, %output: tensor<7x1x1x1x2x?xf32>, %pad: f32, %high: index) -> tensor<7x1x1x1x2x?xf32> {
%0 = linalg.pack %input padding_value(%pad : f32) outer_dims_perm = [1, 0, 2, 3] inner_dims_pos = [3, 2] inner_tiles = [2, %high] into %output : tensor<1x7x5x1xf32> -> tensor<7x1x1x1x2x?xf32>
return %0 : tensor<7x1x1x1x2x?xf32>
}
// CHECK-LABEL: func.func @negative_not_all_dims_tiled_outer_dim_1_permuted
// CHECK: linalg.pack

// -----

func.func @simple_NC_to_CNnc(%arg0: tensor<32x8xf32>, %arg1: tensor<1x1x32x8xf32>) -> tensor<1x1x32x8xf32>{
%0 = linalg.pack %arg0 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 8] into %arg1 : tensor<32x8xf32> -> tensor<1x1x32x8xf32>
return %0 : tensor<1x1x32x8xf32>
Expand Down Expand Up @@ -295,3 +338,20 @@ func.func @pack_with_non_adjacent_and_non_permuted_inner_dims(%arg0: tensor<8x1x
// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %[[DEST]]
// CHECK-SAME: [0, 0, 0, 0, 0, 0] [1, 1, 1, 1, 8, 1] [1, 1, 1, 1, 1, 1] : tensor<1x1x8x1xf32> into tensor<1x1x1x1x8x1xf32>
// CHECK: return %[[INSERT]]

// -----
/// Note "126", which is a non-unit tile-outer-dim. This is not supported.

func.func @negative_non_unit_tiled_outer_dim(%dest: tensor<1x126x1x1x8xf32>, %src: tensor<1x1x1x1001xf32>, %pad: f32) -> tensor<1x126x1x1x8xf32> {
%pack = linalg.pack %src
padding_value(%pad : f32)
outer_dims_perm = [0, 3, 2, 1]
inner_dims_pos = [3]
inner_tiles = [8]
into %dest : tensor<1x1x1x1001xf32>
-> tensor<1x126x1x1x8xf32>

return %pack : tensor<1x126x1x1x8xf32>
}
// CHECK-LABEL: @negative_non_unit_tiled_outer_dim(
// CHECK: linalg.pack