From a4ed7d89ea1bfcfdfcf08e43304bd8ae8662424f Mon Sep 17 00:00:00 2001 From: hanhanW Date: Thu, 17 Jul 2025 19:03:25 -0700 Subject: [PATCH 1/3] [mlir][linalg] Allow pack consumer fusion if the tile size is greater than dimension size. This only happens when you use tile size which is greater than or equal to the dimension size. In this case, it is a full slice, so it is fusiable. The IR can be generated during the TileAndFuse process. It is hard to fix in such driver, so we enable the naive fusion for the case. Signed-off-by: hanhanW --- .../Linalg/Transforms/TilingInterfaceImpl.cpp | 2 +- .../tile-and-fuse-consumer.mlir | 50 +++++++++++++++++++ 2 files changed, 51 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp index b059bcc025315..7581ccd22d8ec 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp @@ -918,7 +918,7 @@ struct PackOpTiling int64_t destDimSize = outerShapeWithoutTranspose[dim]; bool isTiled = failed(cstTileSize) || ShapedType::isDynamic(srcDimSize) || - cstTileSize.value() != srcDimSize; + cstTileSize.value() < srcDimSize; if (!isTiled) { outerDimOffsets.push_back(offsets[dim]); if (ShapedType::isStatic(destDimSize)) { diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir index 20164d5dfd91a..cdbca7228ded3 100644 --- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir +++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir @@ -451,6 +451,56 @@ module attributes {transform.with_named_sequence} { // ----- +#map = affine_map<(d0) -> (-d0 + 4, 16)> +func.func @fuse_pack_consumer_if_single_iteration(%arg0: tensor<4x4xf32>) -> tensor<1x4x16x1xf32> { + %0 = tensor.empty() : tensor<1x4x16x1xf32> + %1 = tensor.empty() : tensor<4x4xf32> + %2 = scf.forall (%arg1) = (0) to (4) step (16) shared_outs(%arg2 = %1) -> (tensor<4x4xf32>) { + %3 = affine.min #map(%arg1) + %extracted_slice = tensor.extract_slice %arg0[%arg1, 0] [%3, 4] [1, 1] : tensor<4x4xf32> to tensor + %extracted_slice_0 = tensor.extract_slice %arg2[%arg1, 0] [%3, 4] [1, 1] : tensor<4x4xf32> to tensor + %4 = linalg.exp ins(%extracted_slice : tensor) outs(%extracted_slice_0 : tensor) -> tensor + scf.forall.in_parallel { + tensor.parallel_insert_slice %4 into %arg2[%arg1, 0] [%3, 4] [1, 1] : tensor into tensor<4x4xf32> + } + } + %cst = arith.constant 0.000000e+00 : f32 + %pack = linalg.pack %2 padding_value(%cst : f32) outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [16, 1] into %0 : tensor<4x4xf32> -> tensor<1x4x16x1xf32> + return %pack : tensor<1x4x16x1xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %1 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %consumer, %fused_consumer = transform.test.fuse_consumer %0 in(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} +// CHECK: #[[MAP:.*]] = affine_map<(d0) -> (-d0 + 4, 16)> +// CHECK: func.func @fuse_pack_consumer_if_single_iteration( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] +// CHECK-DAG: %[[PACK_INIT:.*]] = tensor.empty() : tensor<1x4x16x1xf32> +// CHECK-DAG: %[[ELEM_INIT:.*]] = tensor.empty() : tensor<4x4xf32> +// CHECK-DAG: %[[PAD_VAL:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %{{.*}}:2 = scf.forall (%[[IV:.*]]) = (0) to (4) step (16) +// CHECK-SAME: shared_outs(%[[ELEM_OUT_ARG:.*]] = %[[ELEM_INIT]], %[[PACK_OUT_ARG:.*]] = %[[PACK_INIT]]) +// CHECK-DAG: %[[SIZE:.+]] = affine.min #[[MAP]](%[[IV]]) +// CHECK-DAG: %[[ELEM_SRC:.*]] = tensor.extract_slice %[[ARG0]][%[[IV]], 0] [%[[SIZE]], 4] [1, 1] +// CHECK-DAG: %[[ELEM_DEST:.*]] = tensor.extract_slice %[[ELEM_OUT_ARG]][%[[IV]], 0] [%[[SIZE]], 4] [1, 1] +// CHECK: %[[ELEM:.*]] = linalg.exp +// CHECK-SAME: ins(%[[ELEM_SRC]] +// CHECK-SAME: outs(%[[ELEM_DEST]] +// CHECK-DAG: %[[TILED_PACK_DEST:.*]] = tensor.extract_slice %[[PACK_OUT_ARG]][%[[IV]], 0, 0, 0] [1, 4, 16, 1] [1, 1, 1, 1] +// CHECK: %[[PACK:.*]] = linalg.pack %[[ELEM]] +// CHECK-SAME: padding_value(%[[PAD_VAL]] : f32) +// CHECK-SAME: outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [16, 1] +// CHECK-SAME: into %[[TILED_PACK_DEST]] +// CHECK: scf.forall.in_parallel { +// CHECK: tensor.parallel_insert_slice %[[ELEM]] into %[[ELEM_OUT_ARG]][%[[IV]], 0] [%[[SIZE]], 4] [1, 1] +// CHECK: tensor.parallel_insert_slice %[[PACK]] into %[[PACK_OUT_ARG]][%[[IV]], 0, 0, 0] [1, 4, 16, 1] [1, 1, 1, 1] + +// ----- func.func @fuse_perfect_tiling_pack_consumer_with_outer_dims_perm(%arg0: tensor<64x32xf32>, %arg1: tensor<64x32xf32>, %arg2: tensor<2x64x16x1xf32>) -> tensor<2x64x16x1xf32> { %0 = scf.forall (%arg3) = (0) to (32) step (16) shared_outs(%arg4 = %arg1) -> (tensor<64x32xf32>) { From 4c992de40674753f64c6a6909b6616a3bb587e49 Mon Sep 17 00:00:00 2001 From: hanhanW Date: Fri, 18 Jul 2025 09:44:14 -0700 Subject: [PATCH 2/3] add a comment Signed-off-by: hanhanW --- mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp index 7581ccd22d8ec..0816bcd331c9c 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp @@ -911,7 +911,9 @@ struct PackOpTiling // If a dimension is not tiled, it is always valid to fuse the pack op, // even if the op has padding semantics. Because it always generates a - // full slice along the dimension. + // full slice along the dimension. The tile sizes are for unpacked + // domain, i.e., `srcDimSize`, so `tileSize < srcDimSize` means no + // tiling at all. // TODO: It could be untiled if the `srcDimSize` is dynamic. It is a // hard check to determine if a dimension is tiled or not. int64_t srcDimSize = packOp.getSourceType().getDimSize(dim); From 6c51a22d24339e289891299664ff31065d48e65f Mon Sep 17 00:00:00 2001 From: hanhanW Date: Fri, 18 Jul 2025 09:57:41 -0700 Subject: [PATCH 3/3] correct the comment Signed-off-by: hanhanW --- mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp index 0816bcd331c9c..28d99b130963a 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp @@ -912,8 +912,8 @@ struct PackOpTiling // If a dimension is not tiled, it is always valid to fuse the pack op, // even if the op has padding semantics. Because it always generates a // full slice along the dimension. The tile sizes are for unpacked - // domain, i.e., `srcDimSize`, so `tileSize < srcDimSize` means no - // tiling at all. + // domain, i.e., `srcDimSize`, so `tileSize < srcDimSize` means that the + // dimension is tiled. // TODO: It could be untiled if the `srcDimSize` is dynamic. It is a // hard check to determine if a dimension is tiled or not. int64_t srcDimSize = packOp.getSourceType().getDimSize(dim);