@@ -451,6 +451,56 @@ module attributes {transform.with_named_sequence} {
451
451
452
452
// -----
453
453
454
+ #map = affine_map <(d0 ) -> (-d0 + 4 , 16 )>
455
+ func.func @fuse_pack_consumer_if_single_iteration (%arg0: tensor <4 x4 xf32 >) -> tensor <1 x4 x16 x1 xf32 > {
456
+ %0 = tensor.empty () : tensor <1 x4 x16 x1 xf32 >
457
+ %1 = tensor.empty () : tensor <4 x4 xf32 >
458
+ %2 = scf.forall (%arg1 ) = (0 ) to (4 ) step (16 ) shared_outs (%arg2 = %1 ) -> (tensor <4 x4 xf32 >) {
459
+ %3 = affine.min #map (%arg1 )
460
+ %extracted_slice = tensor.extract_slice %arg0 [%arg1 , 0 ] [%3 , 4 ] [1 , 1 ] : tensor <4 x4 xf32 > to tensor <?x4 xf32 >
461
+ %extracted_slice_0 = tensor.extract_slice %arg2 [%arg1 , 0 ] [%3 , 4 ] [1 , 1 ] : tensor <4 x4 xf32 > to tensor <?x4 xf32 >
462
+ %4 = linalg.exp ins (%extracted_slice : tensor <?x4 xf32 >) outs (%extracted_slice_0 : tensor <?x4 xf32 >) -> tensor <?x4 xf32 >
463
+ scf.forall.in_parallel {
464
+ tensor.parallel_insert_slice %4 into %arg2 [%arg1 , 0 ] [%3 , 4 ] [1 , 1 ] : tensor <?x4 xf32 > into tensor <4 x4 xf32 >
465
+ }
466
+ }
467
+ %cst = arith.constant 0.000000e+00 : f32
468
+ %pack = linalg.pack %2 padding_value (%cst : f32 ) outer_dims_perm = [0 , 1 ] inner_dims_pos = [0 , 1 ] inner_tiles = [16 , 1 ] into %0 : tensor <4 x4 xf32 > -> tensor <1 x4 x16 x1 xf32 >
469
+ return %pack : tensor <1 x4 x16 x1 xf32 >
470
+ }
471
+
472
+ module attributes {transform.with_named_sequence } {
473
+ transform.named_sequence @__transform_main (%arg0: !transform.any_op {transform.readonly }) {
474
+ %0 = transform.structured.match ops {[" tensor.parallel_insert_slice" ]} in %arg0 : (!transform.any_op ) -> !transform.any_op
475
+ %1 = transform.structured.match ops {[" scf.forall" ]} in %arg0 : (!transform.any_op ) -> !transform.any_op
476
+ %consumer , %fused_consumer = transform.test.fuse_consumer %0 in (%1 ) : (!transform.any_op , !transform.any_op ) -> (!transform.any_op , !transform.any_op )
477
+ transform.yield
478
+ }
479
+ }
480
+ // CHECK: #[[MAP:.*]] = affine_map<(d0) -> (-d0 + 4, 16)>
481
+ // CHECK: func.func @fuse_pack_consumer_if_single_iteration(
482
+ // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
483
+ // CHECK-DAG: %[[PACK_INIT:.*]] = tensor.empty() : tensor<1x4x16x1xf32>
484
+ // CHECK-DAG: %[[ELEM_INIT:.*]] = tensor.empty() : tensor<4x4xf32>
485
+ // CHECK-DAG: %[[PAD_VAL:.*]] = arith.constant 0.000000e+00 : f32
486
+ // CHECK: %{{.*}}:2 = scf.forall (%[[IV:.*]]) = (0) to (4) step (16)
487
+ // CHECK-SAME: shared_outs(%[[ELEM_OUT_ARG:.*]] = %[[ELEM_INIT]], %[[PACK_OUT_ARG:.*]] = %[[PACK_INIT]])
488
+ // CHECK-DAG: %[[SIZE:.+]] = affine.min #[[MAP]](%[[IV]])
489
+ // CHECK-DAG: %[[ELEM_SRC:.*]] = tensor.extract_slice %[[ARG0]][%[[IV]], 0] [%[[SIZE]], 4] [1, 1]
490
+ // CHECK-DAG: %[[ELEM_DEST:.*]] = tensor.extract_slice %[[ELEM_OUT_ARG]][%[[IV]], 0] [%[[SIZE]], 4] [1, 1]
491
+ // CHECK: %[[ELEM:.*]] = linalg.exp
492
+ // CHECK-SAME: ins(%[[ELEM_SRC]]
493
+ // CHECK-SAME: outs(%[[ELEM_DEST]]
494
+ // CHECK-DAG: %[[TILED_PACK_DEST:.*]] = tensor.extract_slice %[[PACK_OUT_ARG]][%[[IV]], 0, 0, 0] [1, 4, 16, 1] [1, 1, 1, 1]
495
+ // CHECK: %[[PACK:.*]] = linalg.pack %[[ELEM]]
496
+ // CHECK-SAME: padding_value(%[[PAD_VAL]] : f32)
497
+ // CHECK-SAME: outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [16, 1]
498
+ // CHECK-SAME: into %[[TILED_PACK_DEST]]
499
+ // CHECK: scf.forall.in_parallel {
500
+ // CHECK: tensor.parallel_insert_slice %[[ELEM]] into %[[ELEM_OUT_ARG]][%[[IV]], 0] [%[[SIZE]], 4] [1, 1]
501
+ // CHECK: tensor.parallel_insert_slice %[[PACK]] into %[[PACK_OUT_ARG]][%[[IV]], 0, 0, 0] [1, 4, 16, 1] [1, 1, 1, 1]
502
+
503
+ // -----
454
504
455
505
func.func @fuse_perfect_tiling_pack_consumer_with_outer_dims_perm (%arg0: tensor <64 x32 xf32 >, %arg1: tensor <64 x32 xf32 >, %arg2: tensor <2 x64 x16 x1 xf32 >) -> tensor <2 x64 x16 x1 xf32 > {
456
506
%0 = scf.forall (%arg3 ) = (0 ) to (32 ) step (16 ) shared_outs (%arg4 = %arg1 ) -> (tensor <64 x32 xf32 >) {
0 commit comments