@@ -646,6 +646,87 @@ module attributes {transform.with_named_sequence} {
646646
647647// -----
648648
649+ // It is valid to fuse the pack if the dimension is not tiled even when it needs
650+ // extra padding.
651+
652+ func.func @fuse_pack_consumer_with_untiled_extra_padding (%arg0: tensor <64 x32 xf32 >, %arg1: tensor <64 x32 xf32 >) -> tensor <33 x2 x3 x16 xf32 > {
653+ %0 = scf.forall (%arg2 ) = (0 ) to (32 ) step (16 ) shared_outs (%arg3 = %arg1 ) -> (tensor <64 x32 xf32 >) {
654+ %src = tensor.extract_slice %arg0 [0 , %arg2 ] [64 , 16 ] [1 , 1 ] : tensor <64 x32 xf32 > to tensor <64 x16 xf32 >
655+ %dest = tensor.extract_slice %arg3 [0 , %arg2 ] [64 , 16 ] [1 , 1 ] : tensor <64 x32 xf32 > to tensor <64 x16 xf32 >
656+ %2 = linalg.exp ins (%src : tensor <64 x16 xf32 >) outs (%dest : tensor <64 x16 xf32 >) -> tensor <64 x16 xf32 >
657+ scf.forall.in_parallel {
658+ tensor.parallel_insert_slice %2 into %arg3 [0 , %arg2 ] [64 , 16 ] [1 , 1 ] : tensor <64 x16 xf32 > into tensor <64 x32 xf32 >
659+ }
660+ }
661+ %1 = tensor.empty () : tensor <33 x2 x3 x16 xf32 >
662+ %cst = arith.constant 0.000000e+00 : f32
663+ %pack = linalg.pack %0 padding_value (%cst : f32 ) inner_dims_pos = [0 , 1 ] inner_tiles = [3 , 16 ] into %1 : tensor <64 x32 xf32 > -> tensor <33 x2 x3 x16 xf32 >
664+ return %pack : tensor <33 x2 x3 x16 xf32 >
665+ }
666+
667+ module attributes {transform.with_named_sequence } {
668+ transform.named_sequence @__transform_main (%arg0: !transform.any_op {transform.readonly }) {
669+ %0 = transform.structured.match ops {[" tensor.parallel_insert_slice" ]} in %arg0 : (!transform.any_op ) -> !transform.any_op
670+ %1 = transform.structured.match ops {[" scf.forall" ]} in %arg0 : (!transform.any_op ) -> !transform.any_op
671+ %consumer , %fused_consumer = transform.test.fuse_consumer %0 in (%1 ) : (!transform.any_op , !transform.any_op ) -> (!transform.any_op , !transform.any_op )
672+ transform.yield
673+ }
674+ }
675+ // CHECK: #[[PACK_RESULT_MAP:.*]] = affine_map<(d0) -> (d0 floordiv 16)>
676+ // CHECK: func.func @fuse_pack_consumer_with_untiled_extra_padding(
677+ // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
678+ // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
679+ // CHECK-DAG: %[[OUT_INIT:.*]] = tensor.empty() : tensor<33x2x3x16xf32>
680+ // CHECK-DAG: %[[PAD_VAL:.*]] = arith.constant 0.000000e+00 : f32
681+ // CHECK: %{{.*}}:2 = scf.forall (%[[IV:.*]]) = (0) to (32) step (16)
682+ // CHECK-SAME: shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG1]], %[[PACK_OUT_ARG:.*]] = %[[OUT_INIT]])
683+ // CHECK: %[[ELEM_SRC:.*]] = tensor.extract_slice %[[ARG0]][0, %[[IV]]] [64, 16] [1, 1]
684+ // CHECK: %[[ELEM_DEST:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][0, %[[IV]]] [64, 16] [1, 1]
685+ // CHECK: %[[ELEM:.*]] = linalg.exp
686+ // CHECK-SAME: ins(%[[ELEM_SRC]]
687+ // CHECK-SAME: outs(%[[ELEM_DEST]]
688+ // CHECK-DAG: %[[PACK_RESULT_OFFSET:.*]] = affine.apply #[[PACK_RESULT_MAP]](%[[IV]])
689+ // CHECK-DAG: %[[TILED_PACK_DEST:.*]] = tensor.extract_slice %[[PACK_OUT_ARG]][0, %[[PACK_RESULT_OFFSET]], 0, 0] [33, 1, 3, 16] [1, 1, 1, 1]
690+ // CHECK: %[[TILED_PACK_OUT:.*]] = linalg.pack %[[ELEM]]
691+ // CHECK-SAME: padding_value(%[[PAD_VAL]] : f32)
692+ // CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [3, 16]
693+ // CHECK-SAME: into %[[TILED_PACK_DEST]]
694+ // CHECK: scf.forall.in_parallel {
695+ // CHECK: tensor.parallel_insert_slice %[[GENERIC_OUT]] into %[[FIRST_OUT_ARG]][0, %[[IV]]] [64, 16] [1, 1]
696+ // CHECK: tensor.parallel_insert_slice %[[TILED_PACK_OUT]] into %[[PACK_OUT_ARG]][0, %[[PACK_RESULT_OFFSET]], 0, 0] [33, 1, 3, 16] [1, 1, 1, 1]
697+
698+ // -----
699+
700+ // If the dimension is tiled and it needs extra padding, do not fuse the pack
701+ // op.
702+
703+ func.func @nofuse_pack_consumer_with_extra_padding (%arg0: tensor <64 x32 xf32 >, %arg1: tensor <64 x32 xf32 >) -> tensor <23 x32 x3 x16 xf32 > {
704+ %0 = scf.forall (%arg2 ) = (0 ) to (32 ) step (16 ) shared_outs (%arg3 = %arg1 ) -> (tensor <64 x32 xf32 >) {
705+ %src = tensor.extract_slice %arg0 [0 , %arg2 ] [64 , 16 ] [1 , 1 ] : tensor <64 x32 xf32 > to tensor <64 x16 xf32 >
706+ %dest = tensor.extract_slice %arg3 [0 , %arg2 ] [64 , 16 ] [1 , 1 ] : tensor <64 x32 xf32 > to tensor <64 x16 xf32 >
707+ %2 = linalg.exp ins (%src : tensor <64 x16 xf32 >) outs (%dest : tensor <64 x16 xf32 >) -> tensor <64 x16 xf32 >
708+ scf.forall.in_parallel {
709+ // expected-error @below {{failed to fuse consumer of slice}}
710+ tensor.parallel_insert_slice %2 into %arg3 [0 , %arg2 ] [64 , 16 ] [1 , 1 ] : tensor <64 x16 xf32 > into tensor <64 x32 xf32 >
711+ }
712+ }
713+ %1 = tensor.empty () : tensor <23 x32 x3 x16 xf32 >
714+ %cst = arith.constant 0.000000e+00 : f32
715+ %pack = linalg.pack %0 padding_value (%cst : f32 ) inner_dims_pos = [0 , 1 ] inner_tiles = [3 , 16 ] into %1 : tensor <64 x32 xf32 > -> tensor <23 x32 x3 x16 xf32 >
716+ return %pack : tensor <23 x32 x3 x16 xf32 >
717+ }
718+
719+ module attributes {transform.with_named_sequence } {
720+ transform.named_sequence @__transform_main (%arg0: !transform.any_op {transform.readonly }) {
721+ %0 = transform.structured.match ops {[" tensor.parallel_insert_slice" ]} in %arg0 : (!transform.any_op ) -> !transform.any_op
722+ %1 = transform.structured.match ops {[" scf.forall" ]} in %arg0 : (!transform.any_op ) -> !transform.any_op
723+ %consumer , %fused_consumer = transform.test.fuse_consumer %0 in (%1 ) : (!transform.any_op , !transform.any_op ) -> (!transform.any_op , !transform.any_op )
724+ transform.yield
725+ }
726+ }
727+
728+ // -----
729+
649730// Imperfect tiling is not supported in pack op consumer fusion.
650731
651732#map = affine_map <(d0 ) -> (d0 * 5 )>
0 commit comments