From 9921df4f84a1bb3e9210e23bca6dabcf40b6b405 Mon Sep 17 00:00:00 2001 From: Christopher McGirr Date: Thu, 22 May 2025 15:55:55 +0000 Subject: [PATCH 1/4] [mlir][linalg] fix OuterUnitDims linalg.pack decomposition pattern Given the following example: ``` module { func.func @main(%arg0: tensor<1x1x1x4x1xf32>, %arg1: tensor<1x1x4xf32>) -> tensor<1x1x1x4x1xf32> { %pack = linalg.pack %arg1 outer_dims_perm = [1, 2, 0] inner_dims_pos = [2, 0] inner_tiles = [4, 1] into %arg0 : tensor<1x1x4xf32> -> tensor<1x1x1x4x1xf32> return %pack : tensor<1x1x1x4x1xf32> } } ``` We would generate an invalid transpose operation because the calculated permutation would be `[0, 2, 0]` which is semantically incorrect. As the permutation must contain unique integers corresponding to the source tensor dimensions. The following change modifies how we calculate the permutation array and ensures that the dimension indices given in the permutation array is unique. The above example would then translate to a transpose having a permutation of `[1, 2, 0]`. Following the rule, that the `inner_dim_pos` is appended to the permutation array and the preceding indices are filled with the remaining dimensions. --- .../Dialect/Linalg/Transforms/Transforms.cpp | 23 ++++++++++++------- mlir/test/Dialect/Linalg/decompose-pack.mlir | 19 +++++++++++++++ 2 files changed, 34 insertions(+), 8 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp index 615d1f66414b9..8f6488dbfab3d 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -1204,16 +1204,23 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite( // %init = tensor.empty() // %transposed_tile = linalg.transpose ins(%source_or_padded_source), // outs(%init) - // Two assumptions are made: - // 1. All outer dims are 1 - the corresponding transposition doesn't matter. - // 2. Inner dims position correspond to the trailing `numTiles` dims. - SmallVector tilesPermNormalized = - getPackUnpackNormalizedPerm(srcRank, packOp.getInnerDimsPos()); + // Assumptions made: + // 1. Inner dims position correspond to the trailing `numTiles` dims. SmallVector srcPermForTranspose; - for (int64_t i = 0; i < (srcRank - numTiles); i++) + ArrayRef innerDimPos(packOp.getInnerDimsPos()); + for (int64_t i = 0; i < srcRank; i++) { + // As we assume the trailing dimensions of the inner dim position correspond + // to the trailing indices of the transpose permutation, we need to + // calculate the remaining indicies of the transpose permutation. This is + // done by adding the indices not contained in the inner dimension position. + // For example if we have a source tensor of dimensions [0, 1, 2, 3] + // and inner dim position of [3, 0], the remaining indices are [1, 2]. + // and the transpose will be [1, 2, 3, 0]. + if (llvm::is_contained(innerDimPos, i)) + continue; srcPermForTranspose.push_back(i); - - srcPermForTranspose.append(SmallVector(packOp.getInnerDimsPos())); + } + srcPermForTranspose.append(innerDimPos.begin(), innerDimPos.end()); LLVM_DEBUG(DBGS() << "Pack permutation: " << packOp << "\n" << "perm: " << llvm::interleaved(srcPermForTranspose) diff --git a/mlir/test/Dialect/Linalg/decompose-pack.mlir b/mlir/test/Dialect/Linalg/decompose-pack.mlir index 911b453f919c3..6d091406a639c 100644 --- a/mlir/test/Dialect/Linalg/decompose-pack.mlir +++ b/mlir/test/Dialect/Linalg/decompose-pack.mlir @@ -229,3 +229,22 @@ func.func @simple_KCRS_to_KRSCsr(%arg0: tensor<1x1x32x8xf32>, %arg1: tensor<1x1x // CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %[[DEST]] // CHECK-SAME: [0, 0, 0, 0, 0, 0] [1, 1, 1, 1, 8, 32] [1, 1, 1, 1, 1, 1] // CHECK: return %[[INSERT]] + +// ----- + +func.func @pack_with_unit_outer_dims_and_unit_inner(%arg0: tensor<1x1x4xf32>, %arg1: tensor<1x1x1x4x1xf32>) -> tensor<1x1x1x4x1xf32> { + %pack = linalg.pack %arg0 outer_dims_perm = [1, 2, 0] inner_dims_pos = [2, 0] inner_tiles = [4, 1] into %arg1 : tensor<1x1x4xf32> -> tensor<1x1x1x4x1xf32> + return %pack : tensor<1x1x1x4x1xf32> +} + +// CHECK-LABEL: func.func @pack_with_unit_outer_dims_and_unit_inner +// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]] +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x4x1xf32> +// CHECK: %[[TRANSP:.+]] = linalg.transpose +// CHECK-SAME: ins(%[[SRC]] : tensor<1x1x4xf32>) +// CHECK-SAME: outs(%[[EMPTY]] : tensor<1x4x1xf32>) +// CHECK-SAME: permutation = [1, 2, 0] +// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %[[DEST]] +// CHECK-SAME: [0, 0, 0, 0, 0] [1, 1, 1, 4, 1] [1, 1, 1, 1, 1] : tensor<1x4x1xf32> into tensor<1x1x1x4x1xf32> +// CHECK: return %[[INSERT]] \ No newline at end of file From 861c6a28c58027a32e0ff406ddbb6da7b8d93085 Mon Sep 17 00:00:00 2001 From: Christopher McGirr Date: Mon, 2 Jun 2025 12:58:05 +0000 Subject: [PATCH 2/4] Update(1) [mlir][linalg] fix OuterUnitDims linalg.pack decomposition pattern --- .../Dialect/Linalg/Transforms/Transforms.cpp | 26 +++++++++---------- mlir/test/Dialect/Linalg/decompose-pack.mlir | 3 +-- .../test/Dialect/Linalg/decompose-unpack.mlir | 17 ++++++++++++ 3 files changed, 31 insertions(+), 15 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp index 8f6488dbfab3d..7fe3414609ccd 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -1178,13 +1178,6 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite( int64_t destRank = packOp.getDestRank(); int64_t numTiles = destRank - srcRank; - if (!llvm::all_of(packOp.getInnerDimsPos(), - [&srcRank, &numTiles](int64_t dimPos) { - return dimPos >= (srcRank - numTiles - 1); - })) - return rewriter.notifyMatchFailure( - packOp, "Attempting to tile non-trailing source dims!"); - // 1. Extract the inner tile sizes. // Where possible, values are replaced with constant attributes (to match the // behaviour of `getPackOpSourceOrPaddedSource`). @@ -1205,15 +1198,22 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite( // %transposed_tile = linalg.transpose ins(%source_or_padded_source), // outs(%init) // Assumptions made: - // 1. Inner dims position correspond to the trailing `numTiles` dims. + // 1. All outer dims are 1 - the corresponding transposition order doesn't + // matter, but requires all dim indices to be present. + // 2. Inner dims position can have non-adjacent trailing dimensions. Where, + // For example, a source tensor with indices [0, 1, 2] can have: + // * adjacent trailing dimensions of [1, 2], [2, 1] + // * non-adjacent trailing dimensions of [0, 2] or [2, 0] + // Trailing dimensions are defined in the case above as index [2]. + // And the indices [0] or [1] are not defined to be trailing. SmallVector srcPermForTranspose; ArrayRef innerDimPos(packOp.getInnerDimsPos()); for (int64_t i = 0; i < srcRank; i++) { - // As we assume the trailing dimensions of the inner dim position correspond - // to the trailing indices of the transpose permutation, we need to - // calculate the remaining indicies of the transpose permutation. This is - // done by adding the indices not contained in the inner dimension position. - // For example if we have a source tensor of dimensions [0, 1, 2, 3] + // We assume the `k` dimensions of the inner dim position correspond + // to the last `k` indices of the transpose permutation. This is + // done by adding the indices not contained in the inner dimension position + // in order from 0 to `n`. Where n is the rank of the source tensor. + // For example if we have a source tensor with indices [0, 1, 2, 3] // and inner dim position of [3, 0], the remaining indices are [1, 2]. // and the transpose will be [1, 2, 3, 0]. if (llvm::is_contained(innerDimPos, i)) diff --git a/mlir/test/Dialect/Linalg/decompose-pack.mlir b/mlir/test/Dialect/Linalg/decompose-pack.mlir index 6d091406a639c..6239a82168f38 100644 --- a/mlir/test/Dialect/Linalg/decompose-pack.mlir +++ b/mlir/test/Dialect/Linalg/decompose-pack.mlir @@ -236,7 +236,6 @@ func.func @pack_with_unit_outer_dims_and_unit_inner(%arg0: tensor<1x1x4xf32>, %a %pack = linalg.pack %arg0 outer_dims_perm = [1, 2, 0] inner_dims_pos = [2, 0] inner_tiles = [4, 1] into %arg1 : tensor<1x1x4xf32> -> tensor<1x1x1x4x1xf32> return %pack : tensor<1x1x1x4x1xf32> } - // CHECK-LABEL: func.func @pack_with_unit_outer_dims_and_unit_inner // CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]] // CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]] @@ -247,4 +246,4 @@ func.func @pack_with_unit_outer_dims_and_unit_inner(%arg0: tensor<1x1x4xf32>, %a // CHECK-SAME: permutation = [1, 2, 0] // CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %[[DEST]] // CHECK-SAME: [0, 0, 0, 0, 0] [1, 1, 1, 4, 1] [1, 1, 1, 1, 1] : tensor<1x4x1xf32> into tensor<1x1x1x4x1xf32> -// CHECK: return %[[INSERT]] \ No newline at end of file +// CHECK: return %[[INSERT]] diff --git a/mlir/test/Dialect/Linalg/decompose-unpack.mlir b/mlir/test/Dialect/Linalg/decompose-unpack.mlir index d460c506d6e18..c6c99dca186d5 100644 --- a/mlir/test/Dialect/Linalg/decompose-unpack.mlir +++ b/mlir/test/Dialect/Linalg/decompose-unpack.mlir @@ -169,3 +169,20 @@ func.func @unpack_with_dynamic_dims(%arg0: tensor, %arg1: tens // CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[EXTRACT_SLICE]] into %[[DEST]] // CHECK-SAME: [0, 0, 0, 0] [%[[DIM0_DEST]], 1, 32, 8] [1, 1, 1, 1] // CHECK: return %[[INSERT]] + +// ----- + +func.func @pack_with_unit_outer_dims_and_unit_inner(%arg0: tensor<1x1x1x4x1xf32>, %arg1: tensor<1x1x4xf32>) -> tensor<1x1x4xf32> { + %0 = linalg.unpack %arg0 outer_dims_perm = [1, 2, 0] inner_dims_pos = [2, 0] inner_tiles = [4, 1] into %arg1 : tensor<1x1x1x4x1xf32> -> tensor<1x1x4xf32> + return %0 : tensor<1x1x4xf32> +} +// CHECK-LABEL: func.func @pack_with_unit_outer_dims_and_unit_inner +// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]] +// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[SRC]][0, 0, 0, 0, 0] [1, 1, 1, 4, 1] [1, 1, 1, 1, 1] : tensor<1x1x1x4x1xf32> to tensor<4x1xf32> +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x4xf32> +// CHECK: %[[TRANSP:.+]] = linalg.transpose +// CHECK-SAME: ins(%[[SLICE]] : tensor<4x1xf32>) +// CHECK-SAME: outs(%[[EMPTY]] : tensor<1x4xf32>) permutation = [1, 0] +// CHECK: %[[INSERT:.+]] = tensor.insert_slice %transposed into %[[DEST]][0, 0, 0] [1, 1, 4] [1, 1, 1] : tensor<1x4xf32> into tensor<1x1x4xf32> +// CHECK: return %[[INSERT]] From c3b4e6147d4d687850f16b6221c9e4b3111c301a Mon Sep 17 00:00:00 2001 From: Christopher McGirr Date: Thu, 12 Jun 2025 14:59:54 +0000 Subject: [PATCH 3/4] Update(2) [mlir][linalg] fix OuterUnitDims linalg.pack decomposition pattern --- .../Dialect/Linalg/Transforms/Transforms.cpp | 20 +++++------- mlir/test/Dialect/Linalg/decompose-pack.mlir | 31 +++++++++++++++++-- .../test/Dialect/Linalg/decompose-unpack.mlir | 21 +++++++++++-- 3 files changed, 55 insertions(+), 17 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp index 7fe3414609ccd..a775699f99343 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -1200,22 +1200,16 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite( // Assumptions made: // 1. All outer dims are 1 - the corresponding transposition order doesn't // matter, but requires all dim indices to be present. - // 2. Inner dims position can have non-adjacent trailing dimensions. Where, - // For example, a source tensor with indices [0, 1, 2] can have: - // * adjacent trailing dimensions of [1, 2], [2, 1] - // * non-adjacent trailing dimensions of [0, 2] or [2, 0] - // Trailing dimensions are defined in the case above as index [2]. - // And the indices [0] or [1] are not defined to be trailing. SmallVector srcPermForTranspose; ArrayRef innerDimPos(packOp.getInnerDimsPos()); for (int64_t i = 0; i < srcRank; i++) { - // We assume the `k` dimensions of the inner dim position correspond - // to the last `k` indices of the transpose permutation. This is - // done by adding the indices not contained in the inner dimension position - // in order from 0 to `n`. Where n is the rank of the source tensor. - // For example if we have a source tensor with indices [0, 1, 2, 3] - // and inner dim position of [3, 0], the remaining indices are [1, 2]. - // and the transpose will be [1, 2, 3, 0]. + // We assume the `k` dimensions of the inner dim position, where `k` is the + // rank of the inner tiling, correspond to the last `k` indices of the + // transpose permutation. This is done by adding the indices not contained + // in the inner dimension position in order from 0 to `n`. Where n is the + // rank of the source tensor. For example if we have a source tensor with + // indices [0, 1, 2, 3] and inner dim position of [3, 0], the remaining + // indices are [1, 2]. and the transpose will be [1, 2, 3, 0]. if (llvm::is_contained(innerDimPos, i)) continue; srcPermForTranspose.push_back(i); diff --git a/mlir/test/Dialect/Linalg/decompose-pack.mlir b/mlir/test/Dialect/Linalg/decompose-pack.mlir index 6239a82168f38..17e6c29754f9d 100644 --- a/mlir/test/Dialect/Linalg/decompose-pack.mlir +++ b/mlir/test/Dialect/Linalg/decompose-pack.mlir @@ -232,11 +232,14 @@ func.func @simple_KCRS_to_KRSCsr(%arg0: tensor<1x1x32x8xf32>, %arg1: tensor<1x1x // ----- -func.func @pack_with_unit_outer_dims_and_unit_inner(%arg0: tensor<1x1x4xf32>, %arg1: tensor<1x1x1x4x1xf32>) -> tensor<1x1x1x4x1xf32> { +// The following example shows a pack operation that is defined with inner +// dimension positions that are not adjacent, i.e. `[2, 0]`. And the outer +// dimensions of the packed tensor are of unit values, i.e. `1x1x1`. +func.func @pack_with_non_adjacent_inner_dims_pos_and_unit_outer(%arg0: tensor<1x1x4xf32>, %arg1: tensor<1x1x1x4x1xf32>) -> tensor<1x1x1x4x1xf32> { %pack = linalg.pack %arg0 outer_dims_perm = [1, 2, 0] inner_dims_pos = [2, 0] inner_tiles = [4, 1] into %arg1 : tensor<1x1x4xf32> -> tensor<1x1x1x4x1xf32> return %pack : tensor<1x1x1x4x1xf32> } -// CHECK-LABEL: func.func @pack_with_unit_outer_dims_and_unit_inner +// CHECK-LABEL: func.func @pack_with_non_adjacent_inner_dims_pos_and_unit_outer // CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]] // CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]] // CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x4x1xf32> @@ -247,3 +250,27 @@ func.func @pack_with_unit_outer_dims_and_unit_inner(%arg0: tensor<1x1x4xf32>, %a // CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %[[DEST]] // CHECK-SAME: [0, 0, 0, 0, 0] [1, 1, 1, 4, 1] [1, 1, 1, 1, 1] : tensor<1x4x1xf32> into tensor<1x1x1x4x1xf32> // CHECK: return %[[INSERT]] + +// ----- + +// The following example shows a pack operation where the inner dimension +// positions are specified as [2, 1] which are termed adjacent trailing +// dimensions as they contain the last dimension of the source tensor with a +// neighboring dimension. [1, 2] would also be considered trailing adjacent. +// And the outer dimensions of the packed tensor are all set to unit values +// of `1x1x1`. +func.func @pack_with_adjacent_trailing_dimensions_inner_dims_pos_and_unit_outer(%arg0: tensor<1x1x4xf32>, %arg1: tensor<1x1x1x4x1xf32>) -> tensor<1x1x1x4x1xf32> { + %pack = linalg.pack %arg0 outer_dims_perm = [1, 2, 0] inner_dims_pos = [2, 1] inner_tiles = [4, 1] into %arg1 : tensor<1x1x4xf32> -> tensor<1x1x1x4x1xf32> + return %pack : tensor<1x1x1x4x1xf32> +} +// CHECK-LABEL: func.func @pack_with_adjacent_trailing_dimensions_inner_dims_pos_and_unit_outer +// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]] +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x4x1xf32> +// CHECK: %[[TRANSP:.+]] = linalg.transpose +// CHECK-SAME: ins(%[[SRC]] : tensor<1x1x4xf32>) +// CHECK-SAME: outs(%[[EMPTY]] : tensor<1x4x1xf32>) +// CHECK-SAME: permutation = [0, 2, 1] +// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %[[DEST]] +// CHECK-SAME: [0, 0, 0, 0, 0] [1, 1, 1, 4, 1] [1, 1, 1, 1, 1] : tensor<1x4x1xf32> into tensor<1x1x1x4x1xf32> +// CHECK: return %[[INSERT]] diff --git a/mlir/test/Dialect/Linalg/decompose-unpack.mlir b/mlir/test/Dialect/Linalg/decompose-unpack.mlir index c6c99dca186d5..02ea6b1048afb 100644 --- a/mlir/test/Dialect/Linalg/decompose-unpack.mlir +++ b/mlir/test/Dialect/Linalg/decompose-unpack.mlir @@ -172,11 +172,11 @@ func.func @unpack_with_dynamic_dims(%arg0: tensor, %arg1: tens // ----- -func.func @pack_with_unit_outer_dims_and_unit_inner(%arg0: tensor<1x1x1x4x1xf32>, %arg1: tensor<1x1x4xf32>) -> tensor<1x1x4xf32> { +func.func @unpack_with_unit_outer_dims_and_unit_inner(%arg0: tensor<1x1x1x4x1xf32>, %arg1: tensor<1x1x4xf32>) -> tensor<1x1x4xf32> { %0 = linalg.unpack %arg0 outer_dims_perm = [1, 2, 0] inner_dims_pos = [2, 0] inner_tiles = [4, 1] into %arg1 : tensor<1x1x1x4x1xf32> -> tensor<1x1x4xf32> return %0 : tensor<1x1x4xf32> } -// CHECK-LABEL: func.func @pack_with_unit_outer_dims_and_unit_inner +// CHECK-LABEL: func.func @unpack_with_unit_outer_dims_and_unit_inner // CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]] // CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]] // CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[SRC]][0, 0, 0, 0, 0] [1, 1, 1, 4, 1] [1, 1, 1, 1, 1] : tensor<1x1x1x4x1xf32> to tensor<4x1xf32> @@ -186,3 +186,20 @@ func.func @pack_with_unit_outer_dims_and_unit_inner(%arg0: tensor<1x1x1x4x1xf32> // CHECK-SAME: outs(%[[EMPTY]] : tensor<1x4xf32>) permutation = [1, 0] // CHECK: %[[INSERT:.+]] = tensor.insert_slice %transposed into %[[DEST]][0, 0, 0] [1, 1, 4] [1, 1, 1] : tensor<1x4xf32> into tensor<1x1x4xf32> // CHECK: return %[[INSERT]] + +// ----- + +func.func @unpack_with_non_trailing_dimensions_in_inner_dims(%arg0: tensor<1x1x1x4x1xf32>, %arg1: tensor<1x1x4xf32>) -> tensor<1x1x4xf32> { + %pack = linalg.unpack %arg0 outer_dims_perm = [1, 2, 0] inner_dims_pos = [2, 1] inner_tiles = [4, 1] into %arg1 : tensor<1x1x1x4x1xf32> -> tensor<1x1x4xf32> + return %pack : tensor<1x1x4xf32> +} +// CHECK-LABEL: func.func @unpack_with_non_trailing_dimensions_in_inner_dims +// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]] +// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[SRC]][0, 0, 0, 0, 0] [1, 1, 1, 4, 1] [1, 1, 1, 1, 1] : tensor<1x1x1x4x1xf32> to tensor<4x1xf32> +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x4xf32> +// CHECK: %[[TRANSP:.+]] = linalg.transpose +// CHECK-SAME: ins(%[[SLICE]] : tensor<4x1xf32>) +// CHECK-SAME: outs(%[[EMPTY]] : tensor<1x4xf32>) permutation = [1, 0] +// CHECK: %[[INSERT:.+]] = tensor.insert_slice %transposed into %[[DEST]][0, 0, 0] [1, 1, 4] [1, 1, 1] : tensor<1x4xf32> into tensor<1x1x4xf32> +// CHECK: return %[[INSERT]] From af4d38d6449466ed52062ca85db97957991f80c1 Mon Sep 17 00:00:00 2001 From: Christopher McGirr Date: Thu, 26 Jun 2025 07:58:43 +0200 Subject: [PATCH 4/4] Update(3) [mlir][linalg] fix OuterUnitDims linalg.pack decomposition pattern --- mlir/test/Dialect/Linalg/decompose-unpack.mlir | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlir/test/Dialect/Linalg/decompose-unpack.mlir b/mlir/test/Dialect/Linalg/decompose-unpack.mlir index 02ea6b1048afb..e173d557c770d 100644 --- a/mlir/test/Dialect/Linalg/decompose-unpack.mlir +++ b/mlir/test/Dialect/Linalg/decompose-unpack.mlir @@ -172,11 +172,11 @@ func.func @unpack_with_dynamic_dims(%arg0: tensor, %arg1: tens // ----- -func.func @unpack_with_unit_outer_dims_and_unit_inner(%arg0: tensor<1x1x1x4x1xf32>, %arg1: tensor<1x1x4xf32>) -> tensor<1x1x4xf32> { +func.func @unpack_with_non_adjacent_inner_dims_pos_and_unit_outer(%arg0: tensor<1x1x1x4x1xf32>, %arg1: tensor<1x1x4xf32>) -> tensor<1x1x4xf32> { %0 = linalg.unpack %arg0 outer_dims_perm = [1, 2, 0] inner_dims_pos = [2, 0] inner_tiles = [4, 1] into %arg1 : tensor<1x1x1x4x1xf32> -> tensor<1x1x4xf32> return %0 : tensor<1x1x4xf32> } -// CHECK-LABEL: func.func @unpack_with_unit_outer_dims_and_unit_inner +// CHECK-LABEL: func.func @unpack_with_non_adjacent_inner_dims_pos_and_unit_outer // CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]] // CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]] // CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[SRC]][0, 0, 0, 0, 0] [1, 1, 1, 4, 1] [1, 1, 1, 1, 1] : tensor<1x1x1x4x1xf32> to tensor<4x1xf32>