@@ -96,6 +96,34 @@ module attributes {transform.with_named_sequence} {
9696
9797// -----
9898
99+ // This is same as pack_as_pad but since we explicitly added {lowerPadLikeWithInsertSlice = false}, it should not
100+ // be lowered to insert_slice.
101+ // CHECK-LABEL: func.func @pack_disallowed_as_pad(
102+ func.func @pack_disallowed_as_pad (%arg0: tensor <129 x47 x16 x16 xf32 >, %arg1: tensor <1 x1 x1 x1 x136 x64 x16 x16 xf32 >) -> tensor <1 x1 x1 x1 x136 x64 x16 x16 xf32 > {
103+ %cst_0 = arith.constant 0.0 : f32
104+ // tensor.pack is lowered to tensor.pad + tensor.expand_shape + linalg.transpose
105+ // CHECK-SAME: %[[ARG0:[^:]*]]: tensor<129x47x16x16xf32>
106+ // CHECK: %[[PAD:.*]] = tensor.pad %[[ARG0]]
107+ // CHECK-NOT: %[[RES:.*]] = tensor.insert_slice %[[PAD]]
108+ // CHECK: %[[PAD_EXPANDED:.*]] = tensor.expand_shape %[[PAD]]
109+ // CHECK: %[[RES:.*]] = linalg.transpose ins(%[[PAD_EXPANDED]]
110+ %pack = tensor.pack %arg0 padding_value (%cst_0 : f32 ) inner_dims_pos = [0 , 1 , 2 , 3 ] inner_tiles = [136 , 64 , 16 , 16 ] into %arg1
111+ : tensor <129 x47 x16 x16 xf32 > -> tensor <1 x1 x1 x1 x136 x64 x16 x16 xf32 >
112+ return %pack : tensor <1 x1 x1 x1 x136 x64 x16 x16 xf32 >
113+ }
114+
115+ module attributes {transform.with_named_sequence } {
116+ transform.named_sequence @__transform_main (%module_op: !transform.any_op {transform.readonly }) {
117+ %pack = transform.structured.match ops {[" tensor.pack" ]} in %module_op
118+ : (!transform.any_op ) -> !transform.op <" tensor.pack" >
119+ transform.structured.lower_pack %pack {lowerPadLikeWithInsertSlice = false }: (!transform.op <" tensor.pack" >)
120+ -> (!transform.op <" tensor.pad" >, !transform.op <" tensor.expand_shape" >, !transform.op <" linalg.transpose" >)
121+ transform.yield
122+ }
123+ }
124+
125+ // -----
126+
99127// Check that we don't lower the following pack as a pad.
100128// Although all the outer most dimensions in the resulting shape are 1s,
101129// some of the original dimensions are not part of the inner_dims_pos, hence
@@ -233,6 +261,38 @@ module attributes {transform.with_named_sequence} {
233261
234262// -----
235263
264+ // This is same as upack_as_pad but since we explicitly added {lowerUnpadLikeWithExtractSlice = false}, it should not
265+ // be lowered to extract_slice.
266+ // CHECK-LABEL: func.func @unpack_disallowed_as_pad(
267+ func.func @unpack_disallowed_as_pad (%arg0: tensor <1 x1 x1 x1 x136 x64 x16 x16 xf32 >, %arg1: tensor <129 x47 x16 x16 xf32 >) -> tensor <129 x47 x16 x16 xf32 > {
268+ %cst_0 = arith.constant 0.0 : f32
269+
270+ // tensor.unpack is lowered to tensor.extract_slice + linalg.transpose + tensor.collapse_shape
271+ // CHECK-SAME: %[[ARG0:[^:]*]]: tensor<1x1x1x1x136x64x16x16xf32>
272+ // CHECK-NOT: %[[RES:.*]] = tensor.extract_slice %[[ARG0]]
273+ // CHECK: %[[TRANSPOSED:.*]] = linalg.transpose ins(%[[ARG0]]
274+ // CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[TRANSPOSED]]
275+ // CHECK: %[[RES:.*]] = tensor.extract_slice %[[COLLAPSED]]
276+ %pack = tensor.unpack %arg0 inner_dims_pos = [0 , 1 , 2 , 3 ] inner_tiles = [136 , 64 , 16 , 16 ] into %arg1
277+ : tensor <1 x1 x1 x1 x136 x64 x16 x16 xf32 > -> tensor <129 x47 x16 x16 xf32 >
278+ return %pack : tensor <129 x47 x16 x16 xf32 >
279+ }
280+
281+ module attributes {transform.with_named_sequence } {
282+ transform.named_sequence @__transform_main (%module_op: !transform.any_op {transform.readonly }) {
283+ %unpack = transform.structured.match ops {[" tensor.unpack" ]} in %module_op
284+ : (!transform.any_op ) -> !transform.op <" tensor.unpack" >
285+ transform.structured.lower_unpack %unpack {lowerUnpadLikeWithExtractSlice = false }: (!transform.op <" tensor.unpack" >)
286+ -> (!transform.op <" tensor.empty" >,
287+ !transform.op <" linalg.transpose" >,
288+ !transform.op <" tensor.collapse_shape" >,
289+ !transform.op <" tensor.extract_slice" >)
290+ transform.yield
291+ }
292+ }
293+
294+ // -----
295+
236296// CHECK-LABEL: func.func @pack_with_outer_dims_perm(
237297func.func @pack_with_outer_dims_perm (%src: tensor <100 x200 x128 x256 xi32 >,
238298 %dest: tensor <200 x4 x16 x100 x16 x32 xi32 >)
0 commit comments