diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index 3c160d55a38e7..f31371ec6a054 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -1519,7 +1519,7 @@ struct DecomposePadOpPattern : public OpRewritePattern { /// * tensor::PadOp + linalg::TransposeOp + tensor::EmptyOp + /// tensor::InsertSliceOp ops. /// -/// Required that all the outer dims of the input tensor::PackOp are 1. +/// Requires that all the outer dims of the input tensor::PackOp are 1. /// /// Before: /// ``` @@ -1555,9 +1555,33 @@ struct DecomposeOuterUnitDimsPackOpPattern PatternRewriter &rewriter) const override; }; -/// Rewrites a tensor::UnPackOp into a sequence of rank-reduced extract_slice op -/// + transpose op + insert_slice op, where the tensor::UnPackOp has outer dims -/// being all 1s. +/// Rewrites a tensor::UnPackOp into a sequence of rank-reduced +/// * tensor::ExtractSliceOp + linalg::TransposeOp + tensor::InsertSliceOp +/// +/// Requires that all the outer dims of the input tensor::PackOp are 1. +/// +/// Before: +/// ``` +/// %packed = tensor.unpack %input +/// inner_dims_pos = [1, 0] +/// inner_tiles = [2, 8] +/// into %output : tensor<1x1x2x8xf32> -> tensor<5x1xf32> +/// ``` +/// +/// After: +/// ``` +/// // Rank-reduced extract to obtain the tile +/// %slice = tensor.extract_slice %arg0[0, 0, 0, 0] [1, 1, 2, 8] [1, 1, 1, 1] +/// : tensor<1x1x2x8xf32> to tensor<2x8xf32> +/// // EmptyOp + TransposeOp +/// %init = tensor.empty() : tensor<8x2xf32> +/// %transposed = linalg.transpose +/// ins(%extracted_slice : tensor<2x8xf32>) +/// outs(%0 : tensor<8x2xf32>) permutation = [1, 0] +/// // Extract a slice matching the specified output size +/// %result = tensor.extract_slice %transposed[0, 0] [5, 1] [1, 1] +/// : tensor<8x2xf32> to tensor<5x1xf32> +/// ``` struct DecomposeOuterUnitDimsUnPackOpPattern : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; diff --git a/mlir/test/Dialect/Linalg/decompose-tensor-pack.mlir b/mlir/test/Dialect/Linalg/decompose-tensor-pack.mlir index 4f986606ef93a..1cc1484ed4095 100644 --- a/mlir/test/Dialect/Linalg/decompose-tensor-pack.mlir +++ b/mlir/test/Dialect/Linalg/decompose-tensor-pack.mlir @@ -67,6 +67,9 @@ func.func @simple_pad_and_pack_dynamic_tile(%input: tensor<5x1xf32>, %output: te // CHECK: %[[RES:.*]] = tensor.insert_slice %[[PAD:.*]] into %[[DEST]][0, 0, 0, 0] [1, 1, %[[TILE_DIM_0]], 2] [1, 1, 1, 1] : tensor into tensor<1x1x?x2xf32> // CHECK: return %[[RES]] : tensor<1x1x?x2xf32> +/// Same as example above, but the dynamic tile size is a compile-time constant +/// that's folded away. + func.func @simple_pad_and_pack_dynamic_tile_cst(%input: tensor<5x1xf32>, %output: tensor<1x1x?x2xf32>, %pad: f32) -> tensor<1x1x?x2xf32> { %tile_dim_0 = arith.constant 8 : index %0 = tensor.pack %input padding_value(%pad : f32) inner_dims_pos = [0, 1] inner_tiles = [%tile_dim_0, 2] into %output : tensor<5x1xf32> -> tensor<1x1x?x2xf32> diff --git a/mlir/test/Dialect/Linalg/decompose-tensor-unpack.mlir b/mlir/test/Dialect/Linalg/decompose-tensor-unpack.mlir index 8b15873473a97..a720c655e4be5 100644 --- a/mlir/test/Dialect/Linalg/decompose-tensor-unpack.mlir +++ b/mlir/test/Dialect/Linalg/decompose-tensor-unpack.mlir @@ -19,11 +19,11 @@ func.func @simple_KCRSsr_to_KCRS(%arg0: tensor<1x1x1x1x8x32xf32>, %arg1: tensor< // ----- -func.func @simple_unpack_and_extract_slice(%input: tensor<1x1x8x2xf32>, %output: tensor<5x1xf32>) -> tensor<5x1xf32> { +func.func @simple_unpack_static_tiles(%input: tensor<1x1x8x2xf32>, %output: tensor<5x1xf32>) -> tensor<5x1xf32> { %0 = tensor.unpack %input inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %output : tensor<1x1x8x2xf32> -> tensor<5x1xf32> return %0 : tensor<5x1xf32> } -// CHECK-LABEL: func.func @simple_unpack_and_extract_slice +// CHECK-LABEL: func.func @simple_unpack_static_tiles // CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]] // CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]] // CHECK: %[[TILE:.+]] = tensor.extract_slice %[[SRC]][0, 0, 0, 0] [1, 1, 8, 2] [1, 1, 1, 1] @@ -33,6 +33,55 @@ func.func @simple_unpack_and_extract_slice(%input: tensor<1x1x8x2xf32>, %output: // CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[TILE]][0, 0] [5, 1] [1, 1] // CHECK: return %[[SLICE]] +/// Same as example above, but with 1 dynamic tile size. + +func.func @simple_unpack_dynamic_tile(%input: tensor<1x1x?x2xf32>, %output: tensor<5x1xf32>, %tile_dim_0: index) -> tensor<5x1xf32> { + %0 = tensor.unpack %input inner_dims_pos = [0, 1] inner_tiles = [%tile_dim_0, 2] into %output : tensor<1x1x?x2xf32> -> tensor<5x1xf32> + return %0 : tensor<5x1xf32> +} +// CHECK-LABEL: func.func @simple_unpack_dynamic_tile +// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[TILE_DIM_1:[a-zA-Z0-9]+]] +// CHECK: %[[TILE:.+]] = tensor.extract_slice %[[SRC]][0, 0, 0, 0] [1, 1, %[[TILE_DIM_1]], 2] [1, 1, 1, 1] +// CHECK-NOT: linalg.transpose +// They have the same type, so the insert_slice op is folded +// away. +// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[TILE]][0, 0] [5, 1] [1, 1] +// CHECK: return %[[SLICE]] + +/// Same as example above, but with 1 dynamic tile size and a trasnpose + +/// FIXME: This is currently broken: +/// * 'tensor.empty' op incorrect number of dynamic sizes, has 0, expected 1 + +//func.func @simple_unpack_dynamic_tile_transpose(%input: tensor<1x1x2x?xf32>, %output: tensor<5x1xf32>, %tile_dim_0: index) -> tensor<5x1xf32> { +// %0 = tensor.unpack %input inner_dims_pos = [1, 0] inner_tiles = [2, %tile_dim_0] into %output : tensor<1x1x2x?xf32> -> tensor<5x1xf32> +// return %0 : tensor<5x1xf32> +//} + +/// Same as example above, but with 1 scalable tile size. + +func.func @simple_unpack_scalable_tile(%input: tensor<1x1x?x2xf32>, %output: tensor<5x1xf32>) -> tensor<5x1xf32> { + %c8 = arith.constant 8 : index + %vscale = vector.vscale + %c8_vscale = arith.muli %vscale, %c8 : index + %0 = tensor.unpack %input inner_dims_pos = [0, 1] inner_tiles = [%c8_vscale, 2] into %output : tensor<1x1x?x2xf32> -> tensor<5x1xf32> + return %0 : tensor<5x1xf32> +} +// CHECK-LABEL: func.func @simple_unpack_scalable_tile +// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]] +// CHECK-DAG: %[[C8:.+]] = arith.constant 8 : index +// CHECK-DAG: %[[VS:.+]] = vector.vscale +// CHECK: %[[C8_VS:.+]] = arith.muli %[[VS]], %[[C8]] : index +// CHECK: %[[TILE:.+]] = tensor.extract_slice %[[SRC]][0, 0, 0, 0] [1, 1, %[[C8_VS]], 2] [1, 1, 1, 1] +// CHECK-NOT: linalg.transpose +// They have the same type, so the insert_slice op is folded +// away. +// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[TILE]][0, 0] [5, 1] [1, 1] +// CHECK: return %[[SLICE]] + // ----- func.func @simple_CNnc_to_NC(%arg0: tensor<1x1x32x8xf32>, %arg1: tensor<32x8xf32>) -> tensor<32x8xf32>{