Skip to content

Commit 372b16c

Browse files
add pooling_ncw_max_fill_fuse example.
1 parent 21a8ff0 commit 372b16c

File tree

4 files changed

+52
-9
lines changed

4 files changed

+52
-9
lines changed

mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1759,11 +1759,13 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
17591759
}
17601760

17611761
// The extract_slice op is created in the innermost loop by default. Using
1762-
// hoistLoopInvariantSubsets improves the position of the extract_slice op
1763-
// within the loops, allowing the fuse Op to be created in the correct loop.
1764-
for (LoopLikeOpInterface loop : loops) {
1762+
// `moveLoopInvariantCode` and `hoistLoopInvariantSubsets` improves the
1763+
// position of the extract_slice op within the loops, allowing the fuse Op to
1764+
// be created in the correct loop.
1765+
for (LoopLikeOpInterface loop : loops)
1766+
(void)moveLoopInvariantCode(loop);
1767+
for (LoopLikeOpInterface loop : loops)
17651768
(void)hoistLoopInvariantSubsets(rewriter, loop);
1766-
}
17671769

17681770
// Since the loop gets potentially replaced during fusion, we need to track
17691771
// the mutation of replacement values. To do this, we attach a listener to

mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,9 @@ module attributes {transform.with_named_sequence} {
2828
// CHECK: %[[INIT:.+]] = tensor.empty
2929
// CHECK: scf.for %[[IV0:[a-zA-Z0-9]+]] =
3030
// CHECK-SAME: iter_args(%[[ITERARG0:.+]] = %[[INIT]])
31+
// CHECK-DAG: %[[LHS_TILE:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]], 0]
3132
// CHECK: scf.for %[[IV1:[a-zA-Z0-9]+]] =
3233
// CHECK-SAME: iter_args(%[[ITERARG1:.+]] = %[[ITERARG0]])
33-
// CHECK-DAG: %[[LHS_TILE:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]], 0]
3434
// CHECK-DAG: %[[RHS_TILE:.+]] = tensor.extract_slice %[[ARG1]][0, %[[IV1]]]
3535
// CHECK-DAG: %[[INIT_TILE:.+]] = tensor.extract_slice %[[ITERARG1]][%[[IV0]], %[[IV1]]]
3636
// CHECK: %[[FILL_TILE:.+]] = linalg.fill
@@ -141,6 +141,7 @@ module attributes {transform.with_named_sequence} {
141141
// CHECK-DAG: %[[INIT0:.+]] = tensor.empty(%[[D0]], %[[D1]])
142142
// CHECK-DAG: %[[D2:.+]] = tensor.dim %[[RHS1]], %[[C1]]
143143
// CHECK: %[[INIT1:.+]] = tensor.empty(%[[D0]], %[[D2]])
144+
// CHECK-DAG: %[[RHS1_TILE:.+]] = tensor.extract_slice %[[RHS1]][0, 0]
144145
// CHECK: scf.for %[[IV:[a-zA-Z0-9]+]] =
145146
// CHECK-SAME: iter_args(%[[ITERARG:.+]] = %[[INIT1]])
146147
// CHECK-DAG: %[[LHS0_TILE:.+]] = tensor.extract_slice %[[LHS0]][%[[IV]], 0]
@@ -151,7 +152,6 @@ module attributes {transform.with_named_sequence} {
151152
// CHECK: %[[GEMM0_TILE:.+]] = linalg.matmul
152153
// CHECK-SAME: ins(%[[LHS0_TILE]], %[[RHS0_TILE]] :
153154
// CHECK-SAME: outs(%[[FILL0_TILE]] :
154-
// CHECK-DAG: %[[RHS1_TILE:.+]] = tensor.extract_slice %[[RHS1]][0, 0]
155155
// CHECK-DAG: %[[INIT1_TILE:.+]] = tensor.extract_slice %[[ITERARG]][%[[IV]], 0]
156156
// CHECK: %[[FILL1_TILE:.+]] = linalg.fill
157157
// CHECK-SAME: outs(%[[INIT1_TILE]] :
@@ -444,6 +444,7 @@ module attributes {transform.with_named_sequence} {
444444
// CHECK-DAG: %[[M:.+]] = tensor.dim %[[ORIG_GEMM2]], %[[C0]]
445445
// CHECK-DAG: %[[N2:.+]] = tensor.dim %[[ORIG_GEMM2]], %[[C1]]
446446
// CHECK-DAG: %[[N3:.+]] = tensor.dim %[[ARG5]], %[[C1]]
447+
// CHECK-DAG: %[[SLICE_ARG5:.+]] = tensor.extract_slice %[[ARG5]][0, 0] [%[[N2]], %[[N3]]]
447448
// CHECK: %[[R0:.+]] = scf.for %[[IV:[a-zA-Z0-9_]+]] =
448449
// CHECK-SAME: iter_args(%[[ARG8:.+]] = %[[ARG6]]) -> (tensor<?x?xf32>) {
449450
// CHECK-DAG: %[[N1:.+]] = tensor.dim %[[ORIG_GEMM1]], %[[C1]]
@@ -458,7 +459,6 @@ module attributes {transform.with_named_sequence} {
458459
// CHECK-DAG: %[[SLICE_ARG4:.+]] = tensor.extract_slice %[[ARG4]][%[[IV]], 0] [%[[TILE_M]], %[[N2]]]
459460
// CHECK-DAG: %[[TILE_GEMM2:.+]] = linalg.matmul ins(%[[TILE_GEMM1]], %[[SLICE_ARG3]] :
460461
// CHECK-SAME: outs(%[[SLICE_ARG4]] :
461-
// CHECK-DAG: %[[SLICE_ARG5:.+]] = tensor.extract_slice %[[ARG5]][0, 0] [%[[N2]], %[[N3]]]
462462
// CHECK-DAG: %[[SLICE_ARG6:.+]] = tensor.extract_slice %[[ARG8]][%[[IV]], 0] [%[[TILE_M]], %[[N3]]]
463463
// CHECK-DAG: %[[TILE_GEMM3:.+]] = linalg.matmul
464464
// CHECK-SAME: ins(%[[TILE_GEMM2]], %[[SLICE_ARG5]] :
@@ -688,3 +688,44 @@ module attributes {transform.with_named_sequence} {
688688
// CHECK: }
689689
// CHECK: }
690690

691+
// -----
692+
693+
func.func @pooling_ncw_max_fill_fuse(%input: tensor<?x?x?xf32>, %fake: tensor<?xf32>, %init: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
694+
%cst = arith.constant 0.000000e+00 : f32
695+
%fill = linalg.fill ins(%cst : f32) outs(%init : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
696+
%res = linalg.pooling_ncw_max {dilations = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
697+
ins(%input, %fake: tensor<?x?x?xf32>, tensor<?xf32>)
698+
outs(%fill: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
699+
return %res : tensor<?x?x?xf32>
700+
}
701+
702+
module attributes {transform.with_named_sequence} {
703+
transform.named_sequence @__transform_main(
704+
%arg0: !transform.any_op {transform.readonly}) {
705+
%0 = transform.structured.match ops{["linalg.pooling_ncw_max"]} in %arg0 : (!transform.any_op) -> !transform.any_op
706+
%tiled_pool, %loops0:4 = transform.structured.fuse %0 {tile_sizes = [1, 16, 1, 1], apply_cleanup = true}
707+
: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
708+
transform.yield
709+
}
710+
}
711+
712+
// CHECK-LABEL: func.func @pooling_ncw_max_fill_fuse(
713+
// CHECK-SAME: %[[INPUT:.*]]: tensor<?x?x?xf32>,
714+
// CHECK-SAME: %[[FAKE:.*]]: tensor<?xf32>,
715+
// CHECK-SAME: %[[INIT:.*]]: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
716+
// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
717+
// CHECK: scf.for %[[IV0:[a-zA-Z0-9]+]] =
718+
// CHECK-SAME: iter_args(%[[ITERARG0:.+]] = %[[INIT]])
719+
// CHECK: scf.for %[[IV1:[a-zA-Z0-9]+]] =
720+
// CHECK-SAME: iter_args(%[[ITERARG1:.+]] = %[[ITERARG0]])
721+
// CHECK: scf.for %[[IV2:[a-zA-Z0-9]+]] =
722+
// CHECK-SAME: iter_args(%[[ITERARG2:.+]] = %[[ITERARG1]])
723+
// CHECK: %[[FILL_EXTRACT:.*]] = tensor.extract_slice %[[ITERARG2]]{{\[}}%[[IV0]], %[[IV1]], %[[IV2]]]
724+
// CHECK: %[[TILED_FILL:.*]] = linalg.fill ins(%[[ZERO]] : f32) outs(%[[FILL_EXTRACT]] : tensor<1x?x1xf32>) -> tensor<1x?x1xf32>
725+
// CHECK: scf.for %[[IV3:[a-zA-Z0-9]+]] =
726+
// CHECK-SAME: iter_args(%[[ITERARG3:.*]] = %[[ITERARG2]], %[[ITERARG4:.*]] = %[[TILED_FILL]])
727+
// CHECK: %[[TILED_INPUT:.*]] = tensor.extract_slice %[[INPUT]]{{\[}}%[[IV0]], %[[IV1]]
728+
// CHECK: %[[TILED_FAKE:.*]] = tensor.extract_slice %[[FAKE]]{{\[}}%[[IV3]]]
729+
// CHECK: linalg.pooling_ncw_max
730+
// CHECK-SAME: ins(%[[TILED_INPUT]], %[[TILED_FAKE]] :
731+
// CHECK-SAME: outs(%[[ITERARG4]] :

mlir/test/Interfaces/TilingInterface/tile-fuse-and-yield-using-interface.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ module attributes {transform.with_named_sequence} {
3737
// CHECK-SAME: %[[INIT1:[a-zA-Z0-9]+]]: tensor<?x?xf32>)
3838
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
3939
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
40+
// CHECK-DAG: %[[RHS1_TILE:.+]] = tensor.extract_slice %[[RHS1]][0, 0]
4041
// CHECK: %[[RESULT:.+]]:2 = scf.for %[[IV:[a-zA-Z0-9]+]] =
4142
// CHECK-SAME: iter_args(%[[ITERARG0:[a-zA-Z0-9]+]] = %[[INIT1]], %[[ITERARG1:[a-zA-Z0-9]+]] = %[[INIT0]])
4243
// CHECK-DAG: %[[LHS0_TILE:.+]] = tensor.extract_slice %[[LHS0]][%[[IV]], 0]
@@ -47,7 +48,6 @@ module attributes {transform.with_named_sequence} {
4748
// CHECK: %[[GEMM0_TILE:.+]] = linalg.matmul
4849
// CHECK-SAME: ins(%[[LHS0_TILE]], %[[RHS0_TILE]] :
4950
// CHECK-SAME: outs(%[[FILL0_TILE]] :
50-
// CHECK-DAG: %[[RHS1_TILE:.+]] = tensor.extract_slice %[[RHS1]][0, 0]
5151
// CHECK-DAG: %[[INIT1_TILE:.+]] = tensor.extract_slice %[[ITERARG0]][%[[IV]], 0]
5252
// CHECK: %[[FILL1_TILE:.+]] = linalg.fill
5353
// CHECK-SAME: outs(%[[INIT1_TILE]] :

mlir/test/Interfaces/TilingInterface/tile-fuse-and-yield-using-scfforall.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ module attributes {transform.with_named_sequence} {
3737
// CHECK-SAME: %[[INIT1:[a-zA-Z0-9]+]]: tensor<?x?xf32>)
3838
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
3939
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
40+
// CHECK-DAG: %[[RHS1_TILE:.+]] = tensor.extract_slice %[[RHS1]][0, 0]
4041
// CHECK: %[[RESULT:.+]]:2 = scf.forall (%[[IV:[a-zA-Z0-9]+]]) =
4142
// CHECK-SAME: shared_outs(%[[ITERARG0:[a-zA-Z0-9]+]] = %[[INIT1]], %[[ITERARG1:[a-zA-Z0-9]+]] = %[[INIT0]])
4243
// CHECK-DAG: %[[LHS0_TILE:.+]] = tensor.extract_slice %[[LHS0]][%[[IV]], 0]
@@ -47,7 +48,6 @@ module attributes {transform.with_named_sequence} {
4748
// CHECK: %[[GEMM0_TILE:.+]] = linalg.matmul
4849
// CHECK-SAME: ins(%[[LHS0_TILE]], %[[RHS0_TILE]] :
4950
// CHECK-SAME: outs(%[[FILL0_TILE]] :
50-
// CHECK-DAG: %[[RHS1_TILE:.+]] = tensor.extract_slice %[[RHS1]][0, 0]
5151
// CHECK-DAG: %[[INIT1_TILE:.+]] = tensor.extract_slice %[[ITERARG0]][%[[IV]], 0]
5252
// CHECK: %[[FILL1_TILE:.+]] = linalg.fill
5353
// CHECK-SAME: outs(%[[INIT1_TILE]] :

0 commit comments

Comments
 (0)