@@ -451,6 +451,51 @@ module attributes {transform.with_named_sequence} {
451451
452452// ----- 
453453
454+ 
455+ func.func  @fuse_perfect_tiling_pack_consumer_with_outer_dims_perm (%arg0:  tensor <64 x32 xf32 >, %arg1:  tensor <64 x32 xf32 >, %arg2:  tensor <2 x64 x16 x1 xf32 >) -> tensor <2 x64 x16 x1 xf32 > {
456+   %0  = scf.forall  (%arg3 ) = (0 ) to  (32 ) step  (16 ) shared_outs (%arg4  = %arg1 ) -> (tensor <64 x32 xf32 >) {
457+     %src  = tensor.extract_slice  %arg0 [0 , %arg3 ] [64 , 16 ] [1 , 1 ] : tensor <64 x32 xf32 > to  tensor <64 x16 xf32 >
458+     %dest  = tensor.extract_slice  %arg4 [0 , %arg3 ] [64 , 16 ] [1 , 1 ] : tensor <64 x32 xf32 > to  tensor <64 x16 xf32 >
459+     %1  = linalg.exp  ins (%src  : tensor <64 x16 xf32 >) outs (%dest  : tensor <64 x16 xf32 >) -> tensor <64 x16 xf32 >
460+     scf.forall.in_parallel  {
461+       tensor.parallel_insert_slice  %1  into  %arg4 [0 , %arg3 ] [64 , 16 ] [1 , 1 ] : tensor <64 x16 xf32 > into  tensor <64 x32 xf32 >
462+     }
463+   }
464+   %pack  = linalg.pack  %0  outer_dims_perm  = [1 , 0 ] inner_dims_pos  = [1 , 0 ] inner_tiles  = [16 , 1 ] into  %arg2  : tensor <64 x32 xf32 > -> tensor <2 x64 x16 x1 xf32 >
465+   return  %pack  : tensor <2 x64 x16 x1 xf32 >
466+ }
467+ 
468+ module  attributes  {transform.with_named_sequence } {
469+   transform.named_sequence  @__transform_main (%arg0:  !transform.any_op  {transform.readonly }) {
470+     %0  = transform.structured.match  ops {[" tensor.parallel_insert_slice"  ]} in  %arg0  : (!transform.any_op ) -> !transform.any_op 
471+     %1  = transform.structured.match  ops {[" scf.forall"  ]} in  %arg0  : (!transform.any_op ) -> !transform.any_op 
472+     %consumer , %fused_consumer  = transform.test.fuse_consumer  %0  in (%1 ) : (!transform.any_op , !transform.any_op ) -> (!transform.any_op , !transform.any_op )
473+     transform.yield 
474+   }
475+ }
476+ //      CHECK: #[[PACK_RESULT_MAP:.*]] = affine_map<(d0) -> (d0 floordiv 16)> 
477+ //      CHECK: func.func @fuse_perfect_tiling_pack_consumer_with_outer_dims_perm( 
478+ // CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]] 
479+ // CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]] 
480+ // CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]] 
481+ //      CHECK:   %{{.*}}:2 = scf.forall (%[[IV:.*]]) = (0) to (32) step (16) 
482+ // CHECK-SAME:      shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG1]], %[[PACK_OUT_ARG:.*]] = %[[ARG2]]) 
483+ //      CHECK:      %[[ELEM_SRC:.*]] = tensor.extract_slice %[[ARG0]][0, %[[IV]]] [64, 16] [1, 1] 
484+ //      CHECK:      %[[ELEM_DEST:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][0, %[[IV]]] [64, 16] [1, 1] 
485+ //      CHECK:      %[[ELEM:.*]] = linalg.exp 
486+ // CHECK-SAME:        ins(%[[ELEM_SRC]] 
487+ // CHECK-SAME:        outs(%[[ELEM_DEST]] 
488+ //  CHECK-DAG:      %[[PACK_RESULT_OFFSET:.*]] = affine.apply #[[PACK_RESULT_MAP]](%[[IV]]) 
489+ //  CHECK-DAG:      %[[TILED_PACK_DEST:.*]] = tensor.extract_slice %[[PACK_OUT_ARG]][%[[PACK_RESULT_OFFSET]], 0, 0, 0] [1, 64, 16, 1] [1, 1, 1, 1] 
490+ //      CHECK:      %[[PACK:.*]] = linalg.pack %[[ELEM]] 
491+ // CHECK-SAME:        outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [16, 1] 
492+ // CHECK-SAME:        into %[[TILED_PACK_DEST]] 
493+ //      CHECK:      scf.forall.in_parallel { 
494+ //      CHECK:          tensor.parallel_insert_slice %[[ELEM]] into %[[FIRST_OUT_ARG]][0, %[[IV]]] [64, 16] [1, 1] 
495+ //      CHECK:          tensor.parallel_insert_slice %[[PACK]] into %[[PACK_OUT_ARG]][%[[PACK_RESULT_OFFSET]], 0, 0, 0] [1, 64, 16, 1] [1, 1, 1, 1] 
496+ 
497+ // ----- 
498+ 
454499// It is valid to fuse the pack op in perfect tiling scenario when the dimension 
455500// is dynamic and padding is not needed. 
456501
0 commit comments