@@ -114,6 +114,74 @@ module attributes {transform.with_named_sequence} {
114114 }
115115}
116116
117+ // -----
118+
119+ func.func private @make_vector () -> vector <7 x9 xf32 >
120+
121+ // Negative test - low pad is non-zero
122+
123+ // CHECK-LABEL: func @pad_and_transfer_write_static_non_zero_low_pad
124+ // CHECK: tensor.pad
125+ func.func @pad_and_transfer_write_static_non_zero_low_pad (
126+ %arg0: tensor <5 x6 xf32 >) -> tensor <5 x6 xf32 > {
127+ %c0 = arith.constant 0 : index
128+ %c5 = arith.constant 5.0 : f32
129+ %0 = tensor.pad %arg0 low [0 , 1 ] high [5 , 6 ] {
130+ ^bb0 (%arg2: index , %arg3: index ):
131+ tensor.yield %c5 : f32
132+ } : tensor <5 x6 xf32 > to tensor <10 x13 xf32 >
133+ %1 = call @make_vector () : () -> vector <7 x9 xf32 >
134+ %2 = vector.transfer_write %1 , %0 [%c0 , %c0 ]
135+ : vector <7 x9 xf32 >, tensor <10 x13 xf32 >
136+ %3 = tensor.extract_slice %2 [0 , 0 ] [5 , 6 ] [1 , 1 ] : tensor <10 x13 xf32 > to tensor <5 x6 xf32 >
137+ return %3 : tensor <5 x6 xf32 >
138+ }
139+
140+ module attributes {transform.with_named_sequence } {
141+ transform.named_sequence @__transform_main (%arg1: !transform.any_op {transform.readonly }) {
142+ %func_op = transform.structured.match ops {[" func.func" ]} in %arg1 : (!transform.any_op ) -> !transform.op <" func.func" >
143+
144+ transform.apply_patterns to %func_op {
145+ transform.apply_patterns.linalg.pad_vectorization
146+ } : !transform.op <" func.func" >
147+ transform.yield
148+ }
149+ }
150+
151+ // -----
152+
153+ // Negative test - TransferWriteOp result is not _directly_ consumed by an
154+ // ExtractSliceOp (noet the non-zero offset).
155+
156+ func.func private @make_vector () -> vector <7 x9 xf32 >
157+
158+ // CHECK-LABEL: func @pad_and_transfer_write_static_non_zero_offset
159+ // CHECK: tensor.pad
160+ func.func @pad_and_transfer_write_static_non_zero_offset (
161+ %arg0: tensor <5 x6 xf32 >) -> tensor <5 x6 xf32 > {
162+ %c0 = arith.constant 0 : index
163+ %c5 = arith.constant 5.0 : f32
164+ %0 = tensor.pad %arg0 low [0 , 0 ] high [5 , 7 ] {
165+ ^bb0 (%arg2: index , %arg3: index ):
166+ tensor.yield %c5 : f32
167+ } : tensor <5 x6 xf32 > to tensor <10 x13 xf32 >
168+ %1 = call @make_vector () : () -> vector <7 x9 xf32 >
169+ %2 = vector.transfer_write %1 , %0 [%c0 , %c0 ]
170+ : vector <7 x9 xf32 >, tensor <10 x13 xf32 >
171+ %3 = tensor.extract_slice %2 [0 , 1 ] [5 , 6 ] [1 , 1 ] : tensor <10 x13 xf32 > to tensor <5 x6 xf32 >
172+ return %3 : tensor <5 x6 xf32 >
173+ }
174+
175+ module attributes {transform.with_named_sequence } {
176+ transform.named_sequence @__transform_main (%arg1: !transform.any_op {transform.readonly }) {
177+ %func_op = transform.structured.match ops {[" func.func" ]} in %arg1 : (!transform.any_op ) -> !transform.op <" func.func" >
178+
179+ transform.apply_patterns to %func_op {
180+ transform.apply_patterns.linalg.pad_vectorization
181+ } : !transform.op <" func.func" >
182+ transform.yield
183+ }
184+ }
117185
118186// -----
119187
@@ -209,75 +277,3 @@ module attributes {transform.with_named_sequence} {
209277 transform.yield
210278 }
211279}
212-
213- // -----
214- func.func private @make_vector () -> vector <7 x9 xf32 >
215-
216- // Variant of @pad_and_transfer_write_static
217-
218- // CHECK-LABEL: func @pad_and_transfer_write_static_non_zero_low_pad
219- // CHECK-NOT: tensor.pad
220- // CHECK: linalg.fill
221- func.func @pad_and_transfer_write_static_non_zero_low_pad (
222- %arg0: tensor <5 x6 xf32 >) -> tensor <5 x6 xf32 > {
223- %c0 = arith.constant 0 : index
224- %c5 = arith.constant 5.0 : f32
225- %0 = tensor.pad %arg0 low [0 , 1 ] high [5 , 6 ] {
226- ^bb0 (%arg2: index , %arg3: index ):
227- tensor.yield %c5 : f32
228- } : tensor <5 x6 xf32 > to tensor <10 x13 xf32 >
229- %1 = call @make_vector () : () -> vector <7 x9 xf32 >
230- %2 = vector.transfer_write %1 , %0 [%c0 , %c0 ]
231- : vector <7 x9 xf32 >, tensor <10 x13 xf32 >
232- %3 = tensor.extract_slice %2 [0 , 0 ] [5 , 6 ] [1 , 1 ] : tensor <10 x13 xf32 > to tensor <5 x6 xf32 >
233- return %3 : tensor <5 x6 xf32 >
234- }
235-
236- module attributes {transform.with_named_sequence } {
237- transform.named_sequence @__transform_main (%arg1: !transform.any_op {transform.readonly }) {
238- %func_op = transform.structured.match ops {[" func.func" ]} in %arg1 : (!transform.any_op ) -> !transform.op <" func.func" >
239-
240- transform.apply_patterns to %func_op {
241- // TODO: Split into two tests, one for each pattern
242- transform.apply_patterns.linalg.decompose_pad
243- transform.apply_patterns.linalg.pad_vectorization
244- } : !transform.op <" func.func" >
245- transform.yield
246- }
247- }
248-
249- // -----
250- func.func private @make_vector () -> vector <7 x9 xf32 >
251-
252- // Variant of @pad_and_transfer_write_static
253-
254- // CHECK-LABEL: func @pad_and_transfer_write_static_non_zero_offset
255- // CHECK-NOT: tensor.pad
256- // CHECK: linalg.fill
257- func.func @pad_and_transfer_write_static_non_zero_offset (
258- %arg0: tensor <5 x6 xf32 >) -> tensor <5 x6 xf32 > {
259- %c0 = arith.constant 0 : index
260- %c5 = arith.constant 5.0 : f32
261- %0 = tensor.pad %arg0 low [0 , 1 ] high [5 , 6 ] {
262- ^bb0 (%arg2: index , %arg3: index ):
263- tensor.yield %c5 : f32
264- } : tensor <5 x6 xf32 > to tensor <10 x13 xf32 >
265- %1 = call @make_vector () : () -> vector <7 x9 xf32 >
266- %2 = vector.transfer_write %1 , %0 [%c0 , %c0 ]
267- : vector <7 x9 xf32 >, tensor <10 x13 xf32 >
268- %3 = tensor.extract_slice %2 [0 , 1 ] [5 , 6 ] [1 , 1 ] : tensor <10 x13 xf32 > to tensor <5 x6 xf32 >
269- return %3 : tensor <5 x6 xf32 >
270- }
271-
272- module attributes {transform.with_named_sequence } {
273- transform.named_sequence @__transform_main (%arg1: !transform.any_op {transform.readonly }) {
274- %func_op = transform.structured.match ops {[" func.func" ]} in %arg1 : (!transform.any_op ) -> !transform.op <" func.func" >
275-
276- transform.apply_patterns to %func_op {
277- // TODO: Split into two tests, one for each pattern
278- transform.apply_patterns.linalg.decompose_pad
279- transform.apply_patterns.linalg.pad_vectorization
280- } : !transform.op <" func.func" >
281- transform.yield
282- }
283- }
0 commit comments