@@ -395,7 +395,7 @@ module attributes {transform.with_named_sequence} {
395395
396396#map = affine_map <(d0 , d1 ) -> (d0 , d1 )>
397397module {
398- func.func @fuse_pack_consumer_into_scf_forall (%arg0: tensor <32 x32 xf32 >, %arg1: tensor <32 x32 xf32 >, %arg2: tensor <64 x32 xf32 >) -> tensor <4 x32 x16 xf32 > {
398+ func.func @fuse_perfect_tiling_pack_consumer (%arg0: tensor <32 x32 xf32 >, %arg1: tensor <32 x32 xf32 >, %arg2: tensor <64 x32 xf32 >) -> tensor <4 x32 x16 xf32 > {
399399 %c4 = arith.constant 4 : index
400400 %c64 = arith.constant 64 : index
401401 %c0 = arith.constant 0 : index
@@ -429,7 +429,7 @@ module attributes {transform.with_named_sequence} {
429429 }
430430}
431431// CHECK: #[[PACK_RESULT_MAP:.*]] = affine_map<(d0) -> (d0 floordiv 16)>
432- // CHECK: func.func @fuse_pack_consumer_into_scf_forall (
432+ // CHECK: func.func @fuse_perfect_tiling_pack_consumer (
433433// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<32x32xf32>
434434// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<32x32xf32>
435435// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<64x32xf32>)
@@ -451,7 +451,10 @@ module attributes {transform.with_named_sequence} {
451451
452452// -----
453453
454- func.func @fuse_pack_consumer_with_padding_semantic_into_scf_forall (%arg0: tensor <64 x32 xf32 >, %arg1: tensor <64 x32 xf32 >) -> tensor <23 x2 x3 x16 xf32 > {
454+ // It is valid to fuse the pack op with padding semantics if the dimension does
455+ // not need padding.
456+
457+ func.func @fuse_pack_consumer_with_padding_semantics (%arg0: tensor <64 x32 xf32 >, %arg1: tensor <64 x32 xf32 >) -> tensor <23 x2 x3 x16 xf32 > {
455458 %0 = scf.forall (%arg2 ) = (0 ) to (32 ) step (16 ) shared_outs (%arg3 = %arg1 ) -> (tensor <64 x32 xf32 >) {
456459 %src = tensor.extract_slice %arg0 [0 , %arg2 ] [64 , 16 ] [1 , 1 ] : tensor <64 x32 xf32 > to tensor <64 x16 xf32 >
457460 %dest = tensor.extract_slice %arg3 [0 , %arg2 ] [64 , 16 ] [1 , 1 ] : tensor <64 x32 xf32 > to tensor <64 x16 xf32 >
@@ -475,7 +478,7 @@ module attributes {transform.with_named_sequence} {
475478 }
476479}
477480// CHECK: #[[PACK_RESULT_MAP:.*]] = affine_map<(d0) -> (d0 floordiv 16)>
478- // CHECK: func.func @fuse_pack_consumer_with_padding_semantic_into_scf_forall (
481+ // CHECK: func.func @fuse_pack_consumer_with_padding_semantics (
479482// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
480483// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
481484// CHECK-DAG: %[[OUT_INIT:.*]] = tensor.empty() : tensor<23x2x3x16xf32>
@@ -488,18 +491,72 @@ module attributes {transform.with_named_sequence} {
488491// CHECK-SAME: ins(%[[ELEM_SRC]]
489492// CHECK-SAME: outs(%[[ELEM_DEST]]
490493// CHECK-DAG: %[[PACK_RESULT_OFFSET:.*]] = affine.apply #[[PACK_RESULT_MAP]](%[[IV]])
491- // CHECK-DAG: %[[TILED_PACK_DEST:.*]] = tensor.extract_slice %[[PACK_OUT_ARG]][0, %[[PACK_RESULT_OFFSET]], 0, 0] [22 , 1, 3, 16] [1, 1, 1, 1]
494+ // CHECK-DAG: %[[TILED_PACK_DEST:.*]] = tensor.extract_slice %[[PACK_OUT_ARG]][0, %[[PACK_RESULT_OFFSET]], 0, 0] [23 , 1, 3, 16] [1, 1, 1, 1]
492495// CHECK: %[[TILED_PACK_OUT:.*]] = linalg.pack %[[ELEM]]
493496// CHECK-SAME: padding_value(%[[PAD_VAL]] : f32)
494497// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [3, 16]
495498// CHECK-SAME: into %[[TILED_PACK_DEST]]
496499// CHECK: scf.forall.in_parallel {
497500// CHECK: tensor.parallel_insert_slice %[[GENERIC_OUT]] into %[[FIRST_OUT_ARG]][0, %[[IV]]] [64, 16] [1, 1]
498- // CHECK: tensor.parallel_insert_slice %[[TILED_PACK_OUT]] into %[[PACK_OUT_ARG]][0, %[[PACK_RESULT_OFFSET]], 0, 0] [22 , 1, 3, 16] [1, 1, 1, 1]
501+ // CHECK: tensor.parallel_insert_slice %[[TILED_PACK_OUT]] into %[[PACK_OUT_ARG]][0, %[[PACK_RESULT_OFFSET]], 0, 0] [23 , 1, 3, 16] [1, 1, 1, 1]
499502
500503// -----
501504
502- func.func @nofuse_pack_consumer_with_padding_semantic_into_scf_forall (%arg0: tensor <64 x32 xf32 >, %arg1: tensor <64 x32 xf32 >) -> tensor <23 x32 x3 x16 xf32 > {
505+ // It is valid to fuse the pack if the dimension is not tiled even when it needs
506+ // extra padding.
507+
508+ func.func @fuse_pack_consumer_with_untiled_extra_padding (%arg0: tensor <64 x32 xf32 >, %arg1: tensor <64 x32 xf32 >) -> tensor <33 x2 x3 x16 xf32 > {
509+ %0 = scf.forall (%arg2 ) = (0 ) to (32 ) step (16 ) shared_outs (%arg3 = %arg1 ) -> (tensor <64 x32 xf32 >) {
510+ %src = tensor.extract_slice %arg0 [0 , %arg2 ] [64 , 16 ] [1 , 1 ] : tensor <64 x32 xf32 > to tensor <64 x16 xf32 >
511+ %dest = tensor.extract_slice %arg3 [0 , %arg2 ] [64 , 16 ] [1 , 1 ] : tensor <64 x32 xf32 > to tensor <64 x16 xf32 >
512+ %2 = linalg.exp ins (%src : tensor <64 x16 xf32 >) outs (%dest : tensor <64 x16 xf32 >) -> tensor <64 x16 xf32 >
513+ scf.forall.in_parallel {
514+ tensor.parallel_insert_slice %2 into %arg3 [0 , %arg2 ] [64 , 16 ] [1 , 1 ] : tensor <64 x16 xf32 > into tensor <64 x32 xf32 >
515+ }
516+ }
517+ %1 = tensor.empty () : tensor <33 x2 x3 x16 xf32 >
518+ %cst = arith.constant 0.000000e+00 : f32
519+ %pack = linalg.pack %0 padding_value (%cst : f32 ) inner_dims_pos = [0 , 1 ] inner_tiles = [3 , 16 ] into %1 : tensor <64 x32 xf32 > -> tensor <33 x2 x3 x16 xf32 >
520+ return %pack : tensor <33 x2 x3 x16 xf32 >
521+ }
522+
523+ module attributes {transform.with_named_sequence } {
524+ transform.named_sequence @__transform_main (%arg0: !transform.any_op {transform.readonly }) {
525+ %0 = transform.structured.match ops {[" tensor.parallel_insert_slice" ]} in %arg0 : (!transform.any_op ) -> !transform.any_op
526+ %1 = transform.structured.match ops {[" scf.forall" ]} in %arg0 : (!transform.any_op ) -> !transform.any_op
527+ %consumer , %fused_consumer = transform.test.fuse_consumer %0 in (%1 ) : (!transform.any_op , !transform.any_op ) -> (!transform.any_op , !transform.any_op )
528+ transform.yield
529+ }
530+ }
531+ // CHECK: #[[PACK_RESULT_MAP:.*]] = affine_map<(d0) -> (d0 floordiv 16)>
532+ // CHECK: func.func @fuse_pack_consumer_with_untiled_extra_padding(
533+ // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
534+ // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
535+ // CHECK-DAG: %[[OUT_INIT:.*]] = tensor.empty() : tensor<33x2x3x16xf32>
536+ // CHECK-DAG: %[[PAD_VAL:.*]] = arith.constant 0.000000e+00 : f32
537+ // CHECK: %{{.*}}:2 = scf.forall (%[[IV:.*]]) = (0) to (32) step (16)
538+ // CHECK-SAME: shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG1]], %[[PACK_OUT_ARG:.*]] = %[[OUT_INIT]])
539+ // CHECK: %[[ELEM_SRC:.*]] = tensor.extract_slice %[[ARG0]][0, %[[IV]]] [64, 16] [1, 1]
540+ // CHECK: %[[ELEM_DEST:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][0, %[[IV]]] [64, 16] [1, 1]
541+ // CHECK: %[[ELEM:.*]] = linalg.exp
542+ // CHECK-SAME: ins(%[[ELEM_SRC]]
543+ // CHECK-SAME: outs(%[[ELEM_DEST]]
544+ // CHECK-DAG: %[[PACK_RESULT_OFFSET:.*]] = affine.apply #[[PACK_RESULT_MAP]](%[[IV]])
545+ // CHECK-DAG: %[[TILED_PACK_DEST:.*]] = tensor.extract_slice %[[PACK_OUT_ARG]][0, %[[PACK_RESULT_OFFSET]], 0, 0] [33, 1, 3, 16] [1, 1, 1, 1]
546+ // CHECK: %[[TILED_PACK_OUT:.*]] = linalg.pack %[[ELEM]]
547+ // CHECK-SAME: padding_value(%[[PAD_VAL]] : f32)
548+ // CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [3, 16]
549+ // CHECK-SAME: into %[[TILED_PACK_DEST]]
550+ // CHECK: scf.forall.in_parallel {
551+ // CHECK: tensor.parallel_insert_slice %[[GENERIC_OUT]] into %[[FIRST_OUT_ARG]][0, %[[IV]]] [64, 16] [1, 1]
552+ // CHECK: tensor.parallel_insert_slice %[[TILED_PACK_OUT]] into %[[PACK_OUT_ARG]][0, %[[PACK_RESULT_OFFSET]], 0, 0] [33, 1, 3, 16] [1, 1, 1, 1]
553+
554+ // -----
555+
556+ // If the dimension is tiled and it needs extra padding, do not fuse the pack
557+ // op.
558+
559+ func.func @nofuse_pack_consumer_with_extra_padding (%arg0: tensor <64 x32 xf32 >, %arg1: tensor <64 x32 xf32 >) -> tensor <23 x32 x3 x16 xf32 > {
503560 %0 = scf.forall (%arg2 ) = (0 ) to (32 ) step (16 ) shared_outs (%arg3 = %arg1 ) -> (tensor <64 x32 xf32 >) {
504561 %src = tensor.extract_slice %arg0 [0 , %arg2 ] [64 , 16 ] [1 , 1 ] : tensor <64 x32 xf32 > to tensor <64 x16 xf32 >
505562 %dest = tensor.extract_slice %arg3 [0 , %arg2 ] [64 , 16 ] [1 , 1 ] : tensor <64 x32 xf32 > to tensor <64 x16 xf32 >
@@ -526,6 +583,42 @@ module attributes {transform.with_named_sequence} {
526583
527584// -----
528585
586+ // Imperfect tiling is not supported in pack op consumer fusion.
587+
588+ #map = affine_map <(d0 ) -> (d0 * 5 )>
589+ #map1 = affine_map <(d0 ) -> (d0 )>
590+ func.func @nofuse_pack_with_imperfect_tiling (%arg0: tensor <30 xf32 >) -> tensor <5 x6 xf32 > {
591+ %0 = tensor.empty () : tensor <30 xf32 >
592+ %1 = scf.forall (%arg1 ) in (6 ) shared_outs (%arg2 = %0 ) -> (tensor <30 xf32 >) {
593+ %3 = affine.apply #map (%arg1 )
594+ %extracted_slice = tensor.extract_slice %arg0 [%3 ] [5 ] [1 ] : tensor <30 xf32 > to tensor <5 xf32 >
595+ %extracted_slice_0 = tensor.extract_slice %arg2 [%3 ] [5 ] [1 ] : tensor <30 xf32 > to tensor <5 xf32 >
596+ %4 = linalg.generic {index ing_maps = [#map1 , #map1 ], iterator_types = [" parallel" ]} ins (%extracted_slice : tensor <5 xf32 >) outs (%extracted_slice_0 : tensor <5 xf32 >) {
597+ ^bb0 (%in: f32 , %out: f32 ):
598+ %5 = arith.addf %in , %in : f32
599+ linalg.yield %5 : f32
600+ } -> tensor <5 xf32 >
601+ scf.forall.in_parallel {
602+ // expected-error @below {{failed to fuse consumer of slice}}
603+ tensor.parallel_insert_slice %4 into %arg2 [%3 ] [5 ] [1 ] : tensor <5 xf32 > into tensor <30 xf32 >
604+ }
605+ }
606+ %2 = tensor.empty () : tensor <5 x6 xf32 >
607+ %pack = linalg.pack %1 outer_dims_perm = [0 ] inner_dims_pos = [0 ] inner_tiles = [6 ] into %2 : tensor <30 xf32 > -> tensor <5 x6 xf32 >
608+ return %pack : tensor <5 x6 xf32 >
609+ }
610+
611+ module attributes {transform.with_named_sequence } {
612+ transform.named_sequence @__transform_main (%arg0: !transform.any_op {transform.readonly }) {
613+ %0 = transform.structured.match ops {[" tensor.parallel_insert_slice" ]} in %arg0 : (!transform.any_op ) -> !transform.any_op
614+ %1 = transform.structured.match ops {[" scf.forall" ]} in %arg0 : (!transform.any_op ) -> !transform.any_op
615+ %consumer , %fused_consumer = transform.test.fuse_consumer %0 in (%1 ) : (!transform.any_op , !transform.any_op ) -> (!transform.any_op , !transform.any_op )
616+ transform.yield
617+ }
618+ }
619+
620+ // -----
621+
529622module {
530623 func.func @fuse_add_multiple_tilable_consumers (%arg0: tensor <256 x256 xf32 >, %arg1: tensor <256 x256 xf32 >, %arg2: tensor <256 x256 xf32 >) -> (tensor <256 x256 xf32 >, tensor <256 x256 xf32 >) {
531624 %c0 = arith.constant 0 : index
0 commit comments