Skip to content

Commit 8d4f317

Browse files
authored
[mlir][linalg] Fix UnPackOp::getTiledOuterDims (#152960)
Fixes `getTiledOuterDims` by making sure that the `outer_dims_perm` attribute from `linalg.unpack` is taken into account. Fixes #152037
1 parent 38853a0 commit 8d4f317

File tree

3 files changed

+27
-3
lines changed

3 files changed

+27
-3
lines changed

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1690,7 +1690,7 @@ struct DecomposeOuterUnitDimsPackOpPattern
16901690
/// Rewrites a linalg::UnPackOp into a sequence of rank-reduced
16911691
/// * tensor::ExtractSliceOp + linalg::TransposeOp + tensor::InsertSliceOp
16921692
///
1693-
/// Requires that all the outer dims of the input linalg::PackOp are 1.
1693+
/// Requires that all the tiled outer dims of the input linalg::PackOp are 1.
16941694
///
16951695
/// Before:
16961696
/// ```

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5767,11 +5767,18 @@ ArrayRef<int64_t> UnPackOp::getAllOuterDims() {
57675767

57685768
SmallVector<int64_t> UnPackOp::getTiledOuterDims() {
57695769
auto innerDimsPos = getInnerDimsPos();
5770-
auto packedShape = getSourceType().getShape();
5770+
SmallVector<int64_t> outerDims(getAllOuterDims());
57715771
SmallVector<int64_t> res;
57725772

5773+
// Recover the original order of the outer dims.
5774+
SmallVector<int64_t> outerDimPermInv(getOuterDimsPerm());
5775+
invertPermutationVector(outerDimPermInv);
5776+
if (!outerDimPermInv.empty())
5777+
applyPermutationToVector(outerDims, outerDimPermInv);
5778+
5779+
// Collect the outer dims corresponding to the tilled inner dims.
57735780
for (auto index : innerDimsPos)
5774-
res.push_back(packedShape[index]);
5781+
res.push_back(outerDims[index]);
57755782

57765783
return res;
57775784
}

mlir/test/Dialect/Linalg/decompose-unpack.mlir

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,3 +203,20 @@ func.func @unpack_with_non_trailing_dimensions_in_inner_dims(%arg0: tensor<1x1x1
203203
// CHECK-SAME: outs(%[[EMPTY]] : tensor<1x4xf32>) permutation = [1, 0]
204204
// CHECK: %[[INSERT:.+]] = tensor.insert_slice %transposed into %[[DEST]][0, 0, 0] [1, 1, 4] [1, 1, 1] : tensor<1x4xf32> into tensor<1x1x4xf32>
205205
// CHECK: return %[[INSERT]]
206+
207+
// -----
208+
209+
/// Note "126", which is a non-unit tile-outer-dim. This is not supported.
210+
211+
func.func @negative_non_unit_tiled_outer_dim(%src: tensor<1x126x1x1x8xf32>, %dest: tensor<1x1x1x1001xf32>) -> tensor<1x1x1x1001xf32> {
212+
%unpack = linalg.unpack %src
213+
outer_dims_perm = [0, 3, 2, 1]
214+
inner_dims_pos = [3]
215+
inner_tiles = [8]
216+
into %dest : tensor<1x126x1x1x8xf32>
217+
-> tensor<1x1x1x1001xf32>
218+
219+
return %unpack : tensor<1x1x1x1001xf32>
220+
}
221+
// CHECK-LABEL: @negative_non_unit_tiled_outer_dim(
222+
// CHECK: linalg.unpack

0 commit comments

Comments
 (0)