Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -1650,8 +1650,12 @@ struct DecomposePadOpPattern : public OpRewritePattern<tensor::PadOp> {
/// Rewrites a linalg::PackOp into a sequence of:
/// * tensor::PadOp + linalg::TransposeOp + tensor::EmptyOp +
/// tensor::InsertSliceOp ops.
/// (InsertSliceOp is rank-expanding).
///
/// Requires that all the outer dims of the input linalg::PackOp are 1.
/// Requires that all the tiled-outer-dims of the input linalg::PackOp are 1.
/// Note that this constraint means that effectively exactly one tile is packed.
///
/// In addition, assumes that the un-tiled-outer-dims are not permuted.
///
/// Before:
/// ```
Expand Down Expand Up @@ -1687,10 +1691,13 @@ struct DecomposeOuterUnitDimsPackOpPattern
PatternRewriter &rewriter) const override;
};

/// Rewrites a linalg::UnPackOp into a sequence of rank-reduced
/// Rewrites a linalg::UnPackOp into a sequence of:
/// * tensor::ExtractSliceOp + linalg::TransposeOp + tensor::InsertSliceOp
/// (ExtractSliceOp is rank-reducing).
///
/// Requires that all the tiled outer dims of the input linalg::PackOp are 1.
/// Requires that all the tiled-outer-dims of the input linalg::UnPackOp are 1.
/// Note that this constraint means that effectively exactly one tile is
/// unpacked.
///
/// Before:
/// ```
Expand Down
11 changes: 9 additions & 2 deletions mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5272,11 +5272,18 @@ ArrayRef<int64_t> PackOp::getAllOuterDims() {

SmallVector<int64_t> PackOp::getTiledOuterDims() {
auto innerDimsPos = getInnerDimsPos();
auto packedShape = getDestType().getShape();
SmallVector<int64_t> outerDims(getAllOuterDims());
SmallVector<int64_t> res;

// Recover the original order of the outer dims.
SmallVector<int64_t> outerDimPermInv(getOuterDimsPerm());
invertPermutationVector(outerDimPermInv);
if (!outerDimPermInv.empty())
applyPermutationToVector(outerDims, outerDimPermInv);

// Collect the outer dims corresponding to the tilled inner dims.
for (auto index : innerDimsPos)
res.push_back(packedShape[index]);
res.push_back(outerDims[index]);

return res;
}
Expand Down
89 changes: 69 additions & 20 deletions mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1134,22 +1134,45 @@ getPackUnpackRankReducedPerm(ArrayRef<int64_t> shape,

LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
linalg::PackOp packOp, PatternRewriter &rewriter) const {
// TODO: support the case that outer dimensions are not all 1s. A
// tensor.expand_shape will be generated in this case.
if (llvm::any_of(packOp.getAllOuterDims(),
if (llvm::any_of(packOp.getTiledOuterDims(),
[](int64_t dim) { return dim != 1; })) {
return rewriter.notifyMatchFailure(
packOp, "not all outer dimensions of the result are 1s");
}

ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
auto outerDimsPerm = packOp.getOuterDimsPerm();

// Verify that there are no:
// * non-unit + un-tiled-outer-dims,
// that are permuted. Supporting such cases would require refining the logic
// that generates the Transpose Op.
if (!llvm::all_of(outerDimsPerm, [&innerDimsPos, &packOp](int64_t dim) {
static int prev = 0;
// Skip tiled dims - these can be permuted.
if (llvm::is_contained(innerDimsPos, dim))
return true;

// Check whether this dim has been permuted. Permuting unit dims is fine
// as that's effectively a no-op.
if (dim < prev && (packOp.getType().getShape()[prev] != 1 ||
packOp.getType().getShape()[dim] != 1))
return false;

prev = dim;
return true;
})) {
return rewriter.notifyMatchFailure(
packOp, "At least one non-unit and un-tiled outer dim is permuted, "
"this is not supported ATM!");
}

Attribute zeroIdxAttr = rewriter.getIndexAttr(0);
Attribute oneIdxAttr = rewriter.getIndexAttr(1);
Location loc = packOp.getLoc();

int64_t srcRank = packOp.getSourceRank();
int64_t destRank = packOp.getDestRank();
ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
int64_t numberOfTiles = innerDimsPos.size();

// 1. Get the input that is going to be packed. If the input requires padding,
// add a padding operation and return that as the input.
Expand All @@ -1160,10 +1183,13 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
// %transposed_tile = linalg.transpose ins(%source_or_padded_source),
// outs(%init)
// Assumptions made:
// - All outer dims are 1 - the corresponding transposition order doesn't
// matter, but requires all dim indices to be present.
// - All tiled outer dims are 1 - the corresponding transposition order
// doesn't matter, but requires all dim indices to be present.
// - Un-tiled outer dims remain un-permuted.

// 2.1 Get the permutation for linalg.transpose
// 2.1 Get the permutation for linalg.transpose:
// [ untiled-dims, inner-dims-pos ]
// Note, this logic assumes that the untiled dims are not permuted.
SmallVector<int64_t> srcPermForTranspose;
for (int64_t i = 0; i < srcRank; i++) {
// We assume the `k` dimensions of the inner dim position, where `k` is the
Expand All @@ -1179,9 +1205,21 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
}
srcPermForTranspose.append(innerDimsPos.begin(), innerDimsPos.end());

// 2.2 Create the init tensor for linalg.transpose with the correct shape
SmallVector<OpFoldResult> shapeForEmptyOp(srcRank - numberOfTiles,
oneIdxAttr);
// 2.2 Create the init tensor for linalg.transpose with the correct shape:
// [ untiled-dims, tiled-dims ]
ShapedType inputTy = cast<ShapedType>(input.getType());
SmallVector<OpFoldResult> shapeForEmptyOp;
for (int64_t i = 0; i < srcRank; i++) {
if (llvm::is_contained(innerDimsPos, i)) {
// The tiled dims are appended after this loop.
continue;
}
if (inputTy.isStaticDim(i))
shapeForEmptyOp.push_back(rewriter.getIndexAttr(inputTy.getShape()[i]));
else
shapeForEmptyOp.emplace_back(
tensor::DimOp::create(rewriter, loc, input, i).getResult());
}
shapeForEmptyOp.append(packOp.getMixedTiles());

// getMixedTiles() may contain Values pointing to constant ops, not the
Expand All @@ -1204,25 +1242,36 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
auto transposedOp = linalg::TransposeOp::create(rewriter, loc, input, empty,
srcPermForTranspose);

// 3. Insert the inner tile to the destination:
// 3. Insert the inner tile into the destination tensor:
// %inserted_tile = tensor.insert_slice(%transposed_tile)
SmallVector<OpFoldResult> writeStrides(destRank, oneIdxAttr);
SmallVector<OpFoldResult> writeOffsets(destRank, zeroIdxAttr);
// Outer dims are all 1s!
SmallVector<OpFoldResult> writeSizes(destRank - numberOfTiles, oneIdxAttr);
SmallVector<int64_t> writeShape;

// Compute the sizes attribute:
// [ outer-dims, tile-sizes ]
// Note that the output from the transpose Op excludes the tiled outer dims.
// However, given the assumption that:
// * all tiled outer dims == 1,
// we can just use a rank-expanding tensor.insert_slice.
SmallVector<OpFoldResult> writeSizes;
for (auto size : packOp.getAllOuterDims()) {
writeSizes.push_back(rewriter.getIndexAttr(size));
}

for (auto tileSize : packOp.getMixedTiles()) {
auto [tileSizeStatic, tileSizeOfr] =
auto [_, tileSizeOfr] =
getSimplifiedOfrAndStaticSizePair(tileSize, rewriter);
writeSizes.push_back(tileSizeOfr);
writeShape.push_back(tileSizeStatic);
}

// 4. Replace tensor.packOp with tensor.insert_slice created above
// TODO: Add a constructor for tensor.insert_slice that doesn't require
// strides nor offsets.
SmallVector<OpFoldResult> writeStrides(destRank, oneIdxAttr);
SmallVector<OpFoldResult> writeOffsets(destRank, zeroIdxAttr);

auto insert = tensor::InsertSliceOp::create(
rewriter, loc, transposedOp.getResult()[0], packOp.getDest(),
writeOffsets, writeSizes, writeStrides);

// 4. Replace tensor.packOp with tensor.insert_slice created above
rewriter.replaceOp(packOp, insert.getResult());

return success();
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2310,6 +2310,7 @@ RankedTensorType ExtractSliceOp::inferResultType(
sourceTensorType.getEncoding());
}

// TODO: This uses neither offsets nor strides!
RankedTensorType ExtractSliceOp::inferResultType(
RankedTensorType sourceTensorType, ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes, ArrayRef<OpFoldResult> strides) {
Expand Down
61 changes: 61 additions & 0 deletions mlir/test/Dialect/Linalg/decompose-pack.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,25 @@ func.func @simple_KCRS_to_KCRSsr(%arg0: tensor<?x?xi32>, %arg1: tensor<1x1x?x1xi

// -----

func.func @NCHW_to_NCHWc(%src: tensor<2x32x16x8xf32>, %dest: tensor<2x1x16x8x32xf32>) -> tensor<2x1x16x8x32xf32> {
%pack = linalg.pack %src
inner_dims_pos = [1]
inner_tiles = [32] into %dest
: tensor<2x32x16x8xf32> -> tensor<2x1x16x8x32xf32>
return %pack : tensor<2x1x16x8x32xf32>
}
// CHECK-LABEL: func.func @NCHW_to_NCHWc(
// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
// CHECK: %[[INIT:.*]] = tensor.empty() : tensor<2x16x8x32xf32>
// CHECK: %[[TR:.*]] = linalg.transpose ins(%[[SRC]] : tensor<2x32x16x8xf32>) outs(%[[INIT]] : tensor<2x16x8x32xf32>) permutation = [0, 2, 3, 1]
// CHECK: %[[RES:.*]] = tensor.insert_slice %[[TR]] into %[[DEST]]
// CHECK-SAME: [0, 0, 0, 0, 0] [2, 1, 16, 8, 32] [1, 1, 1, 1, 1]
// CHECK-SAME: : tensor<2x16x8x32xf32> into tensor<2x1x16x8x32xf32>
// CHECK: return %[[RES]] : tensor<2x1x16x8x32xf32>

// -----

func.func @simple_pad_and_pack_static_tiles(%input: tensor<5x1xf32>, %output: tensor<1x1x8x2xf32>, %pad: f32) -> tensor<1x1x8x2xf32> {
%0 = linalg.pack %input padding_value(%pad : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %output : tensor<5x1xf32> -> tensor<1x1x8x2xf32>
return %0 : tensor<1x1x8x2xf32>
Expand Down Expand Up @@ -157,6 +176,8 @@ func.func @simple_pad_and_pack_dynamic_tiles(%input: tensor<5x1xf32>, %output: t

// -----

// Note - un-tiled outer dims are permueted. However, these are unit dims, which is supported.

func.func @simple_pad_and_pack_dynamic_tile_not_all_dims_tiled(%input: tensor<1x1x5x1xf32>, %output: tensor<1x1x1x1x2x?xf32>, %pad: f32, %high: index) -> tensor<1x1x1x1x2x?xf32> {
%0 = linalg.pack %input padding_value(%pad : f32) outer_dims_perm = [1, 0, 2, 3] inner_dims_pos = [3, 2] inner_tiles = [2, %high] into %output : tensor<1x1x5x1xf32> -> tensor<1x1x1x1x2x?xf32>
return %0 : tensor<1x1x1x1x2x?xf32>
Expand All @@ -182,6 +203,28 @@ func.func @simple_pad_and_pack_dynamic_tile_not_all_dims_tiled(%input: tensor<1x

// -----

// Similar as the example above, but one of the un-tiled outer dims that are permuted is non-unit: (7,1) -> (1, 7)

func.func @negative_not_all_dims_tiled_outer_dim_0_permuted(%input: tensor<7x1x5x1xf32>, %output: tensor<1x7x1x1x2x?xf32>, %pad: f32, %high: index) -> tensor<1x7x1x1x2x?xf32> {
%0 = linalg.pack %input padding_value(%pad : f32) outer_dims_perm = [1, 0, 2, 3] inner_dims_pos = [3, 2] inner_tiles = [2, %high] into %output : tensor<7x1x5x1xf32> -> tensor<1x7x1x1x2x?xf32>
return %0 : tensor<1x7x1x1x2x?xf32>
}
// CHECK-LABEL: func.func @negative_not_all_dims_tiled_outer_dim_0_permuted
// CHECK: linalg.pack

// -----

// Similar as the example above, but one of the un-tiled outer dims that are permuted is non-unit: (1, 7) -> (7, 1).

func.func @negative_not_all_dims_tiled_outer_dim_1_permuted(%input: tensor<1x7x5x1xf32>, %output: tensor<7x1x1x1x2x?xf32>, %pad: f32, %high: index) -> tensor<7x1x1x1x2x?xf32> {
%0 = linalg.pack %input padding_value(%pad : f32) outer_dims_perm = [1, 0, 2, 3] inner_dims_pos = [3, 2] inner_tiles = [2, %high] into %output : tensor<1x7x5x1xf32> -> tensor<7x1x1x1x2x?xf32>
return %0 : tensor<7x1x1x1x2x?xf32>
}
// CHECK-LABEL: func.func @negative_not_all_dims_tiled_outer_dim_1_permuted
// CHECK: linalg.pack

// -----

func.func @simple_NC_to_CNnc(%arg0: tensor<32x8xf32>, %arg1: tensor<1x1x32x8xf32>) -> tensor<1x1x32x8xf32>{
%0 = linalg.pack %arg0 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 8] into %arg1 : tensor<32x8xf32> -> tensor<1x1x32x8xf32>
return %0 : tensor<1x1x32x8xf32>
Expand Down Expand Up @@ -295,3 +338,21 @@ func.func @pack_with_non_adjacent_and_non_permuted_inner_dims(%arg0: tensor<8x1x
// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %[[DEST]]
// CHECK-SAME: [0, 0, 0, 0, 0, 0] [1, 1, 1, 1, 8, 1] [1, 1, 1, 1, 1, 1] : tensor<1x1x8x1xf32> into tensor<1x1x1x1x8x1xf32>
// CHECK: return %[[INSERT]]

// -----

/// Note "126", which is a non-unit tiled-outer-dim. This is not supported.

func.func @negative_non_unit_tiled_outer_dim(%dest: tensor<1x126x1x1x8xf32>, %src: tensor<1x1x1x1001xf32>, %pad: f32) -> tensor<1x126x1x1x8xf32> {
%pack = linalg.pack %src
padding_value(%pad : f32)
outer_dims_perm = [0, 3, 2, 1]
inner_dims_pos = [3]
inner_tiles = [8]
into %dest
: tensor<1x1x1x1001xf32> -> tensor<1x126x1x1x8xf32>

return %pack : tensor<1x126x1x1x8xf32>
}
// CHECK-LABEL: @negative_non_unit_tiled_outer_dim(
// CHECK: linalg.pack