@@ -455,13 +455,10 @@ func.func @unpack_on_output(%arg0: tensor<12x2x56x56x32xf32>) -> tensor<12x56x56
455455// CHECK: %[[UNPACKED_ARG0:.+]] = linalg.unpack %[[ARG0]]
456456// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
457457// CHECK-SAME: into %[[ARG0_EMPTY_UNPACK]]
458- // CHECK: %[[ARG0_EMPTY_PACK:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
459- // CHECK: %[[PACKED_ARG0:.+]] = linalg.pack %[[UNPACKED_ARG0]]
460- // CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
461- // CHECK-SAME: into %[[ARG0_EMPTY_PACK]]
458+ // CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
462459// CHECK: %[[RES:.+]] = linalg.generic
463460// CHECK-SAME: indexing_maps = [#[[$MAP]]]
464- // CHECK-SAME: outs(%[[PACKED_ARG0 ]]
461+ // CHECK-SAME: outs(%[[EMPTY ]]
465462// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[RES]]
466463// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
467464// CHECK-SAME: into %[[UNPACKED_ARG0]]
@@ -485,22 +482,11 @@ func.func @unpack_on_input(%arg0: tensor<12x2x56x56x32xf32>, %init: tensor<12x56
485482// CHECK-LABEL: func.func @unpack_on_input
486483// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
487484// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
488- // CHECK: %[[ARG0_UNPACK_EMPTY:.+]] = tensor.empty() : tensor<12x56x56x64xf32>
489- // CHECK: %[[UNPACKED_ARG0:.+]] = linalg.unpack %[[ARG0]]
490- // CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
491- // CHECK-SAME: into %[[ARG0_UNPACK_EMPTY]]
492- // CHECK: %[[ARG1_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
493- // CHECK: %[[ARG1_PACK:.+]] = linalg.pack %[[ARG1]]
494- // CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
495- // CHECK-SAME: into %[[ARG1_PACK_EMPTY]]
496- // CHECK: %[[ARG0_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
497- // CHECK: %[[ARG0_PACK:.+]] = linalg.pack %[[UNPACKED_ARG0]]
498- // CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
499- // CHECK-SAME: into %[[ARG0_PACK_EMPTY]]
485+ // CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
500486// CHECK: %[[RES:.+]] = linalg.generic
501487// CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP]]]
502- // CHECK-SAME: ins(%[[ARG0_PACK ]]
503- // CHECK-SAME: outs(%[[ARG1_PACK ]]
488+ // CHECK-SAME: ins(%[[ARG0 ]]
489+ // CHECK-SAME: outs(%[[EMPTY ]]
504490// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[RES]]
505491// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
506492// CHECK-SAME: into %[[ARG1]]
@@ -524,22 +510,11 @@ func.func @unpack_element_type_change(%arg0: tensor<12x2x56x56x32xf32>, %init: t
524510// CHECK-LABEL: func.func @unpack_element_type_change
525511// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
526512// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
527- // CHECK: %[[ARG0_UNPACK_EMPTY:.+]] = tensor.empty() : tensor<12x56x56x64xf32>
528- // CHECK: %[[UNPACKED_ARG0:.+]] = linalg.unpack %[[ARG0]]
529- // CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
530- // CHECK-SAME: into %[[ARG0_UNPACK_EMPTY]]
531- // CHECK: %[[ARG1_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf16>
532- // CHECK: %[[ARG1_PACK:.+]] = linalg.pack %[[ARG1]]
533- // CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
534- // CHECK-SAME: into %[[ARG1_PACK_EMPTY]]
535- // CHECK: %[[ARG0_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
536- // CHECK: %[[ARG0_PACK:.+]] = linalg.pack %[[UNPACKED_ARG0]]
537- // CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
538- // CHECK-SAME: into %[[ARG0_PACK_EMPTY]]
513+ // CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf16>
539514// CHECK: %[[RES:.+]] = linalg.generic
540515// CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP]]]
541- // CHECK-SAME: ins(%[[ARG0_PACK ]]
542- // CHECK-SAME: outs(%[[ARG1_PACK ]]
516+ // CHECK-SAME: ins(%[[ARG0 ]]
517+ // CHECK-SAME: outs(%[[EMPTY ]]
543518// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[RES]]
544519// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
545520// CHECK-SAME: into %[[ARG1]]
@@ -564,19 +539,11 @@ func.func @forward_tensor_empty(%arg0: tensor<12x2x56x56x32xf32>) -> tensor<12x5
564539// CHECK-LABEL: func.func @forward_tensor_empty
565540// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
566541// CHECK: %[[FINAL_RES:.+]] = tensor.empty() : tensor<12x56x56x64xf32>
567- // CHECK: %[[ARG0_UNPACK_EMPTY:.+]] = tensor.empty() : tensor<12x56x56x64xf32>
568- // CHECK: %[[UNPACKED_ARG0:.+]] = linalg.unpack %[[ARG0]]
569- // CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
570- // CHECK-SAME: into %[[ARG0_UNPACK_EMPTY]]
571- // CHECK: %[[DEST:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
572- // CHECK: %[[ARG0_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
573- // CHECK: %[[PACKED_ARG0:.+]] = linalg.pack %[[UNPACKED_ARG0]]
574- // CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
575- // CHECK-SAME: into %[[ARG0_PACK_EMPTY]]
542+ // CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
576543// CHECK: %[[RES:.+]] = linalg.generic
577544// CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP]]]
578- // CHECK-SAME: ins(%[[PACKED_ARG0 ]]
579- // CHECK-SAME: outs(%[[DEST ]]
545+ // CHECK-SAME: ins(%[[ARG0 ]]
546+ // CHECK-SAME: outs(%[[EMPTY ]]
580547// CHECK: %[[UNPACKED:.+]] = linalg.unpack %[[RES]]
581548// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
582549// CHECK-SAME: into %[[FINAL_RES]]
@@ -810,12 +777,9 @@ func.func @unpack_empty_inner_dims(%arg0: tensor<12x64x56x56xf32>) -> tensor<12x
810777}
811778
812779// CHECK-LABEL: func.func @unpack_empty_inner_dims
813- // CHECK: %[[UNPACKED_ARG0:.+]] = linalg.unpack
814- // CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = []
815- // CHECK: %[[PACKED_ARG0:.+]] = linalg.pack %[[UNPACKED_ARG0]]
816- // CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = []
780+ // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<12x64x56x56xf32>)
817781// CHECK: %[[RES:.+]] = linalg.generic
818- // CHECK-SAME: ins(%[[PACKED_ARG0 ]]
782+ // CHECK-SAME: ins(%[[ARG0 ]]
819783// CHECK: %[[UNPACKED:.+]] = linalg.unpack %[[RES]]
820784// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = []
821785
@@ -943,14 +907,10 @@ func.func @unpack_different_destination_shape(%arg0: tensor<1x1x1080x1920x16xi32
943907// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
944908// CHECK: %[[FINAL_RES:.+]] = tensor.empty() : tensor<16x540x960xi32>
945909// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x540x960x16xi32>
946- // CHECK: %[[PACK_EMPTY:.+]] = tensor.empty() : tensor<1x1x1080x1920x16xi32>
947- // CHECK: %[[PACK_ARG0:.+]] = linalg.pack
948- // CHECK-SAME: inner_dims_pos = [1] inner_tiles = [16]
949- // CHECK-SAME: into %[[PACK_EMPTY]]
950910// CHECK: %[[POOL:.+]] = linalg.generic
951911// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]]
952912// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "parallel"]
953- // CHECK-SAME: ins(%[[PACK_ARG0 ]], %[[ARG1]]
913+ // CHECK-SAME: ins(%[[ARG0 ]], %[[ARG1]]
954914// CHECK-SAME: outs(%[[INIT]]
955915// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[POOL]]
956916// CHECK-SAME: inner_dims_pos = [0] inner_tiles = [16]
@@ -1421,3 +1381,48 @@ func.func @no_push_down_unpack_through_non_divisible_expand(%5: tensor<384x32x8x
14211381// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[ARG0]]
14221382// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[UNPACK]] {{\[}}[0, 1], [2]] output_shape [256, 12, 256] : tensor<3072x256xf32> into tensor<256x12x256xf32>
14231383// CHECK: return %[[EXPANDED]] : tensor<256x12x256xf32>
1384+
1385+ // -----
1386+
1387+ func.func @push_unpack_in_padded_domain_foldable (%arg0: tensor <8 x8 x4 x8 xf32 >, %dest: tensor <?x64 xf32 >, %arg1: tensor <?x64 xbf16 >) -> tensor <?x64 xbf16 > {
1388+ %unpack = linalg.unpack %arg0 inner_dims_pos = [0 , 1 ] inner_tiles = [4 , 8 ] into %dest : tensor <8 x8 x4 x8 xf32 > -> tensor <?x64 xf32 >
1389+ %0 = linalg.generic {index ing_maps = [affine_map <(d0 , d1 ) -> (d0 , d1 )>, affine_map <(d0 , d1 ) -> (d0 , d1 )>], iterator_types = [" parallel" , " parallel" ]} ins (%unpack : tensor <?x64 xf32 >) outs (%arg1 : tensor <?x64 xbf16 >) {
1390+ ^bb0 (%in: f32 , %out: bf16 ):
1391+ %1 = arith.truncf %in : f32 to bf16
1392+ linalg.yield %1 : bf16
1393+ } -> tensor <?x64 xbf16 >
1394+ return %0 : tensor <?x64 xbf16 >
1395+ }
1396+ // CHECK-LABEL: func.func @push_unpack_in_padded_domain_foldable
1397+ // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
1398+ // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
1399+ // CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]
1400+ // CHECK: %[[EMPTY:.+]] = tensor.empty
1401+ // CHECK: %[[GENERIC:.+]] = linalg.generic
1402+ // CHECK-SAME: ins(%[[ARG0]] : tensor<8x8x4x8xf32>)
1403+ // CHECK-SAME: outs(%[[EMPTY]] : tensor<?x8x4x8xbf16>)
1404+ // CHECK: %[[UNPACK:.+]] = linalg.unpack %[[GENERIC]]
1405+ // CHECK-SAME: into %[[ARG2]]
1406+ // CHECK: return %[[UNPACK]] : tensor<?x64xbf16>
1407+
1408+ // -----
1409+
1410+ func.func @push_unpack_in_padded_domain_out_used (%arg0: tensor <8 x8 x4 x8 xf32 >, %arg1: tensor <?x64 xf32 >) -> tensor <?x64 xf32 > {
1411+ %unpack = linalg.unpack %arg0 inner_dims_pos = [0 , 1 ] inner_tiles = [4 , 8 ] into %arg1 : tensor <8 x8 x4 x8 xf32 > -> tensor <?x64 xf32 >
1412+ %0 = linalg.generic {index ing_maps = [affine_map <(d0 , d1 ) -> (d0 , d1 )>, affine_map <(d0 , d1 ) -> (d0 , d1 )>], iterator_types = [" parallel" , " parallel" ]} ins (%unpack : tensor <?x64 xf32 >) outs (%arg1 : tensor <?x64 xf32 >) {
1413+ ^bb0 (%in: f32 , %out: f32 ):
1414+ %1 = arith.addf %in , %out : f32
1415+ linalg.yield %1 : f32
1416+ } -> tensor <?x64 xf32 >
1417+ return %0 : tensor <?x64 xf32 >
1418+ }
1419+ // CHECK-LABEL: func.func @push_unpack_in_padded_domain_out_used
1420+ // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
1421+ // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
1422+ // CHECK: %[[EMPTY:.+]] = tensor.empty
1423+ // CHECK: %[[GENERIC:.+]] = linalg.generic
1424+ // CHECK-SAME: ins(%[[ARG0]] : tensor<8x8x4x8xf32>)
1425+ // CHECK-SAME: outs(%[[EMPTY]] : tensor<?x8x4x8xf32>)
1426+ // CHECK: %[[UNPACK2:.+]] = linalg.unpack %[[GENERIC]]
1427+ // CHECK-SAME: into %[[ARG1]]
1428+ // CHECK: return %[[UNPACK2]] : tensor<?x64xf32>
0 commit comments