@@ -28,9 +28,9 @@ module attributes {transform.with_named_sequence} {
2828// CHECK: %[[INIT:.+]] = tensor.empty
2929// CHECK: scf.for %[[IV0:[a-zA-Z0-9]+]] =
3030// CHECK-SAME: iter_args(%[[ITERARG0:.+]] = %[[INIT]])
31+ // CHECK-DAG: %[[LHS_TILE:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]], 0]
3132// CHECK: scf.for %[[IV1:[a-zA-Z0-9]+]] =
3233// CHECK-SAME: iter_args(%[[ITERARG1:.+]] = %[[ITERARG0]])
33- // CHECK-DAG: %[[LHS_TILE:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]], 0]
3434// CHECK-DAG: %[[RHS_TILE:.+]] = tensor.extract_slice %[[ARG1]][0, %[[IV1]]]
3535// CHECK-DAG: %[[INIT_TILE:.+]] = tensor.extract_slice %[[ITERARG1]][%[[IV0]], %[[IV1]]]
3636// CHECK: %[[FILL_TILE:.+]] = linalg.fill
@@ -141,6 +141,7 @@ module attributes {transform.with_named_sequence} {
141141// CHECK-DAG: %[[INIT0:.+]] = tensor.empty(%[[D0]], %[[D1]])
142142// CHECK-DAG: %[[D2:.+]] = tensor.dim %[[RHS1]], %[[C1]]
143143// CHECK: %[[INIT1:.+]] = tensor.empty(%[[D0]], %[[D2]])
144+ // CHECK-DAG: %[[RHS1_TILE:.+]] = tensor.extract_slice %[[RHS1]][0, 0]
144145// CHECK: scf.for %[[IV:[a-zA-Z0-9]+]] =
145146// CHECK-SAME: iter_args(%[[ITERARG:.+]] = %[[INIT1]])
146147// CHECK-DAG: %[[LHS0_TILE:.+]] = tensor.extract_slice %[[LHS0]][%[[IV]], 0]
@@ -151,7 +152,6 @@ module attributes {transform.with_named_sequence} {
151152// CHECK: %[[GEMM0_TILE:.+]] = linalg.matmul
152153// CHECK-SAME: ins(%[[LHS0_TILE]], %[[RHS0_TILE]] :
153154// CHECK-SAME: outs(%[[FILL0_TILE]] :
154- // CHECK-DAG: %[[RHS1_TILE:.+]] = tensor.extract_slice %[[RHS1]][0, 0]
155155// CHECK-DAG: %[[INIT1_TILE:.+]] = tensor.extract_slice %[[ITERARG]][%[[IV]], 0]
156156// CHECK: %[[FILL1_TILE:.+]] = linalg.fill
157157// CHECK-SAME: outs(%[[INIT1_TILE]] :
@@ -444,6 +444,7 @@ module attributes {transform.with_named_sequence} {
444444// CHECK-DAG: %[[M:.+]] = tensor.dim %[[ORIG_GEMM2]], %[[C0]]
445445// CHECK-DAG: %[[N2:.+]] = tensor.dim %[[ORIG_GEMM2]], %[[C1]]
446446// CHECK-DAG: %[[N3:.+]] = tensor.dim %[[ARG5]], %[[C1]]
447+ // CHECK-DAG: %[[SLICE_ARG5:.+]] = tensor.extract_slice %[[ARG5]][0, 0] [%[[N2]], %[[N3]]]
447448// CHECK: %[[R0:.+]] = scf.for %[[IV:[a-zA-Z0-9_]+]] =
448449// CHECK-SAME: iter_args(%[[ARG8:.+]] = %[[ARG6]]) -> (tensor<?x?xf32>) {
449450// CHECK-DAG: %[[N1:.+]] = tensor.dim %[[ORIG_GEMM1]], %[[C1]]
@@ -458,7 +459,6 @@ module attributes {transform.with_named_sequence} {
458459// CHECK-DAG: %[[SLICE_ARG4:.+]] = tensor.extract_slice %[[ARG4]][%[[IV]], 0] [%[[TILE_M]], %[[N2]]]
459460// CHECK-DAG: %[[TILE_GEMM2:.+]] = linalg.matmul ins(%[[TILE_GEMM1]], %[[SLICE_ARG3]] :
460461// CHECK-SAME: outs(%[[SLICE_ARG4]] :
461- // CHECK-DAG: %[[SLICE_ARG5:.+]] = tensor.extract_slice %[[ARG5]][0, 0] [%[[N2]], %[[N3]]]
462462// CHECK-DAG: %[[SLICE_ARG6:.+]] = tensor.extract_slice %[[ARG8]][%[[IV]], 0] [%[[TILE_M]], %[[N3]]]
463463// CHECK-DAG: %[[TILE_GEMM3:.+]] = linalg.matmul
464464// CHECK-SAME: ins(%[[TILE_GEMM2]], %[[SLICE_ARG5]] :
@@ -688,3 +688,44 @@ module attributes {transform.with_named_sequence} {
688688// CHECK: }
689689// CHECK: }
690690
691+ // -----
692+
693+ func.func @pooling_ncw_max_fill_fuse (%input: tensor <?x?x?xf32 >, %fake: tensor <?xf32 >, %init: tensor <?x?x?xf32 >) -> tensor <?x?x?xf32 > {
694+ %cst = arith.constant 0.000000e+00 : f32
695+ %fill = linalg.fill ins (%cst : f32 ) outs (%init : tensor <?x?x?xf32 >) -> tensor <?x?x?xf32 >
696+ %res = linalg.pooling_ncw_max {dilations = dense <1 > : tensor <1 xi64 >, strides = dense <1 > : tensor <1 xi64 >}
697+ ins (%input , %fake: tensor <?x?x?xf32 >, tensor <?xf32 >)
698+ outs (%fill: tensor <?x?x?xf32 >) -> tensor <?x?x?xf32 >
699+ return %res : tensor <?x?x?xf32 >
700+ }
701+
702+ module attributes {transform.with_named_sequence } {
703+ transform.named_sequence @__transform_main (
704+ %arg0: !transform.any_op {transform.readonly }) {
705+ %0 = transform.structured.match ops {[" linalg.pooling_ncw_max" ]} in %arg0 : (!transform.any_op ) -> !transform.any_op
706+ %tiled_pool , %loops0:4 = transform.structured.fuse %0 {tile_sizes = [1 , 16 , 1 , 1 ], apply_cleanup = true }
707+ : (!transform.any_op ) -> (!transform.any_op , !transform.any_op , !transform.any_op , !transform.any_op , !transform.any_op )
708+ transform.yield
709+ }
710+ }
711+
712+ // CHECK-LABEL: func.func @pooling_ncw_max_fill_fuse(
713+ // CHECK-SAME: %[[INPUT:.*]]: tensor<?x?x?xf32>,
714+ // CHECK-SAME: %[[FAKE:.*]]: tensor<?xf32>,
715+ // CHECK-SAME: %[[INIT:.*]]: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
716+ // CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
717+ // CHECK: scf.for %[[IV0:[a-zA-Z0-9]+]] =
718+ // CHECK-SAME: iter_args(%[[ITERARG0:.+]] = %[[INIT]])
719+ // CHECK: scf.for %[[IV1:[a-zA-Z0-9]+]] =
720+ // CHECK-SAME: iter_args(%[[ITERARG1:.+]] = %[[ITERARG0]])
721+ // CHECK: scf.for %[[IV2:[a-zA-Z0-9]+]] =
722+ // CHECK-SAME: iter_args(%[[ITERARG2:.+]] = %[[ITERARG1]])
723+ // CHECK: %[[FILL_EXTRACT:.*]] = tensor.extract_slice %[[ITERARG2]]{{\[}}%[[IV0]], %[[IV1]], %[[IV2]]]
724+ // CHECK: %[[TILED_FILL:.*]] = linalg.fill ins(%[[ZERO]] : f32) outs(%[[FILL_EXTRACT]] : tensor<1x?x1xf32>) -> tensor<1x?x1xf32>
725+ // CHECK: scf.for %[[IV3:[a-zA-Z0-9]+]] =
726+ // CHECK-SAME: iter_args(%[[ITERARG3:.*]] = %[[ITERARG2]], %[[ITERARG4:.*]] = %[[TILED_FILL]])
727+ // CHECK: %[[TILED_INPUT:.*]] = tensor.extract_slice %[[INPUT]]{{\[}}%[[IV0]], %[[IV1]]
728+ // CHECK: %[[TILED_FAKE:.*]] = tensor.extract_slice %[[FAKE]]{{\[}}%[[IV3]]]
729+ // CHECK: linalg.pooling_ncw_max
730+ // CHECK-SAME: ins(%[[TILED_INPUT]], %[[TILED_FAKE]] :
731+ // CHECK-SAME: outs(%[[ITERARG4]] :
0 commit comments