@@ -451,6 +451,55 @@ module attributes {transform.with_named_sequence} {
451451
452452// -----
453453
454+ // It is valid to fuse the pack op in perfect tiling scenario when the dimension
455+ // is dynamic and padding is not needed.
456+
457+ func.func @fuse_pack_consumer_with_no_pad_dynamic_dim (%arg0: tensor <64 x?xf32 >, %arg1: tensor <64 x?xf32 >, %1: tensor <64 x?x16 xf32 >) -> tensor <64 x?x16 xf32 > {
458+ %c1 = arith.constant 1 : index
459+ %d1 = tensor.dim %arg0 , %c1 : tensor <64 x?xf32 >
460+ %0 = scf.forall (%arg2 ) = (0 ) to (%d1 ) step (16 ) shared_outs (%arg3 = %arg1 ) -> (tensor <64 x?xf32 >) {
461+ %src = tensor.extract_slice %arg0 [0 , %arg2 ] [64 , 16 ] [1 , 1 ] : tensor <64 x?xf32 > to tensor <64 x16 xf32 >
462+ %dest = tensor.extract_slice %arg3 [0 , %arg2 ] [64 , 16 ] [1 , 1 ] : tensor <64 x?xf32 > to tensor <64 x16 xf32 >
463+ %2 = linalg.exp ins (%src : tensor <64 x16 xf32 >) outs (%dest : tensor <64 x16 xf32 >) -> tensor <64 x16 xf32 >
464+ scf.forall.in_parallel {
465+ tensor.parallel_insert_slice %2 into %arg3 [0 , %arg2 ] [64 , 16 ] [1 , 1 ] : tensor <64 x16 xf32 > into tensor <64 x?xf32 >
466+ }
467+ }
468+ %pack = linalg.pack %0 inner_dims_pos = [1 ] inner_tiles = [16 ] into %1 : tensor <64 x?xf32 > -> tensor <64 x?x16 xf32 >
469+ return %pack : tensor <64 x?x16 xf32 >
470+ }
471+
472+ module attributes {transform.with_named_sequence } {
473+ transform.named_sequence @__transform_main (%arg0: !transform.any_op {transform.readonly }) {
474+ %0 = transform.structured.match ops {[" tensor.parallel_insert_slice" ]} in %arg0 : (!transform.any_op ) -> !transform.any_op
475+ %1 = transform.structured.match ops {[" scf.forall" ]} in %arg0 : (!transform.any_op ) -> !transform.any_op
476+ %consumer , %fused_consumer = transform.test.fuse_consumer %0 in (%1 ) : (!transform.any_op , !transform.any_op ) -> (!transform.any_op , !transform.any_op )
477+ transform.yield
478+ }
479+ }
480+ // CHECK: #[[PACK_RESULT_MAP:.*]] = affine_map<(d0) -> (d0 floordiv 16)>
481+ // CHECK: func.func @fuse_pack_consumer_with_no_pad_dynamic_dim(
482+ // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
483+ // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
484+ // CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]
485+ // CHECK: %{{.*}}:2 = scf.forall (%[[IV:.*]]) = (0) to (%{{.+}}) step (16)
486+ // CHECK-SAME: shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG1]], %[[PACK_OUT_ARG:.*]] = %[[ARG2]])
487+ // CHECK: %[[ELEM_SRC:.*]] = tensor.extract_slice %[[ARG0]][0, %[[IV]]] [64, 16] [1, 1]
488+ // CHECK: %[[ELEM_DEST:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][0, %[[IV]]] [64, 16] [1, 1]
489+ // CHECK: %[[ELEM:.*]] = linalg.exp
490+ // CHECK-SAME: ins(%[[ELEM_SRC]]
491+ // CHECK-SAME: outs(%[[ELEM_DEST]]
492+ // CHECK-DAG: %[[PACK_RESULT_OFFSET:.*]] = affine.apply #[[PACK_RESULT_MAP]](%[[IV]])
493+ // CHECK-DAG: %[[TILED_PACK_DEST:.*]] = tensor.extract_slice %[[PACK_OUT_ARG]][0, %[[PACK_RESULT_OFFSET]], 0] [64, 1, 16] [1, 1, 1]
494+ // CHECK: %[[PACK:.*]] = linalg.pack %[[ELEM]]
495+ // CHECK-SAME: inner_dims_pos = [1] inner_tiles = [16]
496+ // CHECK-SAME: into %[[TILED_PACK_DEST]]
497+ // CHECK: scf.forall.in_parallel {
498+ // CHECK: tensor.parallel_insert_slice %[[ELEM]] into %[[FIRST_OUT_ARG]][0, %[[IV]]] [64, 16] [1, 1]
499+ // CHECK: tensor.parallel_insert_slice %[[PACK]] into %[[PACK_OUT_ARG]][0, %[[PACK_RESULT_OFFSET]], 0] [64, 1, 16] [1, 1, 1]
500+
501+ // -----
502+
454503// It is valid to fuse the pack op with padding semantics if the dimension does
455504// not need padding.
456505
0 commit comments