From 3642d259c3ece69cbc41ab74af863d6b4b221839 Mon Sep 17 00:00:00 2001 From: hanhanW Date: Tue, 15 Jul 2025 16:50:31 -0700 Subject: [PATCH 1/8] [mlir][linalg] Improve linalg.pack consumer fusion. If a dimension is not tiled, it is always valid to to fuse the pack op even if it has padding semantics. Because it always generates a full slice along the dimension. Signed-off-by: hanhanW --- .../Linalg/Transforms/TilingInterfaceImpl.cpp | 41 +-- .../tile-and-fuse-consumer.mlir | 278 ++++++++++-------- 2 files changed, 184 insertions(+), 135 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp index 513cecef29b61..fb9ba4ccb14af 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp @@ -887,26 +887,13 @@ struct PackOpTiling ArrayRef offsets(allOffsets[0]); ArrayRef sizes(allSizes[0]); - auto packOp = cast(op); - // It is not trivial to infer dest tile from source tile if `packOp` has - // padding semantic. - if (packOp.getPaddingValue()) - return failure(); - Location loc = packOp.getLoc(); - SmallVector outerDimOffsets, outerDimSizes; DenseMap dimAndTileMapping = packOp.getDimAndTileMapping(); for (auto dim : llvm::seq(packOp.getSourceRank())) { if (dimAndTileMapping.count(dim)) { - FailureOr cstSize = - ValueBoundsConstraintSet::computeConstantBound( - presburger::BoundType::UB, sizes[dim], - /*stopCondition=*/nullptr, /*closedUB=*/true); - std::optional cstInnerSize = - getConstantIntValue(dimAndTileMapping[dim]); // Currently fusing `packOp` as consumer only expects perfect tiling // scenario because even if without padding semantic, the `packOp` may // also yield incomplete tiles. E.g. tensor<30xf32> -> tensor<5x6xf32>, @@ -916,12 +903,25 @@ struct PackOpTiling // (0,0)~(0,4) at first row. // 2. the second slice is extracted from (5) to (9) and SHOULD BE // respectively inserted into two rows with different length, including - // first row: (0,5) and second row (1,0)~(1,3). It is hard to coordinate - // them, thus adding below constraint to bypass them temporarily. In - // another word, we can only support tiling with consumer if the tile - // size for the producer is a multiple of the inner tile size for the - // packed dimensions at this moment. - if (failed(cstSize) || !cstInnerSize || *cstSize % *cstInnerSize != 0) { + // first row: (0,5) and second row (1,0)~(1,3). + // It is hard to coordinate them, thus adding below constraint to bypass + // them temporarily. In another word, we can only support tiling with + // consumer if the tile size for the producer is either a multiple of + // the inner tile size for the packed dimensions or the dimension is not + // tiled at this moment. + FailureOr cstTileSize = + ValueBoundsConstraintSet::computeConstantBound( + presburger::BoundType::UB, sizes[dim], + /*stopCondition=*/nullptr, /*closedUB=*/true); + std::optional cstInnerSize = + getConstantIntValue(dimAndTileMapping[dim]); + int64_t dimSize = packOp.getSourceType().getDimSize(dim); + // TODO: It could be untiled if the `dimSize` is dynamic. It is a hard + // check to determine if a dimension is tiled or not. + bool isTiled = failed(cstTileSize) || ShapedType::isDynamic(dimSize) || + cstTileSize.value() != dimSize; + if (isTiled && (failed(cstTileSize) || !cstInnerSize || + *cstTileSize % *cstInnerSize != 0)) { return failure(); } @@ -988,7 +988,8 @@ struct PackOpTiling loc, packOp.getDest(), outputOffsets, outputSizes, strides); tiledOperands.push_back(outSlice); - assert(!packOp.getPaddingValue() && "Expect no padding semantic"); + if (auto val = packOp.getPaddingValue()) + tiledOperands.push_back(val); for (auto tile : packOp.getInnerTiles()) tiledOperands.push_back(tile); diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir index d09373bdb3f14..da3592547e125 100644 --- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir +++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir @@ -193,33 +193,33 @@ module attributes {transform.with_named_sequence} { #map = affine_map<(d0, d1) -> (d0, d1)> module { - func.func @fuse_tileable_consumer_scf_forall_multi_yielding_consumer(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x64xf32>, %arg3: tensor<64x32xf32>) -> (tensor<64x64xf32>, tensor<2048xf32>) { - %c4 = arith.constant 4 : index - %c64 = arith.constant 64 : index - %c0 = arith.constant 0 : index - %0:2 = scf.forall (%arg4, %arg5) in (2, 2) shared_outs(%arg6 = %arg3, %arg7 = %arg2) -> (tensor<64x32xf32>, tensor<64x64xf32>) { - %extracted_slice = tensor.extract_slice %arg6[%arg4, %arg5] [32, 32] [1, 1] : tensor<64x32xf32> to tensor<32x32xf32> - %extracted_slice_0 = tensor.extract_slice %arg7[%arg4, %arg5] [32, 32] [1, 1] : tensor<64x64xf32> to tensor<32x32xf32> - %6 = linalg.matmul ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice : tensor<32x32xf32>) -> tensor<32x32xf32> - scf.forall.in_parallel { - tensor.parallel_insert_slice %6 into %arg7[%arg4, %arg5] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x64xf32> - tensor.parallel_insert_slice %extracted_slice_0 into %arg6[%arg4, %arg5] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x32xf32> - } + func.func @fuse_tileable_consumer_scf_forall_multi_yielding_consumer(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x64xf32>, %arg3: tensor<64x32xf32>) -> (tensor<64x64xf32>, tensor<2048xf32>) { + %c4 = arith.constant 4 : index + %c64 = arith.constant 64 : index + %c0 = arith.constant 0 : index + %0:2 = scf.forall (%arg4, %arg5) in (2, 2) shared_outs(%arg6 = %arg3, %arg7 = %arg2) -> (tensor<64x32xf32>, tensor<64x64xf32>) { + %extracted_slice = tensor.extract_slice %arg6[%arg4, %arg5] [32, 32] [1, 1] : tensor<64x32xf32> to tensor<32x32xf32> + %extracted_slice_0 = tensor.extract_slice %arg7[%arg4, %arg5] [32, 32] [1, 1] : tensor<64x64xf32> to tensor<32x32xf32> + %6 = linalg.matmul ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice : tensor<32x32xf32>) -> tensor<32x32xf32> + scf.forall.in_parallel { + tensor.parallel_insert_slice %6 into %arg7[%arg4, %arg5] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x64xf32> + tensor.parallel_insert_slice %extracted_slice_0 into %arg6[%arg4, %arg5] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x32xf32> } - %1 = tensor.empty() : tensor<64x64xf32> - %2 = tensor.empty() : tensor<64x64xf32> - %3 = tensor.empty() : tensor<64x64xf32> - %4:2 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%0#1, %1 : tensor<64x64xf32>, tensor<64x64xf32>) outs(%2, %3 : tensor<64x64xf32>, tensor<64x64xf32>) { - ^bb0(%in: f32, %in_0: f32, %out: f32, %out_1: f32): - %6 = arith.mulf %in, %in_0 : f32 - %7 = arith.subf %out, %6 : f32 - %8 = arith.addf %out_1, %in : f32 - linalg.yield %7, %8 : f32, f32 - } -> (tensor<64x64xf32>, tensor<64x64xf32>) - %5 = tensor.empty() : tensor<2048xf32> - %unpack = linalg.unpack %0#0 outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [32] into %5 : tensor<64x32xf32> -> tensor<2048xf32> - return %4#1, %unpack : tensor<64x64xf32>, tensor<2048xf32> } + %1 = tensor.empty() : tensor<64x64xf32> + %2 = tensor.empty() : tensor<64x64xf32> + %3 = tensor.empty() : tensor<64x64xf32> + %4:2 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%0#1, %1 : tensor<64x64xf32>, tensor<64x64xf32>) outs(%2, %3 : tensor<64x64xf32>, tensor<64x64xf32>) { + ^bb0(%in: f32, %in_0: f32, %out: f32, %out_1: f32): + %6 = arith.mulf %in, %in_0 : f32 + %7 = arith.subf %out, %6 : f32 + %8 = arith.addf %out_1, %in : f32 + linalg.yield %7, %8 : f32, f32 + } -> (tensor<64x64xf32>, tensor<64x64xf32>) + %5 = tensor.empty() : tensor<2048xf32> + %unpack = linalg.unpack %0#0 outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [32] into %5 : tensor<64x32xf32> -> tensor<2048xf32> + return %4#1, %unpack : tensor<64x64xf32>, tensor<2048xf32> + } } module attributes {transform.with_named_sequence} { @@ -269,38 +269,38 @@ module attributes {transform.with_named_sequence} { #map = affine_map<(d0, d1) -> (d0, d1)> module { - func.func @fuse_unpack_consumer_into_scf_forall(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x32xf32>) -> tensor<2048xf32> { - %c4 = arith.constant 4 : index - %c64 = arith.constant 64 : index - %c0 = arith.constant 0 : index - %1 = scf.forall (%arg3, %arg4) = (0, 0) to (64, 32) step (32, 32) shared_outs(%arg5 = %arg2) -> (tensor<64x32xf32>) { - %extracted_slice = tensor.extract_slice %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x32xf32> to tensor<32x32xf32> - %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice : tensor<32x32xf32>) { - ^bb0(%in: f32, %in_16: f32, %out: f32): - %13 = arith.mulf %in, %in_16 : f32 - %14 = arith.addf %out, %13 : f32 - linalg.yield %14 : f32 - } -> tensor<32x32xf32> - scf.forall.in_parallel { - tensor.parallel_insert_slice %3 into %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x32xf32> - } - } - %output = tensor.empty() : tensor<2048xf32> - %unpack = linalg.unpack %1 outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [32] into %output : tensor<64x32xf32> -> tensor<2048xf32> - return %unpack : tensor<2048xf32> + func.func @fuse_unpack_consumer_into_scf_forall(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x32xf32>) -> tensor<2048xf32> { + %c4 = arith.constant 4 : index + %c64 = arith.constant 64 : index + %c0 = arith.constant 0 : index + %1 = scf.forall (%arg3, %arg4) = (0, 0) to (64, 32) step (32, 32) shared_outs(%arg5 = %arg2) -> (tensor<64x32xf32>) { + %extracted_slice = tensor.extract_slice %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x32xf32> to tensor<32x32xf32> + %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice : tensor<32x32xf32>) { + ^bb0(%in: f32, %in_16: f32, %out: f32): + %13 = arith.mulf %in, %in_16 : f32 + %14 = arith.addf %out, %13 : f32 + linalg.yield %14 : f32 + } -> tensor<32x32xf32> + scf.forall.in_parallel { + tensor.parallel_insert_slice %3 into %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x32xf32> + } } + %output = tensor.empty() : tensor<2048xf32> + %unpack = linalg.unpack %1 outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [32] into %output : tensor<64x32xf32> -> tensor<2048xf32> + return %unpack : tensor<2048xf32> + } } - + module attributes {transform.with_named_sequence} { - transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { - %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1 - : (!transform.any_op) -> !transform.any_op - %loop = transform.structured.match ops{["scf.forall"]} in %arg1 - : (!transform.any_op) -> !transform.any_op - %a, %b = transform.test.fuse_consumer %slice_op in (%loop) - : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) - transform.yield - } + transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { + %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %loop = transform.structured.match ops{["scf.forall"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %a, %b = transform.test.fuse_consumer %slice_op in (%loop) + : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } } // CHECK-DAG: #[[UNPACK_RESULT_OFFSET_MAP:.*]] = affine_map<(d0) -> (d0 * 32)> // CHECK-DAG: #[[UNPACK_RESULT_SIZE_MAP:.*]] = affine_map<(d0) -> (1024, d0 * -32 + 2048)> @@ -332,38 +332,38 @@ module attributes {transform.with_named_sequence} { #map = affine_map<(d0, d1) -> (d0, d1)> module { - func.func @fuse_unaligned_unpack_consumer_into_scf_forall(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x32xf32>) -> tensor<2047xf32> { - %c4 = arith.constant 4 : index - %c64 = arith.constant 64 : index - %c0 = arith.constant 0 : index - %1 = scf.forall (%arg3, %arg4) = (0, 0) to (64, 32) step (32, 32) shared_outs(%arg5 = %arg2) -> (tensor<64x32xf32>) { - %extracted_slice = tensor.extract_slice %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x32xf32> to tensor<32x32xf32> - %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice : tensor<32x32xf32>) { - ^bb0(%in: f32, %in_16: f32, %out: f32): - %13 = arith.mulf %in, %in_16 : f32 - %14 = arith.addf %out, %13 : f32 - linalg.yield %14 : f32 - } -> tensor<32x32xf32> - scf.forall.in_parallel { - tensor.parallel_insert_slice %3 into %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x32xf32> - } - } - %output = tensor.empty() : tensor<2047xf32> - %unpack = linalg.unpack %1 outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [32] into %output : tensor<64x32xf32> -> tensor<2047xf32> - return %unpack : tensor<2047xf32> + func.func @fuse_unaligned_unpack_consumer_into_scf_forall(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x32xf32>) -> tensor<2047xf32> { + %c4 = arith.constant 4 : index + %c64 = arith.constant 64 : index + %c0 = arith.constant 0 : index + %1 = scf.forall (%arg3, %arg4) = (0, 0) to (64, 32) step (32, 32) shared_outs(%arg5 = %arg2) -> (tensor<64x32xf32>) { + %extracted_slice = tensor.extract_slice %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x32xf32> to tensor<32x32xf32> + %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice : tensor<32x32xf32>) { + ^bb0(%in: f32, %in_16: f32, %out: f32): + %13 = arith.mulf %in, %in_16 : f32 + %14 = arith.addf %out, %13 : f32 + linalg.yield %14 : f32 + } -> tensor<32x32xf32> + scf.forall.in_parallel { + tensor.parallel_insert_slice %3 into %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x32xf32> + } } + %output = tensor.empty() : tensor<2047xf32> + %unpack = linalg.unpack %1 outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [32] into %output : tensor<64x32xf32> -> tensor<2047xf32> + return %unpack : tensor<2047xf32> + } } - + module attributes {transform.with_named_sequence} { - transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { - %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1 - : (!transform.any_op) -> !transform.any_op - %loop = transform.structured.match ops{["scf.forall"]} in %arg1 - : (!transform.any_op) -> !transform.any_op - %a, %b = transform.test.fuse_consumer %slice_op in (%loop) - : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) - transform.yield - } + transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { + %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %loop = transform.structured.match ops{["scf.forall"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %a, %b = transform.test.fuse_consumer %slice_op in (%loop) + : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } } // CHECK-DAG: #[[UNPACK_RESULT_OFFSET_MAP:.*]] = affine_map<(d0) -> (d0 * 32)> // CHECK-DAG: #[[UNPACK_RESULT_SIZE_MAP:.*]] = affine_map<(d0) -> (1024, d0 * -32 + 2047)> @@ -395,38 +395,38 @@ module attributes {transform.with_named_sequence} { #map = affine_map<(d0, d1) -> (d0, d1)> module { - func.func @fuse_pack_consumer_into_scf_forall(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x32xf32>) -> tensor<4x32x16xf32> { - %c4 = arith.constant 4 : index - %c64 = arith.constant 64 : index - %c0 = arith.constant 0 : index - %1 = scf.forall (%arg3, %arg4) in (2, 2) shared_outs(%arg5 = %arg2) -> (tensor<64x32xf32>) { - %extracted_slice = tensor.extract_slice %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x32xf32> to tensor<32x32xf32> - %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice : tensor<32x32xf32>) { - ^bb0(%in: f32, %in_16: f32, %out: f32): - %13 = arith.mulf %in, %in_16 : f32 - %14 = arith.addf %out, %13 : f32 - linalg.yield %14 : f32 - } -> tensor<32x32xf32> - scf.forall.in_parallel { - tensor.parallel_insert_slice %3 into %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x32xf32> - } - } - %output = tensor.empty() : tensor<4x32x16xf32> - %pack = linalg.pack %1 inner_dims_pos = [0] inner_tiles = [16] into %output : tensor<64x32xf32> -> tensor<4x32x16xf32> - return %pack : tensor<4x32x16xf32> + func.func @fuse_pack_consumer_into_scf_forall(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x32xf32>) -> tensor<4x32x16xf32> { + %c4 = arith.constant 4 : index + %c64 = arith.constant 64 : index + %c0 = arith.constant 0 : index + %1 = scf.forall (%arg3, %arg4) in (2, 2) shared_outs(%arg5 = %arg2) -> (tensor<64x32xf32>) { + %extracted_slice = tensor.extract_slice %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x32xf32> to tensor<32x32xf32> + %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice : tensor<32x32xf32>) { + ^bb0(%in: f32, %in_16: f32, %out: f32): + %13 = arith.mulf %in, %in_16 : f32 + %14 = arith.addf %out, %13 : f32 + linalg.yield %14 : f32 + } -> tensor<32x32xf32> + scf.forall.in_parallel { + tensor.parallel_insert_slice %3 into %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x32xf32> + } } + %output = tensor.empty() : tensor<4x32x16xf32> + %pack = linalg.pack %1 inner_dims_pos = [0] inner_tiles = [16] into %output : tensor<64x32xf32> -> tensor<4x32x16xf32> + return %pack : tensor<4x32x16xf32> + } } - + module attributes {transform.with_named_sequence} { - transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { - %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1 - : (!transform.any_op) -> !transform.any_op - %loop = transform.structured.match ops{["scf.forall"]} in %arg1 - : (!transform.any_op) -> !transform.any_op - %a, %b = transform.test.fuse_consumer %slice_op in (%loop) - : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) - transform.yield - } + transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { + %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %loop = transform.structured.match ops{["scf.forall"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %a, %b = transform.test.fuse_consumer %slice_op in (%loop) + : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } } // CHECK: #[[PACK_RESULT_MAP:.*]] = affine_map<(d0) -> (d0 floordiv 16)> // CHECK: func.func @fuse_pack_consumer_into_scf_forall( @@ -451,6 +451,54 @@ module attributes {transform.with_named_sequence} { // ----- +func.func @fuse_pack_consumer_with_padding_semantic_into_scf_forall(%arg0: tensor<64x32xf32>, %arg1: tensor<64x32xf32>) -> tensor<23x32x3x16xf32> { + %0 = scf.forall (%arg2) in (2) shared_outs(%arg3 = %arg1) -> (tensor<64x32xf32>) { + %extracted_slice = tensor.extract_slice %arg0[0, %arg2] [64, 32] [1, 1] : tensor<64x32xf32> to tensor<64x32xf32> + %extracted_slice_0 = tensor.extract_slice %arg3[0, %arg2] [64, 32] [1, 1] : tensor<64x32xf32> to tensor<64x32xf32> + %2 = linalg.exp ins(%extracted_slice : tensor<64x32xf32>) outs(%extracted_slice_0 : tensor<64x32xf32>) -> tensor<64x32xf32> + scf.forall.in_parallel { + tensor.parallel_insert_slice %2 into %arg3[0, %arg2] [64, 32] [1, 1] : tensor<64x32xf32> into tensor<64x32xf32> + } + } + %1 = tensor.empty() : tensor<23x32x3x16xf32> + %cst = arith.constant 0.000000e+00 : f32 + %pack = linalg.pack %0 padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [3, 16] into %1 : tensor<64x32xf32> -> tensor<23x32x3x16xf32> + return %pack : tensor<23x32x3x16xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %1 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %consumer, %fused_consumer = transform.test.fuse_consumer %0 in(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} +// CHECK: #[[PACK_RESULT_MAP:.*]] = affine_map<(d0) -> (d0 floordiv 16)> +// CHECK: func.func @fuse_pack_consumer_with_padding_semantic_into_scf_forall( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] +// CHECK-DAG: %[[OUT_INIT:.*]] = tensor.empty() : tensor<23x32x3x16xf32> +// CHECK-DAG: %[[PAD_VAL:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %{{.*}}:2 = scf.forall (%[[IV:.*]]) in (2) +// CHECK-SAME: shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG1]], %[[PACK_OUT_ARG:.*]] = %[[OUT_INIT]]) +// CHECK: %[[ELEM_SRC:.*]] = tensor.extract_slice %[[ARG0]][0, %[[IV]]] [64, 32] [1, 1] +// CHECK: %[[ELEM_DEST:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][0, %[[IV]]] [64, 32] [1, 1] +// CHECK: %[[ELEM:.*]] = linalg.exp +// CHECK-SAME: ins(%[[ELEM_SRC]] +// CHECK-SAME: outs(%[[ELEM_DEST]] +// CHECK-DAG: %[[PACK_RESULT_OFFSET:.*]] = affine.apply #[[PACK_RESULT_MAP]](%[[IV]]) +// CHECK-DAG: %[[TILED_PACK_DEST:.*]] = tensor.extract_slice %[[PACK_OUT_ARG]][0, %[[PACK_RESULT_OFFSET]], 0, 0] [22, 2, 3, 16] [1, 1, 1, 1] +// CHECK: %[[TILED_PACK_OUT:.*]] = linalg.pack %[[ELEM]] +// CHECK-SAME: padding_value(%[[PAD_VAL]] : f32) +// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [3, 16] +// CHECK-SAME: into %[[TILED_PACK_DEST]] +// CHECK: scf.forall.in_parallel { +// CHECK: tensor.parallel_insert_slice %[[GENERIC_OUT]] into %[[FIRST_OUT_ARG]][0, %[[IV]]] [64, 32] [1, 1] +// CHECK: tensor.parallel_insert_slice %[[TILED_PACK_OUT]] into %[[PACK_OUT_ARG]][0, %[[PACK_RESULT_OFFSET]], 0, 0] [22, 2, 3, 16] [1, 1, 1, 1] + +// ----- + module { func.func @fuse_add_multiple_tilable_consumers(%arg0: tensor<256x256xf32>, %arg1: tensor<256x256xf32>, %arg2: tensor<256x256xf32>) -> (tensor<256x256xf32>, tensor<256x256xf32>) { %c0 = arith.constant 0 : index @@ -489,7 +537,7 @@ module attributes {transform.with_named_sequence} { // CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<256x256xf32> // CHECK: %[[dest0:.*]] = tensor.empty() : tensor<256x256xf32> // CHECK: %[[LOOP_RESULT:.*]]:3 = scf.for %[[IV1:.*]] = %[[C0]] -// CHECK-SAME: iter_args(%[[FIRST_OUT_ARG:.*]] = %[[dest0]], %[[SECOND_OUT_ARG:.*]] = %[[dest0]], %[[THIRD_OUT_ARG:.*]] = %[[dest0]]) +// CHECK-SAME: iter_args(%[[FIRST_OUT_ARG:.*]] = %[[dest0]], %[[SECOND_OUT_ARG:.*]] = %[[dest0]], %[[THIRD_OUT_ARG:.*]] = %[[dest0]]) // CHECK-SAME: { // CHECK: %[[ADD_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1] // CHECK: %[[ADD_INS0_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[IV1]], 0] [64, 256] [1, 1] @@ -645,7 +693,7 @@ func.func @multi_slice_fusion1(%arg0 : tensor, %arg1 : tensor, % scf.forall.in_parallel { tensor.parallel_insert_slice %generic#0 into %init0[%iv0] [%tilesize] [1] : tensor into tensor tensor.parallel_insert_slice %generic#1 into %init1[%iv0] [%tilesize] [1] : tensor into tensor - } + } } %empty = tensor.empty(%dim0) : tensor %result = linalg.generic { @@ -719,7 +767,7 @@ func.func @multi_slice_fusion2(%arg0 : tensor, %arg1 : tensor, % scf.forall.in_parallel { tensor.parallel_insert_slice %generic0 into %init0[%iv0] [%tilesize] [1] : tensor into tensor tensor.parallel_insert_slice %generic1 into %init1[%iv0] [%tilesize] [1] : tensor into tensor - } + } } %empty = tensor.empty(%dim0) : tensor %result = linalg.generic { From 061d4a2336d958d3bb83fd0ea64d7ce20b9cbc61 Mon Sep 17 00:00:00 2001 From: hanhanW Date: Thu, 17 Jul 2025 11:15:31 -0700 Subject: [PATCH 2/8] Restrict the fusion condition. Signed-off-by: hanhanW --- .../Linalg/Transforms/TilingInterfaceImpl.cpp | 27 ++++++--- .../tile-and-fuse-consumer.mlir | 59 ++++++++++++++----- 2 files changed, 63 insertions(+), 23 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp index fb9ba4ccb14af..f609c65818e43 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp @@ -20,6 +20,7 @@ #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/Interfaces/TilingInterface.h" #include "mlir/Interfaces/ValueBoundsOpInterface.h" #include "llvm/Support/Debug.h" @@ -915,13 +916,25 @@ struct PackOpTiling /*stopCondition=*/nullptr, /*closedUB=*/true); std::optional cstInnerSize = getConstantIntValue(dimAndTileMapping[dim]); - int64_t dimSize = packOp.getSourceType().getDimSize(dim); - // TODO: It could be untiled if the `dimSize` is dynamic. It is a hard - // check to determine if a dimension is tiled or not. - bool isTiled = failed(cstTileSize) || ShapedType::isDynamic(dimSize) || - cstTileSize.value() != dimSize; - if (isTiled && (failed(cstTileSize) || !cstInnerSize || - *cstTileSize % *cstInnerSize != 0)) { + // If a dimension is not tiled, it is always valid to fuse the pack op, + // even if the op has padding semantics. Because it always generates a + // full slice along the dimension. + // TODO: It could be untiled if the `srcDimSize` is dynamic. It is a + // hard check to determine if a dimension is tiled or not. + int64_t srcDimSize = packOp.getSourceType().getDimSize(dim); + bool isTiled = failed(cstTileSize) || + ShapedType::isDynamic(srcDimSize) || + cstTileSize.value() != srcDimSize; + int64_t destDimSize = packOp.getDestType().getDimSize(dim); + bool needPadding = ShapedType::isDynamic(destDimSize) || + !cstInnerSize || + destDimSize * cstInnerSize.value() != srcDimSize; + // Prioritize the case that the op already says that it does not need + // padding. + if (!packOp.getPaddingValue()) { + needPadding = false; + } + if (isTiled && needPadding) { return failure(); } diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir index da3592547e125..3d32ddd9bed84 100644 --- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir +++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir @@ -451,19 +451,19 @@ module attributes {transform.with_named_sequence} { // ----- -func.func @fuse_pack_consumer_with_padding_semantic_into_scf_forall(%arg0: tensor<64x32xf32>, %arg1: tensor<64x32xf32>) -> tensor<23x32x3x16xf32> { - %0 = scf.forall (%arg2) in (2) shared_outs(%arg3 = %arg1) -> (tensor<64x32xf32>) { - %extracted_slice = tensor.extract_slice %arg0[0, %arg2] [64, 32] [1, 1] : tensor<64x32xf32> to tensor<64x32xf32> - %extracted_slice_0 = tensor.extract_slice %arg3[0, %arg2] [64, 32] [1, 1] : tensor<64x32xf32> to tensor<64x32xf32> - %2 = linalg.exp ins(%extracted_slice : tensor<64x32xf32>) outs(%extracted_slice_0 : tensor<64x32xf32>) -> tensor<64x32xf32> +func.func @fuse_pack_consumer_with_padding_semantic_into_scf_forall(%arg0: tensor<64x32xf32>, %arg1: tensor<64x32xf32>) -> tensor<23x2x3x16xf32> { + %0 = scf.forall (%arg2) = (0) to (32) step (16) shared_outs(%arg3 = %arg1) -> (tensor<64x32xf32>) { + %src = tensor.extract_slice %arg0[0, %arg2] [64, 16] [1, 1] : tensor<64x32xf32> to tensor<64x16xf32> + %dest = tensor.extract_slice %arg3[0, %arg2] [64, 16] [1, 1] : tensor<64x32xf32> to tensor<64x16xf32> + %2 = linalg.exp ins(%src : tensor<64x16xf32>) outs(%dest : tensor<64x16xf32>) -> tensor<64x16xf32> scf.forall.in_parallel { - tensor.parallel_insert_slice %2 into %arg3[0, %arg2] [64, 32] [1, 1] : tensor<64x32xf32> into tensor<64x32xf32> + tensor.parallel_insert_slice %2 into %arg3[0, %arg2] [64, 16] [1, 1] : tensor<64x16xf32> into tensor<64x32xf32> } } - %1 = tensor.empty() : tensor<23x32x3x16xf32> + %1 = tensor.empty() : tensor<23x2x3x16xf32> %cst = arith.constant 0.000000e+00 : f32 - %pack = linalg.pack %0 padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [3, 16] into %1 : tensor<64x32xf32> -> tensor<23x32x3x16xf32> - return %pack : tensor<23x32x3x16xf32> + %pack = linalg.pack %0 padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [3, 16] into %1 : tensor<64x32xf32> -> tensor<23x2x3x16xf32> + return %pack : tensor<23x2x3x16xf32> } module attributes {transform.with_named_sequence} { @@ -478,24 +478,51 @@ module attributes {transform.with_named_sequence} { // CHECK: func.func @fuse_pack_consumer_with_padding_semantic_into_scf_forall( // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] -// CHECK-DAG: %[[OUT_INIT:.*]] = tensor.empty() : tensor<23x32x3x16xf32> +// CHECK-DAG: %[[OUT_INIT:.*]] = tensor.empty() : tensor<23x2x3x16xf32> // CHECK-DAG: %[[PAD_VAL:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK: %{{.*}}:2 = scf.forall (%[[IV:.*]]) in (2) +// CHECK: %{{.*}}:2 = scf.forall (%[[IV:.*]]) = (0) to (32) step (16) // CHECK-SAME: shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG1]], %[[PACK_OUT_ARG:.*]] = %[[OUT_INIT]]) -// CHECK: %[[ELEM_SRC:.*]] = tensor.extract_slice %[[ARG0]][0, %[[IV]]] [64, 32] [1, 1] -// CHECK: %[[ELEM_DEST:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][0, %[[IV]]] [64, 32] [1, 1] +// CHECK: %[[ELEM_SRC:.*]] = tensor.extract_slice %[[ARG0]][0, %[[IV]]] [64, 16] [1, 1] +// CHECK: %[[ELEM_DEST:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][0, %[[IV]]] [64, 16] [1, 1] // CHECK: %[[ELEM:.*]] = linalg.exp // CHECK-SAME: ins(%[[ELEM_SRC]] // CHECK-SAME: outs(%[[ELEM_DEST]] // CHECK-DAG: %[[PACK_RESULT_OFFSET:.*]] = affine.apply #[[PACK_RESULT_MAP]](%[[IV]]) -// CHECK-DAG: %[[TILED_PACK_DEST:.*]] = tensor.extract_slice %[[PACK_OUT_ARG]][0, %[[PACK_RESULT_OFFSET]], 0, 0] [22, 2, 3, 16] [1, 1, 1, 1] +// CHECK-DAG: %[[TILED_PACK_DEST:.*]] = tensor.extract_slice %[[PACK_OUT_ARG]][0, %[[PACK_RESULT_OFFSET]], 0, 0] [22, 1, 3, 16] [1, 1, 1, 1] // CHECK: %[[TILED_PACK_OUT:.*]] = linalg.pack %[[ELEM]] // CHECK-SAME: padding_value(%[[PAD_VAL]] : f32) // CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [3, 16] // CHECK-SAME: into %[[TILED_PACK_DEST]] // CHECK: scf.forall.in_parallel { -// CHECK: tensor.parallel_insert_slice %[[GENERIC_OUT]] into %[[FIRST_OUT_ARG]][0, %[[IV]]] [64, 32] [1, 1] -// CHECK: tensor.parallel_insert_slice %[[TILED_PACK_OUT]] into %[[PACK_OUT_ARG]][0, %[[PACK_RESULT_OFFSET]], 0, 0] [22, 2, 3, 16] [1, 1, 1, 1] +// CHECK: tensor.parallel_insert_slice %[[GENERIC_OUT]] into %[[FIRST_OUT_ARG]][0, %[[IV]]] [64, 16] [1, 1] +// CHECK: tensor.parallel_insert_slice %[[TILED_PACK_OUT]] into %[[PACK_OUT_ARG]][0, %[[PACK_RESULT_OFFSET]], 0, 0] [22, 1, 3, 16] [1, 1, 1, 1] + +// ----- + +func.func @nofuse_pack_consumer_with_padding_semantic_into_scf_forall(%arg0: tensor<64x32xf32>, %arg1: tensor<64x32xf32>) -> tensor<23x32x3x16xf32> { + %0 = scf.forall (%arg2) = (0) to (32) step (16) shared_outs(%arg3 = %arg1) -> (tensor<64x32xf32>) { + %src = tensor.extract_slice %arg0[0, %arg2] [64, 16] [1, 1] : tensor<64x32xf32> to tensor<64x16xf32> + %dest = tensor.extract_slice %arg3[0, %arg2] [64, 16] [1, 1] : tensor<64x32xf32> to tensor<64x16xf32> + %2 = linalg.exp ins(%src : tensor<64x16xf32>) outs(%dest : tensor<64x16xf32>) -> tensor<64x16xf32> + scf.forall.in_parallel { + // expected-error @below {{failed to fuse consumer of slice}} + tensor.parallel_insert_slice %2 into %arg3[0, %arg2] [64, 16] [1, 1] : tensor<64x16xf32> into tensor<64x32xf32> + } + } + %1 = tensor.empty() : tensor<23x32x3x16xf32> + %cst = arith.constant 0.000000e+00 : f32 + %pack = linalg.pack %0 padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [3, 16] into %1 : tensor<64x32xf32> -> tensor<23x32x3x16xf32> + return %pack : tensor<23x32x3x16xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %1 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %consumer, %fused_consumer = transform.test.fuse_consumer %0 in(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} // ----- From 864a9a55c6fafea9ff41f10e9ed757bfce0409cc Mon Sep 17 00:00:00 2001 From: hanhanW Date: Thu, 17 Jul 2025 11:16:35 -0700 Subject: [PATCH 3/8] Fix IR bug Signed-off-by: hanhanW --- .../Interfaces/TilingInterface/tile-and-fuse-consumer.mlir | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir index 3d32ddd9bed84..fc64733d7a887 100644 --- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir +++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir @@ -399,7 +399,7 @@ module { %c4 = arith.constant 4 : index %c64 = arith.constant 64 : index %c0 = arith.constant 0 : index - %1 = scf.forall (%arg3, %arg4) in (2, 2) shared_outs(%arg5 = %arg2) -> (tensor<64x32xf32>) { + %1 = scf.forall (%arg3, %arg4) in (2, 1) shared_outs(%arg5 = %arg2) -> (tensor<64x32xf32>) { %extracted_slice = tensor.extract_slice %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x32xf32> to tensor<32x32xf32> %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice : tensor<32x32xf32>) { ^bb0(%in: f32, %in_16: f32, %out: f32): @@ -434,7 +434,7 @@ module attributes {transform.with_named_sequence} { // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<32x32xf32> // CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<64x32xf32>) // CHECK: %[[OUT_INIT:.*]] = tensor.empty() : tensor<4x32x16xf32> -// CHECK: %[[FINAL_RESULT:.*]]:2 = scf.forall (%[[IV1:.*]], %[[IV2:.*]]) in (2, 2) +// CHECK: %[[FINAL_RESULT:.*]]:2 = scf.forall (%[[IV1:.*]], %[[IV2:.*]]) in (2, 1) // CHECK-SAME: shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG2]], %[[PACK_OUT_ARG:.*]] = %[[OUT_INIT]]) // CHECK-SAME: { // CHECK: %[[GENERIC_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1] From 583639db7bacdc197623df0e20185b865ffa095d Mon Sep 17 00:00:00 2001 From: hanhanW Date: Thu, 17 Jul 2025 12:24:57 -0700 Subject: [PATCH 4/8] Fix the fusion logic and add more lit tests Signed-off-by: hanhanW --- .../Linalg/Transforms/TilingInterfaceImpl.cpp | 59 ++++++---- .../tile-and-fuse-consumer.mlir | 107 ++++++++++++++++-- 2 files changed, 137 insertions(+), 29 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp index f609c65818e43..0513fbfe28148 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp @@ -895,48 +895,63 @@ struct PackOpTiling packOp.getDimAndTileMapping(); for (auto dim : llvm::seq(packOp.getSourceRank())) { if (dimAndTileMapping.count(dim)) { - // Currently fusing `packOp` as consumer only expects perfect tiling - // scenario because even if without padding semantic, the `packOp` may - // also yield incomplete tiles. E.g. tensor<30xf32> -> tensor<5x6xf32>, - // where the `tileSize` from operand of `packOp` is 5, which is not - // exactly divided by `innerTile`(=6) of `packOp`. As the result: - // 1. the first slice is extracted from (0) to (4) and inserted into - // (0,0)~(0,4) at first row. - // 2. the second slice is extracted from (5) to (9) and SHOULD BE - // respectively inserted into two rows with different length, including - // first row: (0,5) and second row (1,0)~(1,3). - // It is hard to coordinate them, thus adding below constraint to bypass - // them temporarily. In another word, we can only support tiling with - // consumer if the tile size for the producer is either a multiple of - // the inner tile size for the packed dimensions or the dimension is not - // tiled at this moment. FailureOr cstTileSize = ValueBoundsConstraintSet::computeConstantBound( presburger::BoundType::UB, sizes[dim], /*stopCondition=*/nullptr, /*closedUB=*/true); std::optional cstInnerSize = getConstantIntValue(dimAndTileMapping[dim]); + // If a dimension is not tiled, it is always valid to fuse the pack op, // even if the op has padding semantics. Because it always generates a // full slice along the dimension. // TODO: It could be untiled if the `srcDimSize` is dynamic. It is a // hard check to determine if a dimension is tiled or not. int64_t srcDimSize = packOp.getSourceType().getDimSize(dim); + int64_t destDimSize = packOp.getDestType().getDimSize(dim); bool isTiled = failed(cstTileSize) || ShapedType::isDynamic(srcDimSize) || cstTileSize.value() != srcDimSize; - int64_t destDimSize = packOp.getDestType().getDimSize(dim); - bool needPadding = ShapedType::isDynamic(destDimSize) || + if (!isTiled) { + outerDimOffsets.push_back(offsets[dim]); + if (ShapedType::isStatic(destDimSize)) { + outerDimSizes.push_back(b.getIndexAttr(destDimSize)); + } else { + outerDimSizes.push_back( + b.createOrFold(loc, packOp.getDest(), dim)); + } + continue; + } + + // If the dimension needs padding, it is not supported because there are + // iterations that only write padding values to the whole tile. The + // consumer fusion is driven by the source, so it is not possible to map + // an empty slice to the tile. + bool needExtraPadding = ShapedType::isDynamic(destDimSize) || !cstInnerSize || destDimSize * cstInnerSize.value() != srcDimSize; // Prioritize the case that the op already says that it does not need // padding. - if (!packOp.getPaddingValue()) { - needPadding = false; - } - if (isTiled && needPadding) { + if (!packOp.getPaddingValue()) + needExtraPadding = false; + if (needExtraPadding) + return failure(); + + // Currently fusing `packOp` as consumer only expects perfect tiling + // scenario because even if without padding semantic, the `packOp` may + // also yield incomplete tiles. E.g. tensor<30xf32> -> tensor<5x6xf32>, + // where the `tileSize` from operand of `packOp` is 5, which is not + // exactly divided by `innerTile`(=6) of `packOp`. As the result: + // 1. the first slice is extracted from (0) to (4) and inserted into + // (0,0)~(0,4) at first row. + // 2. the second slice is extracted from (5) to (9) and SHOULD BE + // respectively inserted into two rows with different length, including + // first row: (0,5) and second row (1,0)~(1,3). + // It is hard to coordinate them, thus adding below constraint to bypass + // them temporarily. + if ((failed(cstTileSize) || !cstInnerSize || + *cstTileSize % *cstInnerSize != 0)) return failure(); - } using AV = affine::AffineValueExpr; affine::AffineBuilder ab(b, loc); diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir index fc64733d7a887..daa8341ca5a28 100644 --- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir +++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir @@ -395,7 +395,7 @@ module attributes {transform.with_named_sequence} { #map = affine_map<(d0, d1) -> (d0, d1)> module { - func.func @fuse_pack_consumer_into_scf_forall(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x32xf32>) -> tensor<4x32x16xf32> { + func.func @fuse_perfect_tiling_pack_consumer(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x32xf32>) -> tensor<4x32x16xf32> { %c4 = arith.constant 4 : index %c64 = arith.constant 64 : index %c0 = arith.constant 0 : index @@ -429,7 +429,7 @@ module attributes {transform.with_named_sequence} { } } // CHECK: #[[PACK_RESULT_MAP:.*]] = affine_map<(d0) -> (d0 floordiv 16)> -// CHECK: func.func @fuse_pack_consumer_into_scf_forall( +// CHECK: func.func @fuse_perfect_tiling_pack_consumer( // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<32x32xf32> // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<32x32xf32> // CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<64x32xf32>) @@ -451,7 +451,10 @@ module attributes {transform.with_named_sequence} { // ----- -func.func @fuse_pack_consumer_with_padding_semantic_into_scf_forall(%arg0: tensor<64x32xf32>, %arg1: tensor<64x32xf32>) -> tensor<23x2x3x16xf32> { +// It is valid to fuse the pack op with padding semantics if the dimension does +// not need padding. + +func.func @fuse_pack_consumer_with_padding_semantics(%arg0: tensor<64x32xf32>, %arg1: tensor<64x32xf32>) -> tensor<23x2x3x16xf32> { %0 = scf.forall (%arg2) = (0) to (32) step (16) shared_outs(%arg3 = %arg1) -> (tensor<64x32xf32>) { %src = tensor.extract_slice %arg0[0, %arg2] [64, 16] [1, 1] : tensor<64x32xf32> to tensor<64x16xf32> %dest = tensor.extract_slice %arg3[0, %arg2] [64, 16] [1, 1] : tensor<64x32xf32> to tensor<64x16xf32> @@ -475,7 +478,7 @@ module attributes {transform.with_named_sequence} { } } // CHECK: #[[PACK_RESULT_MAP:.*]] = affine_map<(d0) -> (d0 floordiv 16)> -// CHECK: func.func @fuse_pack_consumer_with_padding_semantic_into_scf_forall( +// CHECK: func.func @fuse_pack_consumer_with_padding_semantics( // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] // CHECK-DAG: %[[OUT_INIT:.*]] = tensor.empty() : tensor<23x2x3x16xf32> @@ -488,18 +491,72 @@ module attributes {transform.with_named_sequence} { // CHECK-SAME: ins(%[[ELEM_SRC]] // CHECK-SAME: outs(%[[ELEM_DEST]] // CHECK-DAG: %[[PACK_RESULT_OFFSET:.*]] = affine.apply #[[PACK_RESULT_MAP]](%[[IV]]) -// CHECK-DAG: %[[TILED_PACK_DEST:.*]] = tensor.extract_slice %[[PACK_OUT_ARG]][0, %[[PACK_RESULT_OFFSET]], 0, 0] [22, 1, 3, 16] [1, 1, 1, 1] +// CHECK-DAG: %[[TILED_PACK_DEST:.*]] = tensor.extract_slice %[[PACK_OUT_ARG]][0, %[[PACK_RESULT_OFFSET]], 0, 0] [23, 1, 3, 16] [1, 1, 1, 1] // CHECK: %[[TILED_PACK_OUT:.*]] = linalg.pack %[[ELEM]] // CHECK-SAME: padding_value(%[[PAD_VAL]] : f32) // CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [3, 16] // CHECK-SAME: into %[[TILED_PACK_DEST]] // CHECK: scf.forall.in_parallel { // CHECK: tensor.parallel_insert_slice %[[GENERIC_OUT]] into %[[FIRST_OUT_ARG]][0, %[[IV]]] [64, 16] [1, 1] -// CHECK: tensor.parallel_insert_slice %[[TILED_PACK_OUT]] into %[[PACK_OUT_ARG]][0, %[[PACK_RESULT_OFFSET]], 0, 0] [22, 1, 3, 16] [1, 1, 1, 1] +// CHECK: tensor.parallel_insert_slice %[[TILED_PACK_OUT]] into %[[PACK_OUT_ARG]][0, %[[PACK_RESULT_OFFSET]], 0, 0] [23, 1, 3, 16] [1, 1, 1, 1] // ----- -func.func @nofuse_pack_consumer_with_padding_semantic_into_scf_forall(%arg0: tensor<64x32xf32>, %arg1: tensor<64x32xf32>) -> tensor<23x32x3x16xf32> { +// It is valid to fuse the pack if the dimension is not tiled even when it needs +// extra padding. + +func.func @fuse_pack_consumer_with_untiled_extra_padding(%arg0: tensor<64x32xf32>, %arg1: tensor<64x32xf32>) -> tensor<33x2x3x16xf32> { + %0 = scf.forall (%arg2) = (0) to (32) step (16) shared_outs(%arg3 = %arg1) -> (tensor<64x32xf32>) { + %src = tensor.extract_slice %arg0[0, %arg2] [64, 16] [1, 1] : tensor<64x32xf32> to tensor<64x16xf32> + %dest = tensor.extract_slice %arg3[0, %arg2] [64, 16] [1, 1] : tensor<64x32xf32> to tensor<64x16xf32> + %2 = linalg.exp ins(%src : tensor<64x16xf32>) outs(%dest : tensor<64x16xf32>) -> tensor<64x16xf32> + scf.forall.in_parallel { + tensor.parallel_insert_slice %2 into %arg3[0, %arg2] [64, 16] [1, 1] : tensor<64x16xf32> into tensor<64x32xf32> + } + } + %1 = tensor.empty() : tensor<33x2x3x16xf32> + %cst = arith.constant 0.000000e+00 : f32 + %pack = linalg.pack %0 padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [3, 16] into %1 : tensor<64x32xf32> -> tensor<33x2x3x16xf32> + return %pack : tensor<33x2x3x16xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %1 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %consumer, %fused_consumer = transform.test.fuse_consumer %0 in(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} +// CHECK: #[[PACK_RESULT_MAP:.*]] = affine_map<(d0) -> (d0 floordiv 16)> +// CHECK: func.func @fuse_pack_consumer_with_untiled_extra_padding( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] +// CHECK-DAG: %[[OUT_INIT:.*]] = tensor.empty() : tensor<33x2x3x16xf32> +// CHECK-DAG: %[[PAD_VAL:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %{{.*}}:2 = scf.forall (%[[IV:.*]]) = (0) to (32) step (16) +// CHECK-SAME: shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG1]], %[[PACK_OUT_ARG:.*]] = %[[OUT_INIT]]) +// CHECK: %[[ELEM_SRC:.*]] = tensor.extract_slice %[[ARG0]][0, %[[IV]]] [64, 16] [1, 1] +// CHECK: %[[ELEM_DEST:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][0, %[[IV]]] [64, 16] [1, 1] +// CHECK: %[[ELEM:.*]] = linalg.exp +// CHECK-SAME: ins(%[[ELEM_SRC]] +// CHECK-SAME: outs(%[[ELEM_DEST]] +// CHECK-DAG: %[[PACK_RESULT_OFFSET:.*]] = affine.apply #[[PACK_RESULT_MAP]](%[[IV]]) +// CHECK-DAG: %[[TILED_PACK_DEST:.*]] = tensor.extract_slice %[[PACK_OUT_ARG]][0, %[[PACK_RESULT_OFFSET]], 0, 0] [33, 1, 3, 16] [1, 1, 1, 1] +// CHECK: %[[TILED_PACK_OUT:.*]] = linalg.pack %[[ELEM]] +// CHECK-SAME: padding_value(%[[PAD_VAL]] : f32) +// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [3, 16] +// CHECK-SAME: into %[[TILED_PACK_DEST]] +// CHECK: scf.forall.in_parallel { +// CHECK: tensor.parallel_insert_slice %[[GENERIC_OUT]] into %[[FIRST_OUT_ARG]][0, %[[IV]]] [64, 16] [1, 1] +// CHECK: tensor.parallel_insert_slice %[[TILED_PACK_OUT]] into %[[PACK_OUT_ARG]][0, %[[PACK_RESULT_OFFSET]], 0, 0] [33, 1, 3, 16] [1, 1, 1, 1] + +// ----- + +// If the dimension is tiled and it needs extra padding, do not fuse the pack +// op. + +func.func @nofuse_pack_consumer_with_extra_padding(%arg0: tensor<64x32xf32>, %arg1: tensor<64x32xf32>) -> tensor<23x32x3x16xf32> { %0 = scf.forall (%arg2) = (0) to (32) step (16) shared_outs(%arg3 = %arg1) -> (tensor<64x32xf32>) { %src = tensor.extract_slice %arg0[0, %arg2] [64, 16] [1, 1] : tensor<64x32xf32> to tensor<64x16xf32> %dest = tensor.extract_slice %arg3[0, %arg2] [64, 16] [1, 1] : tensor<64x32xf32> to tensor<64x16xf32> @@ -526,6 +583,42 @@ module attributes {transform.with_named_sequence} { // ----- +// Imperfect tiling is not supported in pack op consumer fusion. + +#map = affine_map<(d0) -> (d0 * 5)> +#map1 = affine_map<(d0) -> (d0)> +func.func @nofuse_pack_with_imperfect_tiling(%arg0: tensor<30xf32>) -> tensor<5x6xf32> { + %0 = tensor.empty() : tensor<30xf32> + %1 = scf.forall (%arg1) in (6) shared_outs(%arg2 = %0) -> (tensor<30xf32>) { + %3 = affine.apply #map(%arg1) + %extracted_slice = tensor.extract_slice %arg0[%3] [5] [1] : tensor<30xf32> to tensor<5xf32> + %extracted_slice_0 = tensor.extract_slice %arg2[%3] [5] [1] : tensor<30xf32> to tensor<5xf32> + %4 = linalg.generic {indexing_maps = [#map1, #map1], iterator_types = ["parallel"]} ins(%extracted_slice : tensor<5xf32>) outs(%extracted_slice_0 : tensor<5xf32>) { + ^bb0(%in: f32, %out: f32): + %5 = arith.addf %in, %in : f32 + linalg.yield %5 : f32 + } -> tensor<5xf32> + scf.forall.in_parallel { + // expected-error @below {{failed to fuse consumer of slice}} + tensor.parallel_insert_slice %4 into %arg2[%3] [5] [1] : tensor<5xf32> into tensor<30xf32> + } + } + %2 = tensor.empty() : tensor<5x6xf32> + %pack = linalg.pack %1 outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [6] into %2 : tensor<30xf32> -> tensor<5x6xf32> + return %pack : tensor<5x6xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %1 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %consumer, %fused_consumer = transform.test.fuse_consumer %0 in(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} + +// ----- + module { func.func @fuse_add_multiple_tilable_consumers(%arg0: tensor<256x256xf32>, %arg1: tensor<256x256xf32>, %arg2: tensor<256x256xf32>) -> (tensor<256x256xf32>, tensor<256x256xf32>) { %c0 = arith.constant 0 : index From a5305bdde964b336273156b5e5b501618847af9d Mon Sep 17 00:00:00 2001 From: hanhanW Date: Thu, 17 Jul 2025 12:36:01 -0700 Subject: [PATCH 5/8] Recover the comment and add one more test. Signed-off-by: hanhanW --- .../Linalg/Transforms/TilingInterfaceImpl.cpp | 8 +-- .../tile-and-fuse-consumer.mlir | 49 +++++++++++++++++++ 2 files changed, 54 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp index 0513fbfe28148..bc3e71d3b9b6f 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp @@ -946,9 +946,11 @@ struct PackOpTiling // (0,0)~(0,4) at first row. // 2. the second slice is extracted from (5) to (9) and SHOULD BE // respectively inserted into two rows with different length, including - // first row: (0,5) and second row (1,0)~(1,3). - // It is hard to coordinate them, thus adding below constraint to bypass - // them temporarily. + // first row: (0,5) and second row (1,0)~(1,3). It is hard to coordinate + // them, thus adding below constraint to bypass them temporarily. In + // another word, we can only support tiling with consumer if the tile + // size for the producer is a multiple of the inner tile size for the + // packed dimensions at this moment. if ((failed(cstTileSize) || !cstInnerSize || *cstTileSize % *cstInnerSize != 0)) return failure(); diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir index daa8341ca5a28..ef9b9454c946e 100644 --- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir +++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir @@ -451,6 +451,55 @@ module attributes {transform.with_named_sequence} { // ----- +// It is valid to fuse the pack op in perfect tiling scenario when the dimension +// is dynamic and padding is not needed. + +func.func @fuse_pack_consumer_with_no_pad_dynamic_dim(%arg0: tensor<64x?xf32>, %arg1: tensor<64x?xf32>, %1: tensor<64x?x16xf32>) -> tensor<64x?x16xf32> { + %c1 = arith.constant 1 : index + %d1 = tensor.dim %arg0, %c1 : tensor<64x?xf32> + %0 = scf.forall (%arg2) = (0) to (%d1) step (16) shared_outs(%arg3 = %arg1) -> (tensor<64x?xf32>) { + %src = tensor.extract_slice %arg0[0, %arg2] [64, 16] [1, 1] : tensor<64x?xf32> to tensor<64x16xf32> + %dest = tensor.extract_slice %arg3[0, %arg2] [64, 16] [1, 1] : tensor<64x?xf32> to tensor<64x16xf32> + %2 = linalg.exp ins(%src : tensor<64x16xf32>) outs(%dest : tensor<64x16xf32>) -> tensor<64x16xf32> + scf.forall.in_parallel { + tensor.parallel_insert_slice %2 into %arg3[0, %arg2] [64, 16] [1, 1] : tensor<64x16xf32> into tensor<64x?xf32> + } + } + %pack = linalg.pack %0 inner_dims_pos = [1] inner_tiles = [16] into %1 : tensor<64x?xf32> -> tensor<64x?x16xf32> + return %pack : tensor<64x?x16xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %1 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %consumer, %fused_consumer = transform.test.fuse_consumer %0 in(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} +// CHECK: #[[PACK_RESULT_MAP:.*]] = affine_map<(d0) -> (d0 floordiv 16)> +// CHECK: func.func @fuse_pack_consumer_with_no_pad_dynamic_dim( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]] +// CHECK: %{{.*}}:2 = scf.forall (%[[IV:.*]]) = (0) to (%{{.+}}) step (16) +// CHECK-SAME: shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG1]], %[[PACK_OUT_ARG:.*]] = %[[ARG2]]) +// CHECK: %[[ELEM_SRC:.*]] = tensor.extract_slice %[[ARG0]][0, %[[IV]]] [64, 16] [1, 1] +// CHECK: %[[ELEM_DEST:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][0, %[[IV]]] [64, 16] [1, 1] +// CHECK: %[[ELEM:.*]] = linalg.exp +// CHECK-SAME: ins(%[[ELEM_SRC]] +// CHECK-SAME: outs(%[[ELEM_DEST]] +// CHECK-DAG: %[[PACK_RESULT_OFFSET:.*]] = affine.apply #[[PACK_RESULT_MAP]](%[[IV]]) +// CHECK-DAG: %[[TILED_PACK_DEST:.*]] = tensor.extract_slice %[[PACK_OUT_ARG]][0, %[[PACK_RESULT_OFFSET]], 0] [64, 1, 16] [1, 1, 1] +// CHECK: %[[PACK:.*]] = linalg.pack %[[ELEM]] +// CHECK-SAME: inner_dims_pos = [1] inner_tiles = [16] +// CHECK-SAME: into %[[TILED_PACK_DEST]] +// CHECK: scf.forall.in_parallel { +// CHECK: tensor.parallel_insert_slice %[[ELEM]] into %[[FIRST_OUT_ARG]][0, %[[IV]]] [64, 16] [1, 1] +// CHECK: tensor.parallel_insert_slice %[[PACK]] into %[[PACK_OUT_ARG]][0, %[[PACK_RESULT_OFFSET]], 0] [64, 1, 16] [1, 1, 1] + +// ----- + // It is valid to fuse the pack op with padding semantics if the dimension does // not need padding. From 183b79508ece80deccda397c4e40d46733019d3b Mon Sep 17 00:00:00 2001 From: hanhanW Date: Thu, 17 Jul 2025 13:40:06 -0700 Subject: [PATCH 6/8] format Signed-off-by: hanhanW --- mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp index bc3e71d3b9b6f..5a10883a6043c 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp @@ -927,9 +927,9 @@ struct PackOpTiling // iterations that only write padding values to the whole tile. The // consumer fusion is driven by the source, so it is not possible to map // an empty slice to the tile. - bool needExtraPadding = ShapedType::isDynamic(destDimSize) || - !cstInnerSize || - destDimSize * cstInnerSize.value() != srcDimSize; + bool needExtraPadding = + ShapedType::isDynamic(destDimSize) || !cstInnerSize || + destDimSize * cstInnerSize.value() != srcDimSize; // Prioritize the case that the op already says that it does not need // padding. if (!packOp.getPaddingValue()) From 9dcfb2fc3b75db0a3b80b27eecbfe0eed73ea0c4 Mon Sep 17 00:00:00 2001 From: hanhanW Date: Thu, 17 Jul 2025 14:32:46 -0700 Subject: [PATCH 7/8] Fix comments Signed-off-by: hanhanW --- .../Interfaces/TilingInterface/tile-and-fuse-consumer.mlir | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir index ef9b9454c946e..fa188b47fc031 100644 --- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir +++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir @@ -500,8 +500,8 @@ module attributes {transform.with_named_sequence} { // ----- -// It is valid to fuse the pack op with padding semantics if the dimension does -// not need padding. +// It is valid to fuse the pack op with padding semantics if the tiled +// dimensions do not need padding. func.func @fuse_pack_consumer_with_padding_semantics(%arg0: tensor<64x32xf32>, %arg1: tensor<64x32xf32>) -> tensor<23x2x3x16xf32> { %0 = scf.forall (%arg2) = (0) to (32) step (16) shared_outs(%arg3 = %arg1) -> (tensor<64x32xf32>) { From 99df9ef83bf812ff95f28c89da3466dc56f43a14 Mon Sep 17 00:00:00 2001 From: hanhanW Date: Thu, 17 Jul 2025 15:40:57 -0700 Subject: [PATCH 8/8] 23 -> 22 Signed-off-by: hanhanW --- .../TilingInterface/tile-and-fuse-consumer.mlir | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir index fa188b47fc031..7b0a8494a8acb 100644 --- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir +++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir @@ -503,7 +503,7 @@ module attributes {transform.with_named_sequence} { // It is valid to fuse the pack op with padding semantics if the tiled // dimensions do not need padding. -func.func @fuse_pack_consumer_with_padding_semantics(%arg0: tensor<64x32xf32>, %arg1: tensor<64x32xf32>) -> tensor<23x2x3x16xf32> { +func.func @fuse_pack_consumer_with_padding_semantics(%arg0: tensor<64x32xf32>, %arg1: tensor<64x32xf32>) -> tensor<22x2x3x16xf32> { %0 = scf.forall (%arg2) = (0) to (32) step (16) shared_outs(%arg3 = %arg1) -> (tensor<64x32xf32>) { %src = tensor.extract_slice %arg0[0, %arg2] [64, 16] [1, 1] : tensor<64x32xf32> to tensor<64x16xf32> %dest = tensor.extract_slice %arg3[0, %arg2] [64, 16] [1, 1] : tensor<64x32xf32> to tensor<64x16xf32> @@ -512,10 +512,10 @@ func.func @fuse_pack_consumer_with_padding_semantics(%arg0: tensor<64x32xf32>, % tensor.parallel_insert_slice %2 into %arg3[0, %arg2] [64, 16] [1, 1] : tensor<64x16xf32> into tensor<64x32xf32> } } - %1 = tensor.empty() : tensor<23x2x3x16xf32> + %1 = tensor.empty() : tensor<22x2x3x16xf32> %cst = arith.constant 0.000000e+00 : f32 - %pack = linalg.pack %0 padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [3, 16] into %1 : tensor<64x32xf32> -> tensor<23x2x3x16xf32> - return %pack : tensor<23x2x3x16xf32> + %pack = linalg.pack %0 padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [3, 16] into %1 : tensor<64x32xf32> -> tensor<22x2x3x16xf32> + return %pack : tensor<22x2x3x16xf32> } module attributes {transform.with_named_sequence} { @@ -530,7 +530,7 @@ module attributes {transform.with_named_sequence} { // CHECK: func.func @fuse_pack_consumer_with_padding_semantics( // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] -// CHECK-DAG: %[[OUT_INIT:.*]] = tensor.empty() : tensor<23x2x3x16xf32> +// CHECK-DAG: %[[OUT_INIT:.*]] = tensor.empty() : tensor<22x2x3x16xf32> // CHECK-DAG: %[[PAD_VAL:.*]] = arith.constant 0.000000e+00 : f32 // CHECK: %{{.*}}:2 = scf.forall (%[[IV:.*]]) = (0) to (32) step (16) // CHECK-SAME: shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG1]], %[[PACK_OUT_ARG:.*]] = %[[OUT_INIT]]) @@ -540,14 +540,14 @@ module attributes {transform.with_named_sequence} { // CHECK-SAME: ins(%[[ELEM_SRC]] // CHECK-SAME: outs(%[[ELEM_DEST]] // CHECK-DAG: %[[PACK_RESULT_OFFSET:.*]] = affine.apply #[[PACK_RESULT_MAP]](%[[IV]]) -// CHECK-DAG: %[[TILED_PACK_DEST:.*]] = tensor.extract_slice %[[PACK_OUT_ARG]][0, %[[PACK_RESULT_OFFSET]], 0, 0] [23, 1, 3, 16] [1, 1, 1, 1] +// CHECK-DAG: %[[TILED_PACK_DEST:.*]] = tensor.extract_slice %[[PACK_OUT_ARG]][0, %[[PACK_RESULT_OFFSET]], 0, 0] [22, 1, 3, 16] [1, 1, 1, 1] // CHECK: %[[TILED_PACK_OUT:.*]] = linalg.pack %[[ELEM]] // CHECK-SAME: padding_value(%[[PAD_VAL]] : f32) // CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [3, 16] // CHECK-SAME: into %[[TILED_PACK_DEST]] // CHECK: scf.forall.in_parallel { // CHECK: tensor.parallel_insert_slice %[[GENERIC_OUT]] into %[[FIRST_OUT_ARG]][0, %[[IV]]] [64, 16] [1, 1] -// CHECK: tensor.parallel_insert_slice %[[TILED_PACK_OUT]] into %[[PACK_OUT_ARG]][0, %[[PACK_RESULT_OFFSET]], 0, 0] [23, 1, 3, 16] [1, 1, 1, 1] +// CHECK: tensor.parallel_insert_slice %[[TILED_PACK_OUT]] into %[[PACK_OUT_ARG]][0, %[[PACK_RESULT_OFFSET]], 0, 0] [22, 1, 3, 16] [1, 1, 1, 1] // -----