Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ class Linalg_RelayoutOp<string mnemonic, list<Trait> traits = []> :
/// tile factors.
DenseMap<int64_t, OpFoldResult> getDimAndTileMapping();

/// Return the tile sizes as OpFoldResult.
/// Return the tile sizes as OpFoldResult. Will return the Value
/// of the constant Op, not the constant Attribute.
SmallVector<OpFoldResult> getMixedTiles();

/// Return the tile sizes as `int64_t`. If a tile size is dynamic
Expand Down
64 changes: 29 additions & 35 deletions mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1146,37 +1146,25 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
Attribute oneIdxAttr = rewriter.getIndexAttr(1);
Location loc = packOp.getLoc();

Value input = getPackOpSourceOrPaddedSource(rewriter, packOp);
DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
packOp.getDimAndTileMapping();
int64_t srcRank = packOp.getSourceRank();
int64_t destRank = packOp.getDestRank();
int64_t numTiles = destRank - srcRank;
ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
int64_t numberOfTiles = innerDimsPos.size();

// 1. Extract the inner tile sizes.
// Where possible, values are replaced with constant attributes (to match the
// behaviour of `getPackOpSourceOrPaddedSource`).
SmallVector<OpFoldResult> tileSizes;
for (auto i : llvm::seq<unsigned>(0, srcRank)) {
if (dimAndTileMapping.count(i)) {
// Rather than taking the tile size as is, extact the actual constant
// value Attribute where possible, e.g.:
// [Value: %tile_size = arith.constant 8 : index] --> [Attribute: 8]
auto [_, tileSize] =
getSimplifiedOfrAndStaticSizePair(dimAndTileMapping[i], rewriter);
tileSizes.push_back(tileSize);
}
}
// 1. Get the input that is going to be packed. If the input requires padding,
// add a padding operation and return that as the input.
Value input = getPackOpSourceOrPaddedSource(rewriter, packOp);

// 2. Transpose the input to match the inner tile order:
// %init = tensor.empty()
// %transposed_tile = linalg.transpose ins(%source_or_padded_source),
// outs(%init)
// Assumptions made:
// 1. All outer dims are 1 - the corresponding transposition order doesn't
// - All outer dims are 1 - the corresponding transposition order doesn't
// matter, but requires all dim indices to be present.

// 2.1 Get the permutation for linalg.transpose
SmallVector<int64_t> srcPermForTranspose;
ArrayRef<int64_t> innerDimPos(packOp.getInnerDimsPos());
for (int64_t i = 0; i < srcRank; i++) {
// We assume the `k` dimensions of the inner dim position, where `k` is the
// rank of the inner tiling, correspond to the last `k` indices of the
Expand All @@ -1185,27 +1173,34 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
// rank of the source tensor. For example if we have a source tensor with
// indices [0, 1, 2, 3] and inner dim position of [3, 0], the remaining
// indices are [1, 2]. and the transpose will be [1, 2, 3, 0].
if (llvm::is_contained(innerDimPos, i))
if (llvm::is_contained(innerDimsPos, i))
continue;
srcPermForTranspose.push_back(i);
}
srcPermForTranspose.append(innerDimPos.begin(), innerDimPos.end());
srcPermForTranspose.append(innerDimsPos.begin(), innerDimsPos.end());

// 2.2 Create the init tensor for linalg.transpose with the correct shape
SmallVector<OpFoldResult> shapeForEmptyOp(srcRank - numberOfTiles,
oneIdxAttr);
shapeForEmptyOp.append(packOp.getMixedTiles());

// getMixedTiles() may contain Values pointing to constant ops, not the
// constant attributes. Replace them with a true OpFoldResult.
llvm::transform(shapeForEmptyOp, shapeForEmptyOp.begin(),
[&](OpFoldResult ofr) {
if (auto val = llvm::dyn_cast<Value>(ofr))
return getAsOpFoldResult(val);
return ofr;
});

LDBG() << "Pack permutation: " << packOp;
LDBG() << "perm: " << llvm::interleaved(srcPermForTranspose);
LDBG() << "Shape of empty tensor: " << llvm::interleaved(shapeForEmptyOp);

// 2.1 Create tensor.empty (init value for TransposeOp)
SmallVector<OpFoldResult> transShapeForEmptyOp(srcRank - numTiles,
oneIdxAttr);
transShapeForEmptyOp.append(tileSizes);

applyPermutationToVector<OpFoldResult>(transShapeForEmptyOp,
srcPermForTranspose);
Value empty =
tensor::EmptyOp::create(rewriter, loc, transShapeForEmptyOp,
packOp.getSourceType().getElementType());
Value empty = tensor::EmptyOp::create(
rewriter, loc, shapeForEmptyOp, packOp.getSourceType().getElementType());

// 2.2 Create linalg.transpose
// 2.3 Create linalg.transpose
auto transposedOp = linalg::TransposeOp::create(rewriter, loc, input, empty,
srcPermForTranspose);

Expand All @@ -1214,8 +1209,7 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
SmallVector<OpFoldResult> writeStrides(destRank, oneIdxAttr);
SmallVector<OpFoldResult> writeOffsets(destRank, zeroIdxAttr);
// Outer dims are all 1s!
SmallVector<OpFoldResult> writeSizes(destRank - dimAndTileMapping.size(),
oneIdxAttr);
SmallVector<OpFoldResult> writeSizes(destRank - numberOfTiles, oneIdxAttr);
SmallVector<int64_t> writeShape;

for (auto tileSize : packOp.getMixedTiles()) {
Expand Down
19 changes: 19 additions & 0 deletions mlir/test/Dialect/Linalg/decompose-pack.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -274,3 +274,22 @@ func.func @pack_with_adjacent_trailing_dimensions_inner_dims_pos_and_unit_outer(
// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %[[DEST]]
// CHECK-SAME: [0, 0, 0, 0, 0] [1, 1, 1, 4, 1] [1, 1, 1, 1, 1] : tensor<1x4x1xf32> into tensor<1x1x1x4x1xf32>
// CHECK: return %[[INSERT]]

// -----

func.func @pack_with_zero_pos_tile_size(%arg0: tensor<8x1x1x1xf32>, %arg1:tensor<1x1x1x1x8x1xf32>) -> tensor<1x1x1x1x8x1xf32> {
%pack = linalg.pack %arg0 outer_dims_perm = [0, 1, 2, 3] inner_dims_pos = [0, 3] inner_tiles = [8, 1] into %arg1: tensor<8x1x1x1xf32> -> tensor<1x1x1x1x8x1xf32>
return %pack : tensor<1x1x1x1x8x1xf32>
}

// CHECK-LABEL: func.func @pack_with_zero_pos_tile_size
// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x1x8x1xf32>
// CHECK: %[[TRANSP:.+]] = linalg.transpose
// CHECK-SAME: ins(%[[SRC]] : tensor<8x1x1x1xf32>)
// CHECK-SAME: outs(%[[EMPTY]] : tensor<1x1x8x1xf32>)
// CHECK-SAME: permutation = [1, 2, 0, 3]
// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %[[DEST]]
// CHECK-SAME: [0, 0, 0, 0, 0, 0] [1, 1, 1, 1, 8, 1] [1, 1, 1, 1, 1, 1] : tensor<1x1x8x1xf32> into tensor<1x1x1x1x8x1xf32>
// CHECK: return %[[INSERT]]