@@ -1772,3 +1772,78 @@ func.func @fold_cast_unpack_dynamic_tile_size(
17721772 into %res {test_attr } : tensor <1 x1 x?x1 xi32 > -> tensor <7 x?xi32 >
17731773 return %unpack : tensor <7 x?xi32 >
17741774}
1775+
1776+ // -----
1777+
1778+ //===----------------------------------------------------------------------===//
1779+ // linalg.unpack + tensor.extract_slice
1780+ //===----------------------------------------------------------------------===//
1781+
1782+ func.func @fold_extract_slice_into_unpack (
1783+ %src : tensor <28 x2 x?x16 x16 xf32 >, %dest : tensor <28 x32 x?xf32 >, %size : index
1784+ ) -> tensor <28 x28 x?xf32 > {
1785+ %unpack = linalg.unpack %src
1786+ outer_dims_perm = [0 , 1 , 2 ]
1787+ inner_dims_pos = [1 , 2 ]
1788+ inner_tiles = [16 , 16 ]
1789+ into %dest : tensor <28 x2 x?x16 x16 xf32 > -> tensor <28 x32 x?xf32 >
1790+ %extracted_slice = tensor.extract_slice %unpack
1791+ [0 , 0 , 0 ] [28 , 28 , %size ] [1 , 1 , 1 ] : tensor <28 x32 x?xf32 > to tensor <28 x28 x?xf32 >
1792+ return %extracted_slice : tensor <28 x28 x?xf32 >
1793+ }
1794+
1795+ // CHECK-LABEL: func @fold_extract_slice_into_unpack
1796+ // CHECK-SAME: %[[SRC:.+]]: tensor<28x2x?x16x16xf32>
1797+ // CHECK-SAME: %[[DEST:.+]]: tensor<28x32x?xf32>
1798+ // CHECK-SAME: %[[SIZE:.+]]: index
1799+ // CHECK: %[[DEST_SLICE:.+]] = tensor.extract_slice %[[DEST]]
1800+ // CHECK-SAME: [0, 0, 0] [28, 28, %[[SIZE]]] [1, 1, 1]
1801+ // CHECK: %[[UNPACK:.+]] = linalg.unpack %[[SRC]]
1802+ // CHECK-SAME: into %[[DEST_SLICE]]
1803+ // CHECK: return %[[UNPACK]]
1804+
1805+ // -----
1806+
1807+ func.func @no_fold_extract_slice_into_unpack_rank_reducing (
1808+ %src : tensor <28 x2 x16 xf32 >, %dest : tensor <28 x32 xf32 >
1809+ ) -> tensor <28 xf32 > {
1810+ %unpack = linalg.unpack %src
1811+ outer_dims_perm = [0 , 1 ]
1812+ inner_dims_pos = [1 ]
1813+ inner_tiles = [16 ]
1814+ into %dest : tensor <28 x2 x16 xf32 > -> tensor <28 x32 xf32 >
1815+ %extracted_slice = tensor.extract_slice %unpack
1816+ [0 , 0 ] [1 , 28 ] [1 , 1 ] : tensor <28 x32 xf32 > to tensor <28 xf32 >
1817+ return %extracted_slice : tensor <28 xf32 >
1818+ }
1819+
1820+ // CHECK-LABEL: func @no_fold_extract_slice_into_unpack_rank_reducing
1821+ // CHECK-SAME: %[[SRC:.+]]: tensor<28x2x16xf32>
1822+ // CHECK-SAME: %[[DEST:.+]]: tensor<28x32xf32>
1823+ // CHECK: %[[UNPACK:.+]] = linalg.unpack %[[SRC]]
1824+ // CHECK-SAME: into %[[DEST]]
1825+ // CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[UNPACK]]
1826+ // CHECK: return %[[SLICE]]
1827+
1828+ // -----
1829+
1830+ func.func @no_fold_extract_slice_into_unpack_non_zero_offset (
1831+ %src : tensor <28 x2 x16 xf32 >, %dest : tensor <28 x32 xf32 >
1832+ ) -> tensor <28 x28 xf32 > {
1833+ %unpack = linalg.unpack %src
1834+ outer_dims_perm = [0 , 1 ]
1835+ inner_dims_pos = [1 ]
1836+ inner_tiles = [16 ]
1837+ into %dest : tensor <28 x2 x16 xf32 > -> tensor <28 x32 xf32 >
1838+ %extracted_slice = tensor.extract_slice %unpack
1839+ [0 , 1 ] [28 , 28 ] [1 , 1 ] : tensor <28 x32 xf32 > to tensor <28 x28 xf32 >
1840+ return %extracted_slice : tensor <28 x28 xf32 >
1841+ }
1842+
1843+ // CHECK-LABEL: func @no_fold_extract_slice_into_unpack_non_zero_offset
1844+ // CHECK-SAME: %[[SRC:.+]]: tensor<28x2x16xf32>
1845+ // CHECK-SAME: %[[DEST:.+]]: tensor<28x32xf32>
1846+ // CHECK: %[[UNPACK:.+]] = linalg.unpack %[[SRC]]
1847+ // CHECK-SAME: into %[[DEST]]
1848+ // CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[UNPACK]]
1849+ // CHECK: return %[[SLICE]]
0 commit comments