Skip to content

Commit a22959b

Browse files
committed
[mlir][linalg] Fix UnPackOp::getTiledOuterDims
Fixes `getTiledOuterDims` by making sure that the `outer_dims_perm` attribute from `linalg.unpack` is taken into account. Fixes #152037
1 parent a2d353e commit a22959b

File tree

4 files changed

+54
-3
lines changed

4 files changed

+54
-3
lines changed

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1690,7 +1690,7 @@ struct DecomposeOuterUnitDimsPackOpPattern
16901690
/// Rewrites a linalg::UnPackOp into a sequence of rank-reduced
16911691
/// * tensor::ExtractSliceOp + linalg::TransposeOp + tensor::InsertSliceOp
16921692
///
1693-
/// Requires that all the outer dims of the input linalg::PackOp are 1.
1693+
/// Requires that all the tile outer dims of the input linalg::PackOp are 1.
16941694
///
16951695
/// Before:
16961696
/// ```

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5765,13 +5765,48 @@ ArrayRef<int64_t> UnPackOp::getAllOuterDims() {
57655765
return getSourceType().getShape().take_front(destRank);
57665766
}
57675767

5768+
static SmallVector<int64_t>
5769+
inversePerm(const llvm::SmallVector<int64_t> &perm) {
5770+
const size_t n = perm.size();
5771+
llvm::SmallVector<int64_t> invPerm(n);
5772+
5773+
for (size_t i = 0; i < n; ++i) {
5774+
assert(perm[i] >= 0 && static_cast<size_t>(perm[i]) < n &&
5775+
"Invalid permutation entry");
5776+
invPerm[perm[i]] = i;
5777+
}
5778+
5779+
return invPerm;
5780+
}
5781+
5782+
/// Compute the inverse of a permutation. Assumes `perm` is a valid permutation
5783+
/// of 0...n-1.
5784+
static SmallVector<int64_t> invertPermutation(SmallVector<int64_t> perm) {
5785+
const size_t permLen = perm.size();
5786+
llvm::SmallVector<int64_t> inv(permLen);
5787+
for (size_t i = 0; i < permLen; ++i) {
5788+
assert(perm[i] >= 0 && static_cast<size_t>(perm[i]) < permLen &&
5789+
"Invalid permutation entry");
5790+
inv[perm[i]] = i;
5791+
}
5792+
return inv;
5793+
}
5794+
57685795
SmallVector<int64_t> UnPackOp::getTiledOuterDims() {
57695796
auto innerDimsPos = getInnerDimsPos();
5770-
auto packedShape = getSourceType().getShape();
5797+
SmallVector<int64_t> outerDims(getAllOuterDims());
57715798
SmallVector<int64_t> res;
57725799

5800+
// Invert outer-dims-perm and use it to restore the original order
5801+
// of the outer dims.
5802+
SmallVector<int64_t> outerDimPermInv(getOuterDimsPerm());
5803+
inversePerm(outerDimPermInv);
5804+
if (!outerDimPermInv.empty())
5805+
applyPermutationToVector(outerDims, outerDimPermInv);
5806+
5807+
// Collect the outer dims corresponding to the tilled inner dims.
57735808
for (auto index : innerDimsPos)
5774-
res.push_back(packedShape[index]);
5809+
res.push_back(outerDims[index]);
57755810

57765811
return res;
57775812
}

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2310,6 +2310,7 @@ RankedTensorType ExtractSliceOp::inferResultType(
23102310
sourceTensorType.getEncoding());
23112311
}
23122312

2313+
// TODO: This uses neither offsets nor strides!
23132314
RankedTensorType ExtractSliceOp::inferResultType(
23142315
RankedTensorType sourceTensorType, ArrayRef<OpFoldResult> offsets,
23152316
ArrayRef<OpFoldResult> sizes, ArrayRef<OpFoldResult> strides) {

mlir/test/Dialect/Linalg/decompose-unpack.mlir

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,3 +203,18 @@ func.func @unpack_with_non_trailing_dimensions_in_inner_dims(%arg0: tensor<1x1x1
203203
// CHECK-SAME: outs(%[[EMPTY]] : tensor<1x4xf32>) permutation = [1, 0]
204204
// CHECK: %[[INSERT:.+]] = tensor.insert_slice %transposed into %[[DEST]][0, 0, 0] [1, 1, 4] [1, 1, 1] : tensor<1x4xf32> into tensor<1x1x4xf32>
205205
// CHECK: return %[[INSERT]]
206+
207+
// -----
208+
209+
/// Note "126", which is a non-unit tile-outer-dim. This is not supported.
210+
211+
func.func @negative_non_unit_tiled_outer_dim(%src: tensor<1x126x1x1x8xf32>, %dest: tensor<1x1x1x1001xf32>) -> tensor<1x1x1x1001xf32> {
212+
%unpack = linalg.unpack %src
213+
outer_dims_perm = [0, 3, 2, 1]
214+
inner_dims_pos = [3]
215+
inner_tiles = [8]
216+
into %dest : tensor<1x126x1x1x8xf32>
217+
-> tensor<1x1x1x1001xf32>
218+
219+
return %unpack : tensor<1x1x1x1001xf32>
220+
}

0 commit comments

Comments
 (0)