Skip to content

Commit 4d52e5e

Browse files
committed
[mlir][linalg] Support pack consumer fusion with padding semantic for perfect tiling.
If the op does not generate extra padding values (i.e., `destDimSize == ceilDiv(srcDimSize, innerTileSize)`), it is valid to fuse the pack consumer op. Signed-off-by: hanhanW <[email protected]>
1 parent 9878ef3 commit 4d52e5e

File tree

2 files changed

+40
-23
lines changed

2 files changed

+40
-23
lines changed

mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -932,13 +932,18 @@ struct PackOpTiling
932932
continue;
933933
}
934934

935-
// If the dimension needs padding, it is not supported because there are
936-
// iterations that only write padding values to the whole tile. The
937-
// consumer fusion is driven by the source, so it is not possible to map
938-
// an empty slice to the tile.
939-
bool needExtraPadding =
940-
ShapedType::isDynamic(destDimSize) || !cstInnerSize ||
941-
destDimSize * cstInnerSize.value() != srcDimSize;
935+
// If the dimension needs extra padding, it is not supported because
936+
// there are iterations that only write padding values to the whole
937+
// tile. The consumer fusion is driven by the source, so it is not
938+
// possible to map an empty slice to the tile. Extra padding is not a
939+
// regular form, and the implementation is being conversative.
940+
bool needExtraPadding = true;
941+
if (!ShapedType::isDynamic(srcDimSize) &&
942+
!ShapedType::isDynamic(destDimSize) && cstInnerSize) {
943+
needExtraPadding =
944+
destDimSize >
945+
(srcDimSize + cstInnerSize.value() - 1) / cstInnerSize.value();
946+
}
942947
// Prioritize the case that the op already says that it does not need
943948
// padding.
944949
if (!packOp.getPaddingValue())

mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir

Lines changed: 28 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -596,15 +596,16 @@ module attributes {transform.with_named_sequence} {
596596
// -----
597597

598598
// It is valid to fuse the pack op with padding semantics if the tiled
599-
// dimensions do not need padding.
599+
// dimensions do not need extra padding and it is a perfect tiling case.
600600

601601
func.func @fuse_pack_consumer_with_padding_semantics(%arg0: tensor<64x32xf32>, %arg1: tensor<64x32xf32>) -> tensor<22x2x3x16xf32> {
602-
%0 = scf.forall (%arg2) = (0) to (32) step (16) shared_outs(%arg3 = %arg1) -> (tensor<64x32xf32>) {
603-
%src = tensor.extract_slice %arg0[0, %arg2] [64, 16] [1, 1] : tensor<64x32xf32> to tensor<64x16xf32>
604-
%dest = tensor.extract_slice %arg3[0, %arg2] [64, 16] [1, 1] : tensor<64x32xf32> to tensor<64x16xf32>
605-
%2 = linalg.exp ins(%src : tensor<64x16xf32>) outs(%dest : tensor<64x16xf32>) -> tensor<64x16xf32>
602+
%0 = scf.forall (%arg2, %arg3) = (0, 0) to (64, 32) step (15, 16) shared_outs(%arg4 = %arg1) -> (tensor<64x32xf32>) {
603+
%size = affine.min affine_map<(d0) -> (-d0 + 64, 15)>(%arg2)
604+
%src = tensor.extract_slice %arg0[%arg2, %arg3] [%size, 16] [1, 1] : tensor<64x32xf32> to tensor<?x16xf32>
605+
%dest = tensor.extract_slice %arg4[%arg2, %arg3] [%size, 16] [1, 1] : tensor<64x32xf32> to tensor<?x16xf32>
606+
%2 = linalg.exp ins(%src : tensor<?x16xf32>) outs(%dest : tensor<?x16xf32>) -> tensor<?x16xf32>
606607
scf.forall.in_parallel {
607-
tensor.parallel_insert_slice %2 into %arg3[0, %arg2] [64, 16] [1, 1] : tensor<64x16xf32> into tensor<64x32xf32>
608+
tensor.parallel_insert_slice %2 into %arg4[%arg2, %arg3] [%size, 16] [1, 1] : tensor<?x16xf32> into tensor<64x32xf32>
608609
}
609610
}
610611
%1 = tensor.empty() : tensor<22x2x3x16xf32>
@@ -621,28 +622,39 @@ module attributes {transform.with_named_sequence} {
621622
transform.yield
622623
}
623624
}
624-
// CHECK: #[[PACK_RESULT_MAP:.*]] = affine_map<(d0) -> (d0 floordiv 16)>
625+
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0) -> (-d0 + 64, 15)>
626+
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0) -> (d0 floordiv 3)>
627+
// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0) -> (d0 ceildiv 3)>
628+
// CHECK-DAG: #[[MAP3:.*]] = affine_map<(d0) -> (d0 floordiv 16)>
625629
// CHECK: func.func @fuse_pack_consumer_with_padding_semantics(
626630
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
627631
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
628632
// CHECK-DAG: %[[OUT_INIT:.*]] = tensor.empty() : tensor<22x2x3x16xf32>
629633
// CHECK-DAG: %[[PAD_VAL:.*]] = arith.constant 0.000000e+00 : f32
630-
// CHECK: %{{.*}}:2 = scf.forall (%[[IV:.*]]) = (0) to (32) step (16)
631-
// CHECK-SAME: shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG1]], %[[PACK_OUT_ARG:.*]] = %[[OUT_INIT]])
632-
// CHECK: %[[ELEM_SRC:.*]] = tensor.extract_slice %[[ARG0]][0, %[[IV]]] [64, 16] [1, 1]
633-
// CHECK: %[[ELEM_DEST:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][0, %[[IV]]] [64, 16] [1, 1]
634+
// CHECK: %{{.*}}:2 = scf.forall (%[[I:.*]], %[[J:.*]]) = (0, 0) to (64, 32) step (15, 16)
635+
// CHECK-SAME: shared_outs(%[[ELEM_OUT:.*]] = %[[ARG1]], %[[PACK_OUT:.*]] = %[[OUT_INIT]])
636+
// CHECK: %[[SIZE:.+]] = affine.min #[[MAP0]](%[[I]])
637+
// CHECK: %[[ELEM_SRC:.*]] = tensor.extract_slice %[[ARG0]]
638+
// CHECK-SAME: [%[[I]], %[[J]]] [%[[SIZE]], 16] [1, 1]
639+
// CHECK: %[[ELEM_DEST:.*]] = tensor.extract_slice %[[ELEM_OUT]]
640+
// CHECK-SAME: [%[[I]], %[[J]]] [%[[SIZE]], 16] [1, 1]
634641
// CHECK: %[[ELEM:.*]] = linalg.exp
635642
// CHECK-SAME: ins(%[[ELEM_SRC]]
636643
// CHECK-SAME: outs(%[[ELEM_DEST]]
637-
// CHECK-DAG: %[[PACK_RESULT_OFFSET:.*]] = affine.apply #[[PACK_RESULT_MAP]](%[[IV]])
638-
// CHECK-DAG: %[[TILED_PACK_DEST:.*]] = tensor.extract_slice %[[PACK_OUT_ARG]][0, %[[PACK_RESULT_OFFSET]], 0, 0] [22, 1, 3, 16] [1, 1, 1, 1]
639-
// CHECK: %[[TILED_PACK_OUT:.*]] = linalg.pack %[[ELEM]]
644+
// CHECK-DAG: %[[D0_OFFSET:.*]] = affine.apply #[[MAP1]](%[[I]])
645+
// CHECK-DAG: %[[D0_SIZE:.*]] = affine.apply #[[MAP2]](%[[SIZE]])
646+
// CHECK-DAG: %[[D1_OFFSET:.*]] = affine.apply #[[MAP3]](%[[J]])
647+
// CHECK-DAG: %[[PACK_INIT:.*]] = tensor.extract_slice %[[PACK_OUT]]
648+
// CHECK-SAME: [%[[D0_OFFSET]], %[[D1_OFFSET]], 0, 0] [%[[D0_SIZE]], 1, 3, 16] [1, 1, 1, 1]
649+
// CHECK: %[[PACK:.*]] = linalg.pack %[[ELEM]]
640650
// CHECK-SAME: padding_value(%[[PAD_VAL]] : f32)
641651
// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [3, 16]
642652
// CHECK-SAME: into %[[TILED_PACK_DEST]]
643653
// CHECK: scf.forall.in_parallel {
644-
// CHECK: tensor.parallel_insert_slice %[[GENERIC_OUT]] into %[[FIRST_OUT_ARG]][0, %[[IV]]] [64, 16] [1, 1]
645-
// CHECK: tensor.parallel_insert_slice %[[TILED_PACK_OUT]] into %[[PACK_OUT_ARG]][0, %[[PACK_RESULT_OFFSET]], 0, 0] [22, 1, 3, 16] [1, 1, 1, 1]
654+
// CHECK: tensor.parallel_insert_slice %[[ELEM]] into %[[ELEM_OUT]]
655+
// CHECK-SAME: [%[[I]], %[[J]]] [%[[SIZE]], 16] [1, 1]
656+
// CHECK: tensor.parallel_insert_slice %[[PACK]] into %[[PACK_OUT]]
657+
// CHECK-SAME: [%[[D0_OFFSET]], %[[D1_OFFSET]], 0, 0] [%[[D0_SIZE]], 1, 3, 16] [1, 1, 1, 1]
646658

647659
// -----
648660

0 commit comments

Comments
 (0)