diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp index 615d1f66414b9..a775699f99343 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -1178,13 +1178,6 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite( int64_t destRank = packOp.getDestRank(); int64_t numTiles = destRank - srcRank; - if (!llvm::all_of(packOp.getInnerDimsPos(), - [&srcRank, &numTiles](int64_t dimPos) { - return dimPos >= (srcRank - numTiles - 1); - })) - return rewriter.notifyMatchFailure( - packOp, "Attempting to tile non-trailing source dims!"); - // 1. Extract the inner tile sizes. // Where possible, values are replaced with constant attributes (to match the // behaviour of `getPackOpSourceOrPaddedSource`). @@ -1204,16 +1197,24 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite( // %init = tensor.empty() // %transposed_tile = linalg.transpose ins(%source_or_padded_source), // outs(%init) - // Two assumptions are made: - // 1. All outer dims are 1 - the corresponding transposition doesn't matter. - // 2. Inner dims position correspond to the trailing `numTiles` dims. - SmallVector tilesPermNormalized = - getPackUnpackNormalizedPerm(srcRank, packOp.getInnerDimsPos()); + // Assumptions made: + // 1. All outer dims are 1 - the corresponding transposition order doesn't + // matter, but requires all dim indices to be present. SmallVector srcPermForTranspose; - for (int64_t i = 0; i < (srcRank - numTiles); i++) + ArrayRef innerDimPos(packOp.getInnerDimsPos()); + for (int64_t i = 0; i < srcRank; i++) { + // We assume the `k` dimensions of the inner dim position, where `k` is the + // rank of the inner tiling, correspond to the last `k` indices of the + // transpose permutation. This is done by adding the indices not contained + // in the inner dimension position in order from 0 to `n`. Where n is the + // rank of the source tensor. For example if we have a source tensor with + // indices [0, 1, 2, 3] and inner dim position of [3, 0], the remaining + // indices are [1, 2]. and the transpose will be [1, 2, 3, 0]. + if (llvm::is_contained(innerDimPos, i)) + continue; srcPermForTranspose.push_back(i); - - srcPermForTranspose.append(SmallVector(packOp.getInnerDimsPos())); + } + srcPermForTranspose.append(innerDimPos.begin(), innerDimPos.end()); LLVM_DEBUG(DBGS() << "Pack permutation: " << packOp << "\n" << "perm: " << llvm::interleaved(srcPermForTranspose) diff --git a/mlir/test/Dialect/Linalg/decompose-pack.mlir b/mlir/test/Dialect/Linalg/decompose-pack.mlir index 911b453f919c3..17e6c29754f9d 100644 --- a/mlir/test/Dialect/Linalg/decompose-pack.mlir +++ b/mlir/test/Dialect/Linalg/decompose-pack.mlir @@ -229,3 +229,48 @@ func.func @simple_KCRS_to_KRSCsr(%arg0: tensor<1x1x32x8xf32>, %arg1: tensor<1x1x // CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %[[DEST]] // CHECK-SAME: [0, 0, 0, 0, 0, 0] [1, 1, 1, 1, 8, 32] [1, 1, 1, 1, 1, 1] // CHECK: return %[[INSERT]] + +// ----- + +// The following example shows a pack operation that is defined with inner +// dimension positions that are not adjacent, i.e. `[2, 0]`. And the outer +// dimensions of the packed tensor are of unit values, i.e. `1x1x1`. +func.func @pack_with_non_adjacent_inner_dims_pos_and_unit_outer(%arg0: tensor<1x1x4xf32>, %arg1: tensor<1x1x1x4x1xf32>) -> tensor<1x1x1x4x1xf32> { + %pack = linalg.pack %arg0 outer_dims_perm = [1, 2, 0] inner_dims_pos = [2, 0] inner_tiles = [4, 1] into %arg1 : tensor<1x1x4xf32> -> tensor<1x1x1x4x1xf32> + return %pack : tensor<1x1x1x4x1xf32> +} +// CHECK-LABEL: func.func @pack_with_non_adjacent_inner_dims_pos_and_unit_outer +// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]] +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x4x1xf32> +// CHECK: %[[TRANSP:.+]] = linalg.transpose +// CHECK-SAME: ins(%[[SRC]] : tensor<1x1x4xf32>) +// CHECK-SAME: outs(%[[EMPTY]] : tensor<1x4x1xf32>) +// CHECK-SAME: permutation = [1, 2, 0] +// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %[[DEST]] +// CHECK-SAME: [0, 0, 0, 0, 0] [1, 1, 1, 4, 1] [1, 1, 1, 1, 1] : tensor<1x4x1xf32> into tensor<1x1x1x4x1xf32> +// CHECK: return %[[INSERT]] + +// ----- + +// The following example shows a pack operation where the inner dimension +// positions are specified as [2, 1] which are termed adjacent trailing +// dimensions as they contain the last dimension of the source tensor with a +// neighboring dimension. [1, 2] would also be considered trailing adjacent. +// And the outer dimensions of the packed tensor are all set to unit values +// of `1x1x1`. +func.func @pack_with_adjacent_trailing_dimensions_inner_dims_pos_and_unit_outer(%arg0: tensor<1x1x4xf32>, %arg1: tensor<1x1x1x4x1xf32>) -> tensor<1x1x1x4x1xf32> { + %pack = linalg.pack %arg0 outer_dims_perm = [1, 2, 0] inner_dims_pos = [2, 1] inner_tiles = [4, 1] into %arg1 : tensor<1x1x4xf32> -> tensor<1x1x1x4x1xf32> + return %pack : tensor<1x1x1x4x1xf32> +} +// CHECK-LABEL: func.func @pack_with_adjacent_trailing_dimensions_inner_dims_pos_and_unit_outer +// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]] +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x4x1xf32> +// CHECK: %[[TRANSP:.+]] = linalg.transpose +// CHECK-SAME: ins(%[[SRC]] : tensor<1x1x4xf32>) +// CHECK-SAME: outs(%[[EMPTY]] : tensor<1x4x1xf32>) +// CHECK-SAME: permutation = [0, 2, 1] +// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %[[DEST]] +// CHECK-SAME: [0, 0, 0, 0, 0] [1, 1, 1, 4, 1] [1, 1, 1, 1, 1] : tensor<1x4x1xf32> into tensor<1x1x1x4x1xf32> +// CHECK: return %[[INSERT]] diff --git a/mlir/test/Dialect/Linalg/decompose-unpack.mlir b/mlir/test/Dialect/Linalg/decompose-unpack.mlir index d460c506d6e18..e173d557c770d 100644 --- a/mlir/test/Dialect/Linalg/decompose-unpack.mlir +++ b/mlir/test/Dialect/Linalg/decompose-unpack.mlir @@ -169,3 +169,37 @@ func.func @unpack_with_dynamic_dims(%arg0: tensor, %arg1: tens // CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[EXTRACT_SLICE]] into %[[DEST]] // CHECK-SAME: [0, 0, 0, 0] [%[[DIM0_DEST]], 1, 32, 8] [1, 1, 1, 1] // CHECK: return %[[INSERT]] + +// ----- + +func.func @unpack_with_non_adjacent_inner_dims_pos_and_unit_outer(%arg0: tensor<1x1x1x4x1xf32>, %arg1: tensor<1x1x4xf32>) -> tensor<1x1x4xf32> { + %0 = linalg.unpack %arg0 outer_dims_perm = [1, 2, 0] inner_dims_pos = [2, 0] inner_tiles = [4, 1] into %arg1 : tensor<1x1x1x4x1xf32> -> tensor<1x1x4xf32> + return %0 : tensor<1x1x4xf32> +} +// CHECK-LABEL: func.func @unpack_with_non_adjacent_inner_dims_pos_and_unit_outer +// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]] +// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[SRC]][0, 0, 0, 0, 0] [1, 1, 1, 4, 1] [1, 1, 1, 1, 1] : tensor<1x1x1x4x1xf32> to tensor<4x1xf32> +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x4xf32> +// CHECK: %[[TRANSP:.+]] = linalg.transpose +// CHECK-SAME: ins(%[[SLICE]] : tensor<4x1xf32>) +// CHECK-SAME: outs(%[[EMPTY]] : tensor<1x4xf32>) permutation = [1, 0] +// CHECK: %[[INSERT:.+]] = tensor.insert_slice %transposed into %[[DEST]][0, 0, 0] [1, 1, 4] [1, 1, 1] : tensor<1x4xf32> into tensor<1x1x4xf32> +// CHECK: return %[[INSERT]] + +// ----- + +func.func @unpack_with_non_trailing_dimensions_in_inner_dims(%arg0: tensor<1x1x1x4x1xf32>, %arg1: tensor<1x1x4xf32>) -> tensor<1x1x4xf32> { + %pack = linalg.unpack %arg0 outer_dims_perm = [1, 2, 0] inner_dims_pos = [2, 1] inner_tiles = [4, 1] into %arg1 : tensor<1x1x1x4x1xf32> -> tensor<1x1x4xf32> + return %pack : tensor<1x1x4xf32> +} +// CHECK-LABEL: func.func @unpack_with_non_trailing_dimensions_in_inner_dims +// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]] +// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[SRC]][0, 0, 0, 0, 0] [1, 1, 1, 4, 1] [1, 1, 1, 1, 1] : tensor<1x1x1x4x1xf32> to tensor<4x1xf32> +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x4xf32> +// CHECK: %[[TRANSP:.+]] = linalg.transpose +// CHECK-SAME: ins(%[[SLICE]] : tensor<4x1xf32>) +// CHECK-SAME: outs(%[[EMPTY]] : tensor<1x4xf32>) permutation = [1, 0] +// CHECK: %[[INSERT:.+]] = tensor.insert_slice %transposed into %[[DEST]][0, 0, 0] [1, 1, 4] [1, 1, 1] : tensor<1x4xf32> into tensor<1x1x4xf32> +// CHECK: return %[[INSERT]]