@@ -524,22 +524,11 @@ func.func @unpack_element_type_change(%arg0: tensor<12x2x56x56x32xf32>, %init: t
524524// CHECK-LABEL: func.func @unpack_element_type_change
525525// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
526526// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
527- // CHECK: %[[ARG0_UNPACK_EMPTY:.+]] = tensor.empty() : tensor<12x56x56x64xf32>
528- // CHECK: %[[UNPACKED_ARG0:.+]] = linalg.unpack %[[ARG0]]
529- // CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
530- // CHECK-SAME: into %[[ARG0_UNPACK_EMPTY]]
531- // CHECK: %[[ARG1_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf16>
532- // CHECK: %[[ARG1_PACK:.+]] = linalg.pack %[[ARG1]]
533- // CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
534- // CHECK-SAME: into %[[ARG1_PACK_EMPTY]]
535- // CHECK: %[[ARG0_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
536- // CHECK: %[[ARG0_PACK:.+]] = linalg.pack %[[UNPACKED_ARG0]]
537- // CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
538- // CHECK-SAME: into %[[ARG0_PACK_EMPTY]]
527+ // CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf16>
539528// CHECK: %[[RES:.+]] = linalg.generic
540529// CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP]]]
541- // CHECK-SAME: ins(%[[ARG0_PACK ]]
542- // CHECK-SAME: outs(%[[ARG1_PACK ]]
530+ // CHECK-SAME: ins(%[[ARG0 ]]
531+ // CHECK-SAME: outs(%[[EMPTY ]]
543532// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[RES]]
544533// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
545534// CHECK-SAME: into %[[ARG1]]
@@ -564,19 +553,11 @@ func.func @forward_tensor_empty(%arg0: tensor<12x2x56x56x32xf32>) -> tensor<12x5
564553// CHECK-LABEL: func.func @forward_tensor_empty
565554// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
566555// CHECK: %[[FINAL_RES:.+]] = tensor.empty() : tensor<12x56x56x64xf32>
567- // CHECK: %[[ARG0_UNPACK_EMPTY:.+]] = tensor.empty() : tensor<12x56x56x64xf32>
568- // CHECK: %[[UNPACKED_ARG0:.+]] = linalg.unpack %[[ARG0]]
569- // CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
570- // CHECK-SAME: into %[[ARG0_UNPACK_EMPTY]]
571- // CHECK: %[[DEST:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
572- // CHECK: %[[ARG0_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
573- // CHECK: %[[PACKED_ARG0:.+]] = linalg.pack %[[UNPACKED_ARG0]]
574- // CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
575- // CHECK-SAME: into %[[ARG0_PACK_EMPTY]]
556+ // CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
576557// CHECK: %[[RES:.+]] = linalg.generic
577558// CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP]]]
578- // CHECK-SAME: ins(%[[PACKED_ARG0 ]]
579- // CHECK-SAME: outs(%[[DEST ]]
559+ // CHECK-SAME: ins(%[[ARG0 ]]
560+ // CHECK-SAME: outs(%[[EMPTY ]]
580561// CHECK: %[[UNPACKED:.+]] = linalg.unpack %[[RES]]
581562// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
582563// CHECK-SAME: into %[[FINAL_RES]]
@@ -810,12 +791,9 @@ func.func @unpack_empty_inner_dims(%arg0: tensor<12x64x56x56xf32>) -> tensor<12x
810791}
811792
812793// CHECK-LABEL: func.func @unpack_empty_inner_dims
813- // CHECK: %[[UNPACKED_ARG0:.+]] = linalg.unpack
814- // CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = []
815- // CHECK: %[[PACKED_ARG0:.+]] = linalg.pack %[[UNPACKED_ARG0]]
816- // CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = []
794+ // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<12x64x56x56xf32>)
817795// CHECK: %[[RES:.+]] = linalg.generic
818- // CHECK-SAME: ins(%[[PACKED_ARG0 ]]
796+ // CHECK-SAME: ins(%[[ARG0 ]]
819797// CHECK: %[[UNPACKED:.+]] = linalg.unpack %[[RES]]
820798// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = []
821799
@@ -943,14 +921,10 @@ func.func @unpack_different_destination_shape(%arg0: tensor<1x1x1080x1920x16xi32
943921// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
944922// CHECK: %[[FINAL_RES:.+]] = tensor.empty() : tensor<16x540x960xi32>
945923// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x540x960x16xi32>
946- // CHECK: %[[PACK_EMPTY:.+]] = tensor.empty() : tensor<1x1x1080x1920x16xi32>
947- // CHECK: %[[PACK_ARG0:.+]] = linalg.pack
948- // CHECK-SAME: inner_dims_pos = [1] inner_tiles = [16]
949- // CHECK-SAME: into %[[PACK_EMPTY]]
950924// CHECK: %[[POOL:.+]] = linalg.generic
951925// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]]
952926// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "parallel"]
953- // CHECK-SAME: ins(%[[PACK_ARG0 ]], %[[ARG1]]
927+ // CHECK-SAME: ins(%[[ARG0 ]], %[[ARG1]]
954928// CHECK-SAME: outs(%[[INIT]]
955929// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[POOL]]
956930// CHECK-SAME: inner_dims_pos = [0] inner_tiles = [16]
@@ -1421,3 +1395,27 @@ func.func @no_push_down_unpack_through_non_divisible_expand(%5: tensor<384x32x8x
14211395// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[ARG0]]
14221396// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[UNPACK]] {{\[}}[0, 1], [2]] output_shape [256, 12, 256] : tensor<3072x256xf32> into tensor<256x12x256xf32>
14231397// CHECK: return %[[EXPANDED]] : tensor<256x12x256xf32>
1398+
1399+ // -----
1400+
1401+ #map = affine_map <(d0 , d1 ) -> (d0 , d1 )>
1402+ func.func @fold_unpack_pack_after_bubble_up (%arg0: tensor <8 x8 x4 x8 xf32 >) -> tensor <8 x8 x4 x8 xf32 > {
1403+ %empty = tensor.empty () : tensor <32 x64 xf32 >
1404+ %unpack = linalg.unpack %arg0 inner_dims_pos = [0 , 1 ] inner_tiles = [4 , 8 ] into %empty : tensor <8 x8 x4 x8 xf32 > -> tensor <32 x64 xf32 >
1405+ %1 = linalg.generic {index ing_maps = [#map , #map ], iterator_types = [" parallel" , " parallel" ]} ins (%unpack : tensor <32 x64 xf32 >) outs (%empty : tensor <32 x64 xf32 >) {
1406+ ^bb0 (%in: f32 , %out: f32 ):
1407+ %2 = arith.addf %in , %in : f32
1408+ linalg.yield %2 : f32
1409+ } -> tensor <32 x64 xf32 >
1410+ %empty1 = tensor.empty () : tensor <8 x8 x4 x8 xf32 >
1411+ %pack = linalg.pack %1 inner_dims_pos = [0 , 1 ] inner_tiles = [4 , 8 ] into %empty1 : tensor <32 x64 xf32 > -> tensor <8 x8 x4 x8 xf32 >
1412+ return %pack : tensor <8 x8 x4 x8 xf32 >
1413+ }
1414+
1415+ // CHECK-LABEL: func.func @fold_unpack_pack_after_bubble_up
1416+ // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
1417+ // CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x8x4x8xf32>
1418+ // CHECK: %[[GENERIC:.+]] = linalg.generic
1419+ // CHECK-SAME: ins(%[[ARG0]] : tensor<8x8x4x8xf32>)
1420+ // CHECK-SAME: outs(%[[EMPTY]] : tensor<8x8x4x8xf32>)
1421+ // CHECK: return %[[GENERIC]] : tensor<8x8x4x8xf32>
0 commit comments