@@ -451,6 +451,51 @@ module attributes {transform.with_named_sequence} {
451451
452452// -----
453453
454+
455+ func.func @fuse_perfect_tiling_pack_consumer_with_outer_dims_perm (%arg0: tensor <64 x32 xf32 >, %arg1: tensor <64 x32 xf32 >, %arg2: tensor <2 x64 x16 x1 xf32 >) -> tensor <2 x64 x16 x1 xf32 > {
456+ %0 = scf.forall (%arg3 ) = (0 ) to (32 ) step (16 ) shared_outs (%arg4 = %arg1 ) -> (tensor <64 x32 xf32 >) {
457+ %src = tensor.extract_slice %arg0 [0 , %arg3 ] [64 , 16 ] [1 , 1 ] : tensor <64 x32 xf32 > to tensor <64 x16 xf32 >
458+ %dest = tensor.extract_slice %arg4 [0 , %arg3 ] [64 , 16 ] [1 , 1 ] : tensor <64 x32 xf32 > to tensor <64 x16 xf32 >
459+ %1 = linalg.exp ins (%src : tensor <64 x16 xf32 >) outs (%dest : tensor <64 x16 xf32 >) -> tensor <64 x16 xf32 >
460+ scf.forall.in_parallel {
461+ tensor.parallel_insert_slice %1 into %arg4 [0 , %arg3 ] [64 , 16 ] [1 , 1 ] : tensor <64 x16 xf32 > into tensor <64 x32 xf32 >
462+ }
463+ }
464+ %pack = linalg.pack %0 outer_dims_perm = [1 , 0 ] inner_dims_pos = [1 , 0 ] inner_tiles = [16 , 1 ] into %arg2 : tensor <64 x32 xf32 > -> tensor <2 x64 x16 x1 xf32 >
465+ return %pack : tensor <2 x64 x16 x1 xf32 >
466+ }
467+
468+ module attributes {transform.with_named_sequence } {
469+ transform.named_sequence @__transform_main (%arg0: !transform.any_op {transform.readonly }) {
470+ %0 = transform.structured.match ops {[" tensor.parallel_insert_slice" ]} in %arg0 : (!transform.any_op ) -> !transform.any_op
471+ %1 = transform.structured.match ops {[" scf.forall" ]} in %arg0 : (!transform.any_op ) -> !transform.any_op
472+ %consumer , %fused_consumer = transform.test.fuse_consumer %0 in (%1 ) : (!transform.any_op , !transform.any_op ) -> (!transform.any_op , !transform.any_op )
473+ transform.yield
474+ }
475+ }
476+ // CHECK: #[[PACK_RESULT_MAP:.*]] = affine_map<(d0) -> (d0 floordiv 16)>
477+ // CHECK: func.func @fuse_perfect_tiling_pack_consumer_with_outer_dims_perm(
478+ // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
479+ // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
480+ // CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]
481+ // CHECK: %{{.*}}:2 = scf.forall (%[[IV:.*]]) = (0) to (32) step (16)
482+ // CHECK-SAME: shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG1]], %[[PACK_OUT_ARG:.*]] = %[[ARG2]])
483+ // CHECK: %[[ELEM_SRC:.*]] = tensor.extract_slice %[[ARG0]][0, %[[IV]]] [64, 16] [1, 1]
484+ // CHECK: %[[ELEM_DEST:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][0, %[[IV]]] [64, 16] [1, 1]
485+ // CHECK: %[[ELEM:.*]] = linalg.exp
486+ // CHECK-SAME: ins(%[[ELEM_SRC]]
487+ // CHECK-SAME: outs(%[[ELEM_DEST]]
488+ // CHECK-DAG: %[[PACK_RESULT_OFFSET:.*]] = affine.apply #[[PACK_RESULT_MAP]](%[[IV]])
489+ // CHECK-DAG: %[[TILED_PACK_DEST:.*]] = tensor.extract_slice %[[PACK_OUT_ARG]][%[[PACK_RESULT_OFFSET]], 0, 0, 0] [1, 64, 16, 1] [1, 1, 1, 1]
490+ // CHECK: %[[PACK:.*]] = linalg.pack %[[ELEM]]
491+ // CHECK-SAME: outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [16, 1]
492+ // CHECK-SAME: into %[[TILED_PACK_DEST]]
493+ // CHECK: scf.forall.in_parallel {
494+ // CHECK: tensor.parallel_insert_slice %[[ELEM]] into %[[FIRST_OUT_ARG]][0, %[[IV]]] [64, 16] [1, 1]
495+ // CHECK: tensor.parallel_insert_slice %[[PACK]] into %[[PACK_OUT_ARG]][%[[PACK_RESULT_OFFSET]], 0, 0, 0] [1, 64, 16, 1] [1, 1, 1, 1]
496+
497+ // -----
498+
454499// It is valid to fuse the pack op in perfect tiling scenario when the dimension
455500// is dynamic and padding is not needed.
456501
0 commit comments