Skip to content

Commit 583639d

Browse files
committed
Fix the fusion logic and add more lit tests
Signed-off-by: hanhanW <[email protected]>
1 parent 864a9a5 commit 583639d

File tree

2 files changed

+137
-29
lines changed

2 files changed

+137
-29
lines changed

mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp

Lines changed: 37 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -895,48 +895,63 @@ struct PackOpTiling
895895
packOp.getDimAndTileMapping();
896896
for (auto dim : llvm::seq<int64_t>(packOp.getSourceRank())) {
897897
if (dimAndTileMapping.count(dim)) {
898-
// Currently fusing `packOp` as consumer only expects perfect tiling
899-
// scenario because even if without padding semantic, the `packOp` may
900-
// also yield incomplete tiles. E.g. tensor<30xf32> -> tensor<5x6xf32>,
901-
// where the `tileSize` from operand of `packOp` is 5, which is not
902-
// exactly divided by `innerTile`(=6) of `packOp`. As the result:
903-
// 1. the first slice is extracted from (0) to (4) and inserted into
904-
// (0,0)~(0,4) at first row.
905-
// 2. the second slice is extracted from (5) to (9) and SHOULD BE
906-
// respectively inserted into two rows with different length, including
907-
// first row: (0,5) and second row (1,0)~(1,3).
908-
// It is hard to coordinate them, thus adding below constraint to bypass
909-
// them temporarily. In another word, we can only support tiling with
910-
// consumer if the tile size for the producer is either a multiple of
911-
// the inner tile size for the packed dimensions or the dimension is not
912-
// tiled at this moment.
913898
FailureOr<int64_t> cstTileSize =
914899
ValueBoundsConstraintSet::computeConstantBound(
915900
presburger::BoundType::UB, sizes[dim],
916901
/*stopCondition=*/nullptr, /*closedUB=*/true);
917902
std::optional<int64_t> cstInnerSize =
918903
getConstantIntValue(dimAndTileMapping[dim]);
904+
919905
// If a dimension is not tiled, it is always valid to fuse the pack op,
920906
// even if the op has padding semantics. Because it always generates a
921907
// full slice along the dimension.
922908
// TODO: It could be untiled if the `srcDimSize` is dynamic. It is a
923909
// hard check to determine if a dimension is tiled or not.
924910
int64_t srcDimSize = packOp.getSourceType().getDimSize(dim);
911+
int64_t destDimSize = packOp.getDestType().getDimSize(dim);
925912
bool isTiled = failed(cstTileSize) ||
926913
ShapedType::isDynamic(srcDimSize) ||
927914
cstTileSize.value() != srcDimSize;
928-
int64_t destDimSize = packOp.getDestType().getDimSize(dim);
929-
bool needPadding = ShapedType::isDynamic(destDimSize) ||
915+
if (!isTiled) {
916+
outerDimOffsets.push_back(offsets[dim]);
917+
if (ShapedType::isStatic(destDimSize)) {
918+
outerDimSizes.push_back(b.getIndexAttr(destDimSize));
919+
} else {
920+
outerDimSizes.push_back(
921+
b.createOrFold<tensor::DimOp>(loc, packOp.getDest(), dim));
922+
}
923+
continue;
924+
}
925+
926+
// If the dimension needs padding, it is not supported because there are
927+
// iterations that only write padding values to the whole tile. The
928+
// consumer fusion is driven by the source, so it is not possible to map
929+
// an empty slice to the tile.
930+
bool needExtraPadding = ShapedType::isDynamic(destDimSize) ||
930931
!cstInnerSize ||
931932
destDimSize * cstInnerSize.value() != srcDimSize;
932933
// Prioritize the case that the op already says that it does not need
933934
// padding.
934-
if (!packOp.getPaddingValue()) {
935-
needPadding = false;
936-
}
937-
if (isTiled && needPadding) {
935+
if (!packOp.getPaddingValue())
936+
needExtraPadding = false;
937+
if (needExtraPadding)
938+
return failure();
939+
940+
// Currently fusing `packOp` as consumer only expects perfect tiling
941+
// scenario because even if without padding semantic, the `packOp` may
942+
// also yield incomplete tiles. E.g. tensor<30xf32> -> tensor<5x6xf32>,
943+
// where the `tileSize` from operand of `packOp` is 5, which is not
944+
// exactly divided by `innerTile`(=6) of `packOp`. As the result:
945+
// 1. the first slice is extracted from (0) to (4) and inserted into
946+
// (0,0)~(0,4) at first row.
947+
// 2. the second slice is extracted from (5) to (9) and SHOULD BE
948+
// respectively inserted into two rows with different length, including
949+
// first row: (0,5) and second row (1,0)~(1,3).
950+
// It is hard to coordinate them, thus adding below constraint to bypass
951+
// them temporarily.
952+
if ((failed(cstTileSize) || !cstInnerSize ||
953+
*cstTileSize % *cstInnerSize != 0))
938954
return failure();
939-
}
940955

941956
using AV = affine::AffineValueExpr;
942957
affine::AffineBuilder ab(b, loc);

mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir

Lines changed: 100 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -395,7 +395,7 @@ module attributes {transform.with_named_sequence} {
395395

396396
#map = affine_map<(d0, d1) -> (d0, d1)>
397397
module {
398-
func.func @fuse_pack_consumer_into_scf_forall(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x32xf32>) -> tensor<4x32x16xf32> {
398+
func.func @fuse_perfect_tiling_pack_consumer(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x32xf32>) -> tensor<4x32x16xf32> {
399399
%c4 = arith.constant 4 : index
400400
%c64 = arith.constant 64 : index
401401
%c0 = arith.constant 0 : index
@@ -429,7 +429,7 @@ module attributes {transform.with_named_sequence} {
429429
}
430430
}
431431
// CHECK: #[[PACK_RESULT_MAP:.*]] = affine_map<(d0) -> (d0 floordiv 16)>
432-
// CHECK: func.func @fuse_pack_consumer_into_scf_forall(
432+
// CHECK: func.func @fuse_perfect_tiling_pack_consumer(
433433
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<32x32xf32>
434434
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<32x32xf32>
435435
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<64x32xf32>)
@@ -451,7 +451,10 @@ module attributes {transform.with_named_sequence} {
451451

452452
// -----
453453

454-
func.func @fuse_pack_consumer_with_padding_semantic_into_scf_forall(%arg0: tensor<64x32xf32>, %arg1: tensor<64x32xf32>) -> tensor<23x2x3x16xf32> {
454+
// It is valid to fuse the pack op with padding semantics if the dimension does
455+
// not need padding.
456+
457+
func.func @fuse_pack_consumer_with_padding_semantics(%arg0: tensor<64x32xf32>, %arg1: tensor<64x32xf32>) -> tensor<23x2x3x16xf32> {
455458
%0 = scf.forall (%arg2) = (0) to (32) step (16) shared_outs(%arg3 = %arg1) -> (tensor<64x32xf32>) {
456459
%src = tensor.extract_slice %arg0[0, %arg2] [64, 16] [1, 1] : tensor<64x32xf32> to tensor<64x16xf32>
457460
%dest = tensor.extract_slice %arg3[0, %arg2] [64, 16] [1, 1] : tensor<64x32xf32> to tensor<64x16xf32>
@@ -475,7 +478,7 @@ module attributes {transform.with_named_sequence} {
475478
}
476479
}
477480
// CHECK: #[[PACK_RESULT_MAP:.*]] = affine_map<(d0) -> (d0 floordiv 16)>
478-
// CHECK: func.func @fuse_pack_consumer_with_padding_semantic_into_scf_forall(
481+
// CHECK: func.func @fuse_pack_consumer_with_padding_semantics(
479482
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
480483
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
481484
// CHECK-DAG: %[[OUT_INIT:.*]] = tensor.empty() : tensor<23x2x3x16xf32>
@@ -488,18 +491,72 @@ module attributes {transform.with_named_sequence} {
488491
// CHECK-SAME: ins(%[[ELEM_SRC]]
489492
// CHECK-SAME: outs(%[[ELEM_DEST]]
490493
// CHECK-DAG: %[[PACK_RESULT_OFFSET:.*]] = affine.apply #[[PACK_RESULT_MAP]](%[[IV]])
491-
// CHECK-DAG: %[[TILED_PACK_DEST:.*]] = tensor.extract_slice %[[PACK_OUT_ARG]][0, %[[PACK_RESULT_OFFSET]], 0, 0] [22, 1, 3, 16] [1, 1, 1, 1]
494+
// CHECK-DAG: %[[TILED_PACK_DEST:.*]] = tensor.extract_slice %[[PACK_OUT_ARG]][0, %[[PACK_RESULT_OFFSET]], 0, 0] [23, 1, 3, 16] [1, 1, 1, 1]
492495
// CHECK: %[[TILED_PACK_OUT:.*]] = linalg.pack %[[ELEM]]
493496
// CHECK-SAME: padding_value(%[[PAD_VAL]] : f32)
494497
// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [3, 16]
495498
// CHECK-SAME: into %[[TILED_PACK_DEST]]
496499
// CHECK: scf.forall.in_parallel {
497500
// CHECK: tensor.parallel_insert_slice %[[GENERIC_OUT]] into %[[FIRST_OUT_ARG]][0, %[[IV]]] [64, 16] [1, 1]
498-
// CHECK: tensor.parallel_insert_slice %[[TILED_PACK_OUT]] into %[[PACK_OUT_ARG]][0, %[[PACK_RESULT_OFFSET]], 0, 0] [22, 1, 3, 16] [1, 1, 1, 1]
501+
// CHECK: tensor.parallel_insert_slice %[[TILED_PACK_OUT]] into %[[PACK_OUT_ARG]][0, %[[PACK_RESULT_OFFSET]], 0, 0] [23, 1, 3, 16] [1, 1, 1, 1]
499502

500503
// -----
501504

502-
func.func @nofuse_pack_consumer_with_padding_semantic_into_scf_forall(%arg0: tensor<64x32xf32>, %arg1: tensor<64x32xf32>) -> tensor<23x32x3x16xf32> {
505+
// It is valid to fuse the pack if the dimension is not tiled even when it needs
506+
// extra padding.
507+
508+
func.func @fuse_pack_consumer_with_untiled_extra_padding(%arg0: tensor<64x32xf32>, %arg1: tensor<64x32xf32>) -> tensor<33x2x3x16xf32> {
509+
%0 = scf.forall (%arg2) = (0) to (32) step (16) shared_outs(%arg3 = %arg1) -> (tensor<64x32xf32>) {
510+
%src = tensor.extract_slice %arg0[0, %arg2] [64, 16] [1, 1] : tensor<64x32xf32> to tensor<64x16xf32>
511+
%dest = tensor.extract_slice %arg3[0, %arg2] [64, 16] [1, 1] : tensor<64x32xf32> to tensor<64x16xf32>
512+
%2 = linalg.exp ins(%src : tensor<64x16xf32>) outs(%dest : tensor<64x16xf32>) -> tensor<64x16xf32>
513+
scf.forall.in_parallel {
514+
tensor.parallel_insert_slice %2 into %arg3[0, %arg2] [64, 16] [1, 1] : tensor<64x16xf32> into tensor<64x32xf32>
515+
}
516+
}
517+
%1 = tensor.empty() : tensor<33x2x3x16xf32>
518+
%cst = arith.constant 0.000000e+00 : f32
519+
%pack = linalg.pack %0 padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [3, 16] into %1 : tensor<64x32xf32> -> tensor<33x2x3x16xf32>
520+
return %pack : tensor<33x2x3x16xf32>
521+
}
522+
523+
module attributes {transform.with_named_sequence} {
524+
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
525+
%0 = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op
526+
%1 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op
527+
%consumer, %fused_consumer = transform.test.fuse_consumer %0 in(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
528+
transform.yield
529+
}
530+
}
531+
// CHECK: #[[PACK_RESULT_MAP:.*]] = affine_map<(d0) -> (d0 floordiv 16)>
532+
// CHECK: func.func @fuse_pack_consumer_with_untiled_extra_padding(
533+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
534+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
535+
// CHECK-DAG: %[[OUT_INIT:.*]] = tensor.empty() : tensor<33x2x3x16xf32>
536+
// CHECK-DAG: %[[PAD_VAL:.*]] = arith.constant 0.000000e+00 : f32
537+
// CHECK: %{{.*}}:2 = scf.forall (%[[IV:.*]]) = (0) to (32) step (16)
538+
// CHECK-SAME: shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG1]], %[[PACK_OUT_ARG:.*]] = %[[OUT_INIT]])
539+
// CHECK: %[[ELEM_SRC:.*]] = tensor.extract_slice %[[ARG0]][0, %[[IV]]] [64, 16] [1, 1]
540+
// CHECK: %[[ELEM_DEST:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][0, %[[IV]]] [64, 16] [1, 1]
541+
// CHECK: %[[ELEM:.*]] = linalg.exp
542+
// CHECK-SAME: ins(%[[ELEM_SRC]]
543+
// CHECK-SAME: outs(%[[ELEM_DEST]]
544+
// CHECK-DAG: %[[PACK_RESULT_OFFSET:.*]] = affine.apply #[[PACK_RESULT_MAP]](%[[IV]])
545+
// CHECK-DAG: %[[TILED_PACK_DEST:.*]] = tensor.extract_slice %[[PACK_OUT_ARG]][0, %[[PACK_RESULT_OFFSET]], 0, 0] [33, 1, 3, 16] [1, 1, 1, 1]
546+
// CHECK: %[[TILED_PACK_OUT:.*]] = linalg.pack %[[ELEM]]
547+
// CHECK-SAME: padding_value(%[[PAD_VAL]] : f32)
548+
// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [3, 16]
549+
// CHECK-SAME: into %[[TILED_PACK_DEST]]
550+
// CHECK: scf.forall.in_parallel {
551+
// CHECK: tensor.parallel_insert_slice %[[GENERIC_OUT]] into %[[FIRST_OUT_ARG]][0, %[[IV]]] [64, 16] [1, 1]
552+
// CHECK: tensor.parallel_insert_slice %[[TILED_PACK_OUT]] into %[[PACK_OUT_ARG]][0, %[[PACK_RESULT_OFFSET]], 0, 0] [33, 1, 3, 16] [1, 1, 1, 1]
553+
554+
// -----
555+
556+
// If the dimension is tiled and it needs extra padding, do not fuse the pack
557+
// op.
558+
559+
func.func @nofuse_pack_consumer_with_extra_padding(%arg0: tensor<64x32xf32>, %arg1: tensor<64x32xf32>) -> tensor<23x32x3x16xf32> {
503560
%0 = scf.forall (%arg2) = (0) to (32) step (16) shared_outs(%arg3 = %arg1) -> (tensor<64x32xf32>) {
504561
%src = tensor.extract_slice %arg0[0, %arg2] [64, 16] [1, 1] : tensor<64x32xf32> to tensor<64x16xf32>
505562
%dest = tensor.extract_slice %arg3[0, %arg2] [64, 16] [1, 1] : tensor<64x32xf32> to tensor<64x16xf32>
@@ -526,6 +583,42 @@ module attributes {transform.with_named_sequence} {
526583

527584
// -----
528585

586+
// Imperfect tiling is not supported in pack op consumer fusion.
587+
588+
#map = affine_map<(d0) -> (d0 * 5)>
589+
#map1 = affine_map<(d0) -> (d0)>
590+
func.func @nofuse_pack_with_imperfect_tiling(%arg0: tensor<30xf32>) -> tensor<5x6xf32> {
591+
%0 = tensor.empty() : tensor<30xf32>
592+
%1 = scf.forall (%arg1) in (6) shared_outs(%arg2 = %0) -> (tensor<30xf32>) {
593+
%3 = affine.apply #map(%arg1)
594+
%extracted_slice = tensor.extract_slice %arg0[%3] [5] [1] : tensor<30xf32> to tensor<5xf32>
595+
%extracted_slice_0 = tensor.extract_slice %arg2[%3] [5] [1] : tensor<30xf32> to tensor<5xf32>
596+
%4 = linalg.generic {indexing_maps = [#map1, #map1], iterator_types = ["parallel"]} ins(%extracted_slice : tensor<5xf32>) outs(%extracted_slice_0 : tensor<5xf32>) {
597+
^bb0(%in: f32, %out: f32):
598+
%5 = arith.addf %in, %in : f32
599+
linalg.yield %5 : f32
600+
} -> tensor<5xf32>
601+
scf.forall.in_parallel {
602+
// expected-error @below {{failed to fuse consumer of slice}}
603+
tensor.parallel_insert_slice %4 into %arg2[%3] [5] [1] : tensor<5xf32> into tensor<30xf32>
604+
}
605+
}
606+
%2 = tensor.empty() : tensor<5x6xf32>
607+
%pack = linalg.pack %1 outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [6] into %2 : tensor<30xf32> -> tensor<5x6xf32>
608+
return %pack : tensor<5x6xf32>
609+
}
610+
611+
module attributes {transform.with_named_sequence} {
612+
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
613+
%0 = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op
614+
%1 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op
615+
%consumer, %fused_consumer = transform.test.fuse_consumer %0 in(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
616+
transform.yield
617+
}
618+
}
619+
620+
// -----
621+
529622
module {
530623
func.func @fuse_add_multiple_tilable_consumers(%arg0: tensor<256x256xf32>, %arg1: tensor<256x256xf32>, %arg2: tensor<256x256xf32>) -> (tensor<256x256xf32>, tensor<256x256xf32>) {
531624
%c0 = arith.constant 0 : index

0 commit comments

Comments
 (0)