@@ -451,6 +451,56 @@ module attributes {transform.with_named_sequence} {
451451
452452// -----
453453
454+ #map = affine_map <(d0 ) -> (-d0 + 4 , 16 )>
455+ func.func @fuse_pack_consumer_if_single_iteration (%arg0: tensor <4 x4 xf32 >) -> tensor <1 x4 x16 x1 xf32 > {
456+ %0 = tensor.empty () : tensor <1 x4 x16 x1 xf32 >
457+ %1 = tensor.empty () : tensor <4 x4 xf32 >
458+ %2 = scf.forall (%arg1 ) = (0 ) to (4 ) step (16 ) shared_outs (%arg2 = %1 ) -> (tensor <4 x4 xf32 >) {
459+ %3 = affine.min #map (%arg1 )
460+ %extracted_slice = tensor.extract_slice %arg0 [%arg1 , 0 ] [%3 , 4 ] [1 , 1 ] : tensor <4 x4 xf32 > to tensor <?x4 xf32 >
461+ %extracted_slice_0 = tensor.extract_slice %arg2 [%arg1 , 0 ] [%3 , 4 ] [1 , 1 ] : tensor <4 x4 xf32 > to tensor <?x4 xf32 >
462+ %4 = linalg.exp ins (%extracted_slice : tensor <?x4 xf32 >) outs (%extracted_slice_0 : tensor <?x4 xf32 >) -> tensor <?x4 xf32 >
463+ scf.forall.in_parallel {
464+ tensor.parallel_insert_slice %4 into %arg2 [%arg1 , 0 ] [%3 , 4 ] [1 , 1 ] : tensor <?x4 xf32 > into tensor <4 x4 xf32 >
465+ }
466+ }
467+ %cst = arith.constant 0.000000e+00 : f32
468+ %pack = linalg.pack %2 padding_value (%cst : f32 ) outer_dims_perm = [0 , 1 ] inner_dims_pos = [0 , 1 ] inner_tiles = [16 , 1 ] into %0 : tensor <4 x4 xf32 > -> tensor <1 x4 x16 x1 xf32 >
469+ return %pack : tensor <1 x4 x16 x1 xf32 >
470+ }
471+
472+ module attributes {transform.with_named_sequence } {
473+ transform.named_sequence @__transform_main (%arg0: !transform.any_op {transform.readonly }) {
474+ %0 = transform.structured.match ops {[" tensor.parallel_insert_slice" ]} in %arg0 : (!transform.any_op ) -> !transform.any_op
475+ %1 = transform.structured.match ops {[" scf.forall" ]} in %arg0 : (!transform.any_op ) -> !transform.any_op
476+ %consumer , %fused_consumer = transform.test.fuse_consumer %0 in (%1 ) : (!transform.any_op , !transform.any_op ) -> (!transform.any_op , !transform.any_op )
477+ transform.yield
478+ }
479+ }
480+ // CHECK: #[[MAP:.*]] = affine_map<(d0) -> (-d0 + 4, 16)>
481+ // CHECK: func.func @fuse_pack_consumer_if_single_iteration(
482+ // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
483+ // CHECK-DAG: %[[PACK_INIT:.*]] = tensor.empty() : tensor<1x4x16x1xf32>
484+ // CHECK-DAG: %[[ELEM_INIT:.*]] = tensor.empty() : tensor<4x4xf32>
485+ // CHECK-DAG: %[[PAD_VAL:.*]] = arith.constant 0.000000e+00 : f32
486+ // CHECK: %{{.*}}:2 = scf.forall (%[[IV:.*]]) = (0) to (4) step (16)
487+ // CHECK-SAME: shared_outs(%[[ELEM_OUT_ARG:.*]] = %[[ELEM_INIT]], %[[PACK_OUT_ARG:.*]] = %[[PACK_INIT]])
488+ // CHECK-DAG: %[[SIZE:.+]] = affine.min #[[MAP]](%[[IV]])
489+ // CHECK-DAG: %[[ELEM_SRC:.*]] = tensor.extract_slice %[[ARG0]][%[[IV]], 0] [%[[SIZE]], 4] [1, 1]
490+ // CHECK-DAG: %[[ELEM_DEST:.*]] = tensor.extract_slice %[[ELEM_OUT_ARG]][%[[IV]], 0] [%[[SIZE]], 4] [1, 1]
491+ // CHECK: %[[ELEM:.*]] = linalg.exp
492+ // CHECK-SAME: ins(%[[ELEM_SRC]]
493+ // CHECK-SAME: outs(%[[ELEM_DEST]]
494+ // CHECK-DAG: %[[TILED_PACK_DEST:.*]] = tensor.extract_slice %[[PACK_OUT_ARG]][%[[IV]], 0, 0, 0] [1, 4, 16, 1] [1, 1, 1, 1]
495+ // CHECK: %[[PACK:.*]] = linalg.pack %[[ELEM]]
496+ // CHECK-SAME: padding_value(%[[PAD_VAL]] : f32)
497+ // CHECK-SAME: outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [16, 1]
498+ // CHECK-SAME: into %[[TILED_PACK_DEST]]
499+ // CHECK: scf.forall.in_parallel {
500+ // CHECK: tensor.parallel_insert_slice %[[ELEM]] into %[[ELEM_OUT_ARG]][%[[IV]], 0] [%[[SIZE]], 4] [1, 1]
501+ // CHECK: tensor.parallel_insert_slice %[[PACK]] into %[[PACK_OUT_ARG]][%[[IV]], 0, 0, 0] [1, 4, 16, 1] [1, 1, 1, 1]
502+
503+ // -----
454504
455505func.func @fuse_perfect_tiling_pack_consumer_with_outer_dims_perm (%arg0: tensor <64 x32 xf32 >, %arg1: tensor <64 x32 xf32 >, %arg2: tensor <2 x64 x16 x1 xf32 >) -> tensor <2 x64 x16 x1 xf32 > {
456506 %0 = scf.forall (%arg3 ) = (0 ) to (32 ) step (16 ) shared_outs (%arg4 = %arg1 ) -> (tensor <64 x32 xf32 >) {
0 commit comments