Skip to content

Commit 3ea6da5

Browse files
authored
[mlir][linalg] Allow pack consumer fusion if the tile size is greater than dimension size. (#149438)
This happens only when you use larger tile size, which is greater than or equal to the dimension size. In this case, it is a full slice, so it is fusible. The IR can be generated during the TileAndFuse process. It is hard to fix in such driver, so we enable the naive fusion for the case. --------- Signed-off-by: hanhanW <[email protected]>
1 parent 87c2adb commit 3ea6da5

File tree

2 files changed

+54
-2
lines changed

2 files changed

+54
-2
lines changed

mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -911,14 +911,16 @@ struct PackOpTiling
911911

912912
// If a dimension is not tiled, it is always valid to fuse the pack op,
913913
// even if the op has padding semantics. Because it always generates a
914-
// full slice along the dimension.
914+
// full slice along the dimension. The tile sizes are for unpacked
915+
// domain, i.e., `srcDimSize`, so `tileSize < srcDimSize` means that the
916+
// dimension is tiled.
915917
// TODO: It could be untiled if the `srcDimSize` is dynamic. It is a
916918
// hard check to determine if a dimension is tiled or not.
917919
int64_t srcDimSize = packOp.getSourceType().getDimSize(dim);
918920
int64_t destDimSize = outerShapeWithoutTranspose[dim];
919921
bool isTiled = failed(cstTileSize) ||
920922
ShapedType::isDynamic(srcDimSize) ||
921-
cstTileSize.value() != srcDimSize;
923+
cstTileSize.value() < srcDimSize;
922924
if (!isTiled) {
923925
outerDimOffsets.push_back(offsets[dim]);
924926
if (ShapedType::isStatic(destDimSize)) {

mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -451,6 +451,56 @@ module attributes {transform.with_named_sequence} {
451451

452452
// -----
453453

454+
#map = affine_map<(d0) -> (-d0 + 4, 16)>
455+
func.func @fuse_pack_consumer_if_single_iteration(%arg0: tensor<4x4xf32>) -> tensor<1x4x16x1xf32> {
456+
%0 = tensor.empty() : tensor<1x4x16x1xf32>
457+
%1 = tensor.empty() : tensor<4x4xf32>
458+
%2 = scf.forall (%arg1) = (0) to (4) step (16) shared_outs(%arg2 = %1) -> (tensor<4x4xf32>) {
459+
%3 = affine.min #map(%arg1)
460+
%extracted_slice = tensor.extract_slice %arg0[%arg1, 0] [%3, 4] [1, 1] : tensor<4x4xf32> to tensor<?x4xf32>
461+
%extracted_slice_0 = tensor.extract_slice %arg2[%arg1, 0] [%3, 4] [1, 1] : tensor<4x4xf32> to tensor<?x4xf32>
462+
%4 = linalg.exp ins(%extracted_slice : tensor<?x4xf32>) outs(%extracted_slice_0 : tensor<?x4xf32>) -> tensor<?x4xf32>
463+
scf.forall.in_parallel {
464+
tensor.parallel_insert_slice %4 into %arg2[%arg1, 0] [%3, 4] [1, 1] : tensor<?x4xf32> into tensor<4x4xf32>
465+
}
466+
}
467+
%cst = arith.constant 0.000000e+00 : f32
468+
%pack = linalg.pack %2 padding_value(%cst : f32) outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [16, 1] into %0 : tensor<4x4xf32> -> tensor<1x4x16x1xf32>
469+
return %pack : tensor<1x4x16x1xf32>
470+
}
471+
472+
module attributes {transform.with_named_sequence} {
473+
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
474+
%0 = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op
475+
%1 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op
476+
%consumer, %fused_consumer = transform.test.fuse_consumer %0 in(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
477+
transform.yield
478+
}
479+
}
480+
// CHECK: #[[MAP:.*]] = affine_map<(d0) -> (-d0 + 4, 16)>
481+
// CHECK: func.func @fuse_pack_consumer_if_single_iteration(
482+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
483+
// CHECK-DAG: %[[PACK_INIT:.*]] = tensor.empty() : tensor<1x4x16x1xf32>
484+
// CHECK-DAG: %[[ELEM_INIT:.*]] = tensor.empty() : tensor<4x4xf32>
485+
// CHECK-DAG: %[[PAD_VAL:.*]] = arith.constant 0.000000e+00 : f32
486+
// CHECK: %{{.*}}:2 = scf.forall (%[[IV:.*]]) = (0) to (4) step (16)
487+
// CHECK-SAME: shared_outs(%[[ELEM_OUT_ARG:.*]] = %[[ELEM_INIT]], %[[PACK_OUT_ARG:.*]] = %[[PACK_INIT]])
488+
// CHECK-DAG: %[[SIZE:.+]] = affine.min #[[MAP]](%[[IV]])
489+
// CHECK-DAG: %[[ELEM_SRC:.*]] = tensor.extract_slice %[[ARG0]][%[[IV]], 0] [%[[SIZE]], 4] [1, 1]
490+
// CHECK-DAG: %[[ELEM_DEST:.*]] = tensor.extract_slice %[[ELEM_OUT_ARG]][%[[IV]], 0] [%[[SIZE]], 4] [1, 1]
491+
// CHECK: %[[ELEM:.*]] = linalg.exp
492+
// CHECK-SAME: ins(%[[ELEM_SRC]]
493+
// CHECK-SAME: outs(%[[ELEM_DEST]]
494+
// CHECK-DAG: %[[TILED_PACK_DEST:.*]] = tensor.extract_slice %[[PACK_OUT_ARG]][%[[IV]], 0, 0, 0] [1, 4, 16, 1] [1, 1, 1, 1]
495+
// CHECK: %[[PACK:.*]] = linalg.pack %[[ELEM]]
496+
// CHECK-SAME: padding_value(%[[PAD_VAL]] : f32)
497+
// CHECK-SAME: outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [16, 1]
498+
// CHECK-SAME: into %[[TILED_PACK_DEST]]
499+
// CHECK: scf.forall.in_parallel {
500+
// CHECK: tensor.parallel_insert_slice %[[ELEM]] into %[[ELEM_OUT_ARG]][%[[IV]], 0] [%[[SIZE]], 4] [1, 1]
501+
// CHECK: tensor.parallel_insert_slice %[[PACK]] into %[[PACK_OUT_ARG]][%[[IV]], 0, 0, 0] [1, 4, 16, 1] [1, 1, 1, 1]
502+
503+
// -----
454504

455505
func.func @fuse_perfect_tiling_pack_consumer_with_outer_dims_perm(%arg0: tensor<64x32xf32>, %arg1: tensor<64x32xf32>, %arg2: tensor<2x64x16x1xf32>) -> tensor<2x64x16x1xf32> {
456506
%0 = scf.forall (%arg3) = (0) to (32) step (16) shared_outs(%arg4 = %arg1) -> (tensor<64x32xf32>) {

0 commit comments

Comments
 (0)