@@ -19,13 +19,14 @@ func.func @simple_KCRS_to_KCRSsr(%arg0: tensor<1x1x32x8xf32>, %arg1: tensor<1x1x
1919
2020// -----
2121
22- func.func @simple_pad_and_pack (%input: tensor <5 x1 xf32 >, %output: tensor <1 x1 x8 x2 xf32 >, %pad: f32 ) -> tensor <1 x1 x8 x2 xf32 > {
22+ func.func @simple_pad_and_pack_static_tiles (%input: tensor <5 x1 xf32 >, %output: tensor <1 x1 x8 x2 xf32 >, %pad: f32 ) -> tensor <1 x1 x8 x2 xf32 > {
2323 %0 = tensor.pack %input padding_value (%pad : f32 ) inner_dims_pos = [0 , 1 ] inner_tiles = [8 , 2 ] into %output : tensor <5 x1 xf32 > -> tensor <1 x1 x8 x2 xf32 >
2424 return %0 : tensor <1 x1 x8 x2 xf32 >
2525}
2626// CHECK: #[[$ATTR_0:.+]] = affine_map<()[s0] -> (s0 - 5)>
27+ // CHECK: #[[$ATTR_1:.+]] = affine_map<()[s0] -> (s0 - 1)>
2728
28- // CHECK-LABEL: func.func @simple_pad_and_pack
29+ // CHECK-LABEL: func.func @simple_pad_and_pack_static_tiles
2930// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
3031// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
3132// CHECK-SAME: %[[PAD_VAL:[a-zA-Z0-9]+]]
@@ -36,18 +37,18 @@ func.func @simple_pad_and_pack(%input: tensor<5x1xf32>, %output: tensor<1x1x8x2x
3637// CHECK-SAME: [0, 0, 0, 0] [1, 1, 8, 2] [1, 1, 1, 1]
3738// CHECK: return %[[INSERT]]
3839
39- /// Same as example above, but with dynamic tile size.
40+ /// Same as example above, but with 1 dynamic tile size.
4041
41- func.func @simple_pad_and_pack_dynamic (%input: tensor <5 x1 xf32 >, %output: tensor <1 x1 x?x2 xf32 >, %pad: f32 , %high: index ) -> tensor <1 x1 x?x2 xf32 > {
42+ func.func @simple_pad_and_pack_dynamic_tile (%input: tensor <5 x1 xf32 >, %output: tensor <1 x1 x?x2 xf32 >, %pad: f32 , %high: index ) -> tensor <1 x1 x?x2 xf32 > {
4243 %0 = tensor.pack %input padding_value (%pad : f32 ) inner_dims_pos = [0 , 1 ] inner_tiles = [%high , 2 ] into %output : tensor <5 x1 xf32 > -> tensor <1 x1 x?x2 xf32 >
4344 return %0 : tensor <1 x1 x?x2 xf32 >
4445}
4546
46- // CHECK-LABEL: func.func @simple_pad_and_pack_dynamic (
47+ // CHECK-LABEL: func.func @simple_pad_and_pack_dynamic_tile (
4748// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
4849// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
4950// CHECK-SAME: %[[PAD_VAL:[a-zA-Z0-9]+]]
50- // CHECK-SAME: %[[HIGH_VAL:.* ]]: index) -> tensor<1x1x?x2xf32> {
51+ // CHECK-SAME: %[[HIGH_VAL:[a-zA-Z0-9]+ ]]: index) -> tensor<1x1x?x2xf32> {
5152// CHECK: %[[C2:.*]] = arith.constant 2 : index
5253// CHECK: %[[PAD_HIGH:.*]] = affine.apply #[[$ATTR_0]](){{\[}}%[[HIGH_VAL]]]
5354// CHECK: %[[PAD:.*]] = tensor.pad %[[SRC]] low[0, 0] high{{\[}}%[[PAD_HIGH]], 1] {
@@ -58,21 +59,21 @@ func.func @simple_pad_and_pack_dynamic(%input: tensor<5x1xf32>, %output: tensor<
5859// CHECK: %[[RES:.*]] = tensor.insert_slice %[[SLICE]] into %[[DEST]][0, 0, 0, 0] [1, 1, %[[DIM]], 2] [1, 1, 1, 1] : tensor<?x2xf32> into tensor<1x1x?x2xf32>
5960// CHECK: return %[[RES]] : tensor<1x1x?x2xf32>
6061
61- /// Same as example above, but with scalable tile size.
62+ /// Same as example above, but with 1 scalable tile size.
6263
6364/// NOTE: For this example to make sense in practice, the "?" in the output shape
6465/// should effectively be 8 * vector.vscale (and that's what tensor.dim
6566/// below should return).
6667
67- func.func @simple_pad_and_pack_scalable (%input: tensor <5 x1 xf32 >, %output: tensor <1 x1 x?x2 xf32 >, %pad: f32 ) -> tensor <1 x1 x?x2 xf32 > {
68+ func.func @simple_pad_and_pack_scalable_tile (%input: tensor <5 x1 xf32 >, %output: tensor <1 x1 x?x2 xf32 >, %pad: f32 ) -> tensor <1 x1 x?x2 xf32 > {
6869 %c8 = arith.constant 8 : index
6970 %vscale = vector.vscale
7071 %c8_vscale = arith.muli %vscale , %c8 : index
7172 %0 = tensor.pack %input padding_value (%pad : f32 ) inner_dims_pos = [0 , 1 ] inner_tiles = [%c8_vscale , 2 ] into %output : tensor <5 x1 xf32 > -> tensor <1 x1 x?x2 xf32 >
7273 return %0 : tensor <1 x1 x?x2 xf32 >
7374}
7475
75- // CHECK-LABEL: func.func @simple_pad_and_pack_scalable (
76+ // CHECK-LABEL: func.func @simple_pad_and_pack_scalable_tile (
7677// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]: tensor<5x1xf32>,
7778// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]: tensor<1x1x?x2xf32>,
7879// CHECK-SAME: %[[PAD_VAL:[a-zA-Z0-9]+]]: f32) -> tensor<1x1x?x2xf32> {
@@ -89,6 +90,31 @@ func.func @simple_pad_and_pack_scalable(%input: tensor<5x1xf32>, %output: tensor
8990// CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICE]] into %[[DEST]][0, 0, 0, 0] [1, 1, %[[DIM]], 2] [1, 1, 1, 1] : tensor<?x2xf32> into tensor<1x1x?x2xf32>
9091// CHECK: return %[[RES]] : tensor<1x1x?x2xf32>
9192
93+ /// Same as example above, but with both tile sizes dynamic.
94+
95+ func.func @simple_pad_and_pack_dynamic_tiles (%input: tensor <5 x1 xf32 >, %output: tensor <1 x1 x?x?xf32 >, %pad: f32 , %high_1: index , %high_2: index ) -> tensor <1 x1 x?x?xf32 > {
96+ %0 = tensor.pack %input padding_value (%pad : f32 ) inner_dims_pos = [0 , 1 ] inner_tiles = [%high_1 , %high_2 ] into %output : tensor <5 x1 xf32 > -> tensor <1 x1 x?x?xf32 >
97+ return %0 : tensor <1 x1 x?x?xf32 >
98+ }
99+ // CHECK-LABEL: func.func @simple_pad_and_pack_dynamic_tiles(
100+ // CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]: tensor<5x1xf32>,
101+ // CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]: tensor<1x1x?x?xf32>,
102+ // CHECK-SAME: %[[PAD_VAL:[a-zA-Z0-9]+]]: f32,
103+ // CHECK-SAME: %[[HIGH_VAL_1:[a-zA-Z0-9]+]]: index,
104+ // CHECK-SAME: %[[HIGH_VAL_2:[a-zA-Z0-9]+]]: index) -> tensor<1x1x?x?xf32> {
105+ // CHECK: %[[C3:.*]] = arith.constant 3 : index
106+ // CHECK: %[[C2:.*]] = arith.constant 2 : index
107+ // CHECK: %[[PAD_HIGH_1:.*]] = affine.apply #[[$ATTR_0]](){{\[}}%[[HIGH_VAL_1]]]
108+ // CHECK: %[[PAD_HIGH_2:.*]] = affine.apply #[[$ATTR_1]](){{\[}}%[[HIGH_VAL_2]]]
109+ // CHECK: %[[PAD:.*]] = tensor.pad %[[SRC]] low[0, 0] high{{\[}}%[[PAD_HIGH_1]], %[[PAD_HIGH_2]]] {
110+ // CHECK: tensor.yield %[[PAD_VAL]] : f32
111+ // CHECK-NOT: linalg.transpose
112+ // CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[PAD:.*]][0, 0] {{\[}}%[[HIGH_VAL_1]], %[[HIGH_VAL_2]]] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
113+ // CHECK: %[[DIM_1:.*]] = tensor.dim %[[DEST]], %[[C2]] : tensor<1x1x?x?xf32>
114+ // CHECK: %[[DIM_2:.*]] = tensor.dim %[[DEST]], %[[C3]] : tensor<1x1x?x?xf32>
115+ // CHECK: %[[RES:.*]] = tensor.insert_slice %[[SLICE]] into %[[DEST]][0, 0, 0, 0] [1, 1, %[[DIM_1]], %[[DIM_2]]] [1, 1, 1, 1] : tensor<?x?xf32> into tensor<1x1x?x?xf32>
116+ // CHECK: return %[[RES]] : tensor<1x1x?x?xf32>
117+
92118// -----
93119
94120func.func @simple_NC_to_CNnc (%arg0: tensor <32 x8 xf32 >, %arg1: tensor <1 x1 x32 x8 xf32 >) -> tensor <1 x1 x32 x8 xf32 >{
0 commit comments