Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 16 additions & 15 deletions mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1178,13 +1178,6 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
int64_t destRank = packOp.getDestRank();
int64_t numTiles = destRank - srcRank;

if (!llvm::all_of(packOp.getInnerDimsPos(),
[&srcRank, &numTiles](int64_t dimPos) {
return dimPos >= (srcRank - numTiles - 1);
}))
return rewriter.notifyMatchFailure(
packOp, "Attempting to tile non-trailing source dims!");

// 1. Extract the inner tile sizes.
// Where possible, values are replaced with constant attributes (to match the
// behaviour of `getPackOpSourceOrPaddedSource`).
Expand All @@ -1204,16 +1197,24 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
// %init = tensor.empty()
// %transposed_tile = linalg.transpose ins(%source_or_padded_source),
// outs(%init)
// Two assumptions are made:
// 1. All outer dims are 1 - the corresponding transposition doesn't matter.
// 2. Inner dims position correspond to the trailing `numTiles` dims.
SmallVector<int64_t> tilesPermNormalized =
getPackUnpackNormalizedPerm(srcRank, packOp.getInnerDimsPos());
// Assumptions made:
// 1. All outer dims are 1 - the corresponding transposition order doesn't
// matter, but requires all dim indices to be present.
SmallVector<int64_t> srcPermForTranspose;
for (int64_t i = 0; i < (srcRank - numTiles); i++)
ArrayRef<int64_t> innerDimPos(packOp.getInnerDimsPos());
for (int64_t i = 0; i < srcRank; i++) {
// We assume the `k` dimensions of the inner dim position, where `k` is the
// rank of the inner tiling, correspond to the last `k` indices of the
// transpose permutation. This is done by adding the indices not contained
// in the inner dimension position in order from 0 to `n`. Where n is the
// rank of the source tensor. For example if we have a source tensor with
// indices [0, 1, 2, 3] and inner dim position of [3, 0], the remaining
// indices are [1, 2]. and the transpose will be [1, 2, 3, 0].
if (llvm::is_contained(innerDimPos, i))
continue;
srcPermForTranspose.push_back(i);

srcPermForTranspose.append(SmallVector<int64_t>(packOp.getInnerDimsPos()));
}
srcPermForTranspose.append(innerDimPos.begin(), innerDimPos.end());

LLVM_DEBUG(DBGS() << "Pack permutation: " << packOp << "\n"
<< "perm: " << llvm::interleaved(srcPermForTranspose)
Expand Down
45 changes: 45 additions & 0 deletions mlir/test/Dialect/Linalg/decompose-pack.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -229,3 +229,48 @@ func.func @simple_KCRS_to_KRSCsr(%arg0: tensor<1x1x32x8xf32>, %arg1: tensor<1x1x
// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %[[DEST]]
// CHECK-SAME: [0, 0, 0, 0, 0, 0] [1, 1, 1, 1, 8, 32] [1, 1, 1, 1, 1, 1]
// CHECK: return %[[INSERT]]

// -----

// The following example shows a pack operation that is defined with inner
// dimension positions that are not adjacent, i.e. `[2, 0]`. And the outer
// dimensions of the packed tensor are of unit values, i.e. `1x1x1`.
func.func @pack_with_non_adjacent_inner_dims_pos_and_unit_outer(%arg0: tensor<1x1x4xf32>, %arg1: tensor<1x1x1x4x1xf32>) -> tensor<1x1x1x4x1xf32> {
%pack = linalg.pack %arg0 outer_dims_perm = [1, 2, 0] inner_dims_pos = [2, 0] inner_tiles = [4, 1] into %arg1 : tensor<1x1x4xf32> -> tensor<1x1x1x4x1xf32>
return %pack : tensor<1x1x1x4x1xf32>
}
// CHECK-LABEL: func.func @pack_with_non_adjacent_inner_dims_pos_and_unit_outer
// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x4x1xf32>
// CHECK: %[[TRANSP:.+]] = linalg.transpose
// CHECK-SAME: ins(%[[SRC]] : tensor<1x1x4xf32>)
// CHECK-SAME: outs(%[[EMPTY]] : tensor<1x4x1xf32>)
// CHECK-SAME: permutation = [1, 2, 0]
// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %[[DEST]]
// CHECK-SAME: [0, 0, 0, 0, 0] [1, 1, 1, 4, 1] [1, 1, 1, 1, 1] : tensor<1x4x1xf32> into tensor<1x1x1x4x1xf32>
// CHECK: return %[[INSERT]]

// -----

// The following example shows a pack operation where the inner dimension
// positions are specified as [2, 1] which are termed adjacent trailing
// dimensions as they contain the last dimension of the source tensor with a
// neighboring dimension. [1, 2] would also be considered trailing adjacent.
// And the outer dimensions of the packed tensor are all set to unit values
// of `1x1x1`.
func.func @pack_with_adjacent_trailing_dimensions_inner_dims_pos_and_unit_outer(%arg0: tensor<1x1x4xf32>, %arg1: tensor<1x1x1x4x1xf32>) -> tensor<1x1x1x4x1xf32> {
%pack = linalg.pack %arg0 outer_dims_perm = [1, 2, 0] inner_dims_pos = [2, 1] inner_tiles = [4, 1] into %arg1 : tensor<1x1x4xf32> -> tensor<1x1x1x4x1xf32>
return %pack : tensor<1x1x1x4x1xf32>
}
// CHECK-LABEL: func.func @pack_with_adjacent_trailing_dimensions_inner_dims_pos_and_unit_outer
// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x4x1xf32>
// CHECK: %[[TRANSP:.+]] = linalg.transpose
// CHECK-SAME: ins(%[[SRC]] : tensor<1x1x4xf32>)
// CHECK-SAME: outs(%[[EMPTY]] : tensor<1x4x1xf32>)
// CHECK-SAME: permutation = [0, 2, 1]
// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %[[DEST]]
// CHECK-SAME: [0, 0, 0, 0, 0] [1, 1, 1, 4, 1] [1, 1, 1, 1, 1] : tensor<1x4x1xf32> into tensor<1x1x1x4x1xf32>
// CHECK: return %[[INSERT]]
34 changes: 34 additions & 0 deletions mlir/test/Dialect/Linalg/decompose-unpack.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -169,3 +169,37 @@ func.func @unpack_with_dynamic_dims(%arg0: tensor<?x1x1x1x8x32xf32>, %arg1: tens
// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[EXTRACT_SLICE]] into %[[DEST]]
// CHECK-SAME: [0, 0, 0, 0] [%[[DIM0_DEST]], 1, 32, 8] [1, 1, 1, 1]
// CHECK: return %[[INSERT]]

// -----

func.func @unpack_with_non_adjacent_inner_dims_pos_and_unit_outer(%arg0: tensor<1x1x1x4x1xf32>, %arg1: tensor<1x1x4xf32>) -> tensor<1x1x4xf32> {
%0 = linalg.unpack %arg0 outer_dims_perm = [1, 2, 0] inner_dims_pos = [2, 0] inner_tiles = [4, 1] into %arg1 : tensor<1x1x1x4x1xf32> -> tensor<1x1x4xf32>
return %0 : tensor<1x1x4xf32>
}
// CHECK-LABEL: func.func @unpack_with_non_adjacent_inner_dims_pos_and_unit_outer
// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[SRC]][0, 0, 0, 0, 0] [1, 1, 1, 4, 1] [1, 1, 1, 1, 1] : tensor<1x1x1x4x1xf32> to tensor<4x1xf32>
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x4xf32>
// CHECK: %[[TRANSP:.+]] = linalg.transpose
// CHECK-SAME: ins(%[[SLICE]] : tensor<4x1xf32>)
// CHECK-SAME: outs(%[[EMPTY]] : tensor<1x4xf32>) permutation = [1, 0]
// CHECK: %[[INSERT:.+]] = tensor.insert_slice %transposed into %[[DEST]][0, 0, 0] [1, 1, 4] [1, 1, 1] : tensor<1x4xf32> into tensor<1x1x4xf32>
// CHECK: return %[[INSERT]]

// -----

func.func @unpack_with_non_trailing_dimensions_in_inner_dims(%arg0: tensor<1x1x1x4x1xf32>, %arg1: tensor<1x1x4xf32>) -> tensor<1x1x4xf32> {
%pack = linalg.unpack %arg0 outer_dims_perm = [1, 2, 0] inner_dims_pos = [2, 1] inner_tiles = [4, 1] into %arg1 : tensor<1x1x1x4x1xf32> -> tensor<1x1x4xf32>
return %pack : tensor<1x1x4xf32>
}
// CHECK-LABEL: func.func @unpack_with_non_trailing_dimensions_in_inner_dims
// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[SRC]][0, 0, 0, 0, 0] [1, 1, 1, 4, 1] [1, 1, 1, 1, 1] : tensor<1x1x1x4x1xf32> to tensor<4x1xf32>
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x4xf32>
// CHECK: %[[TRANSP:.+]] = linalg.transpose
// CHECK-SAME: ins(%[[SLICE]] : tensor<4x1xf32>)
// CHECK-SAME: outs(%[[EMPTY]] : tensor<1x4xf32>) permutation = [1, 0]
// CHECK: %[[INSERT:.+]] = tensor.insert_slice %transposed into %[[DEST]][0, 0, 0] [1, 1, 4] [1, 1, 1] : tensor<1x4xf32> into tensor<1x1x4xf32>
// CHECK: return %[[INSERT]]