1313
1414// CHECK-LABEL: func.func @xfer_write_transposing_permutation_map
1515// CHECK-SAME: %[[VEC:.*]]: vector<4x8xi16>,
16- // CHECK-SAME: %[[MEM:.*]]: memref<2x2x8x4xi16>) {
16+ // CHECK-SAME: %[[MEM:.*]]: memref<2x2x8x4xi16>
1717// CHECK: %[[TR:.*]] = vector.transpose %[[VEC]], [1, 0] : vector<4x8xi16> to vector<8x4xi16>
1818// CHECK: vector.transfer_write
1919// CHECK-NOT: permutation_map
2020// CHECK-SAME: %[[TR]], %[[MEM]]{{.*}} {in_bounds = [true, true]} : vector<8x4xi16>, memref<2x2x8x4xi16>
2121func.func @xfer_write_transposing_permutation_map (
2222 %vec: vector <4 x8 xi16 >,
23- %mem: memref <2 x2 x8 x4 xi16 >) {
23+ %mem: memref <2 x2 x8 x4 xi16 >,
24+ %idx: index ) {
2425
25- %c0 = arith.constant 0 : index
26- vector.transfer_write %vec , %mem [%c0 , %c0 , %c0 , %c0 ] {
26+ vector.transfer_write %vec , %mem [%idx , %idx , %idx , %idx ] {
2727 in_bounds = [true , true ],
2828 permutation_map = affine_map <(d0 , d1 , d2 , d3 ) -> (d3 , d2 )>
2929 } : vector <4 x8 xi16 >, memref <2 x2 x8 x4 xi16 >
3030
3131 return
3232}
3333
34- // Even with out-of-bounds, it is safe to apply this pattern
34+ // Even with out-of-bounds accesses, it is safe to apply this pattern
35+
3536// CHECK-LABEL: func.func @xfer_write_transposing_permutation_map_out_of_bounds
3637// CHECK-SAME: %[[VEC:.*]]: vector<4x8xi16>,
37- // CHECK-SAME: %[[MEM:.*]]: memref<2x2x?x?xi16>) {
38- // CHECK: %[[C0 :.*]] = arith.constant 0 : index
38+ // CHECK-SAME: %[[MEM:.*]]: memref<2x2x?x?xi16>,
39+ // CHECK-SAME : %[[IDX :.*]]: index) {
3940// CHECK: %[[TR:.*]] = vector.transpose %[[VEC]], [1, 0] : vector<4x8xi16> to vector<8x4xi16>
4041// Expect the in_bounds attribute to be preserved. Since we don't print it when
4142// all flags are "false", it should not appear in the output.
4243// CHECK-NOT: in_bounds
4344// CHECK: vector.transfer_write
4445// CHECK-NOT: permutation_map
45- // CHECK-SAME: %[[TR]], %[[MEM]][%[[C0 ]], %[[C0 ]], %[[C0 ]], %[[C0 ]]] : vector<8x4xi16>, memref<2x2x?x?xi16>
46+ // CHECK-SAME: %[[TR]], %[[MEM]][%[[IDX ]], %[[IDX ]], %[[IDX ]], %[[IDX ]]] : vector<8x4xi16>, memref<2x2x?x?xi16>
4647func.func @xfer_write_transposing_permutation_map_out_of_bounds (
4748 %vec: vector <4 x8 xi16 >,
48- %mem: memref <2 x2 x?x?xi16 >) {
49+ %mem: memref <2 x2 x?x?xi16 >,
50+ %idx: index ) {
4951
50- %c0 = arith.constant 0 : index
51- vector.transfer_write %vec , %mem [%c0 , %c0 , %c0 , %c0 ] {
52+ vector.transfer_write %vec , %mem [%idx , %idx , %idx , %idx ] {
5253 in_bounds = [false , false ],
5354 permutation_map = affine_map <(d0 , d1 , d2 , d3 ) -> (d3 , d2 )>
5455 } : vector <4 x8 xi16 >, memref <2 x2 x?x?xi16 >
@@ -59,18 +60,19 @@ func.func @xfer_write_transposing_permutation_map_out_of_bounds(
5960// CHECK-LABEL: func.func @xfer_write_transposing_permutation_map_with_mask_scalable
6061// CHECK-SAME: %[[VEC:.*]]: vector<4x[8]xi16>,
6162// CHECK-SAME: %[[MEM:.*]]: memref<2x2x?x4xi16>,
62- // CHECK-SAME: %[[MASK:.*]]: vector<[8]x4xi1>) {
63+ // CHECK-SAME: %[[MASK:.*]]: vector<[8]x4xi1>
6364// CHECK: %[[TR:.*]] = vector.transpose %[[VEC]], [1, 0] : vector<4x[8]xi16> to vector<[8]x4xi16>
6465// CHECK: vector.transfer_write
6566// CHECK-NOT: permutation_map
6667// CHECK-SAME: %[[TR]], %[[MEM]]{{.*}}, %[[MASK]] {in_bounds = [true, true]} : vector<[8]x4xi16>, memref<2x2x?x4xi16>
6768func.func @xfer_write_transposing_permutation_map_with_mask_scalable (
6869 %vec: vector <4 x[8 ]xi16 >,
6970 %mem: memref <2 x2 x?x4 xi16 >,
70- %mask: vector <[8 ]x4 xi1 >) {
71+ %mask: vector <[8 ]x4 xi1 >,
72+ %idx: index ) {
7173
7274 %c0 = arith.constant 0 : index
73- vector.transfer_write %vec , %mem [%c0 , %c0 , %c0 , %c0 ], %mask {
75+ vector.transfer_write %vec , %mem [%idx , %idx , %idx , %idx ], %mask {
7476 in_bounds = [true , true ],
7577 permutation_map = affine_map <(d0 , d1 , d2 , d3 ) -> (d3 , d2 )>
7678 } : vector <4 x[8 ]xi16 >, memref <2 x2 x?x4 xi16 >
@@ -79,16 +81,18 @@ func.func @xfer_write_transposing_permutation_map_with_mask_scalable(
7981}
8082
8183// Masked version is not supported
84+
8285// CHECK-LABEL: func.func @xfer_write_transposing_permutation_map_masked
8386// CHECK-NOT: vector.transpose
8487func.func @xfer_write_transposing_permutation_map_masked (
8588 %vec: vector <4 x8 xi16 >,
8689 %mem: memref <2 x2 x8 x4 xi16 >,
87- %mask: vector <8 x4 xi1 >) {
90+ %mask: vector <8 x4 xi1 >,
91+ %idx: index ) {
8892
8993 %c0 = arith.constant 0 : index
9094 vector.mask %mask {
91- vector.transfer_write %vec , %mem [%c0 , %c0 , %c0 , %c0 ] {
95+ vector.transfer_write %vec , %mem [%idx , %idx , %idx , %idx ] {
9296 in_bounds = [true , true ],
9397 permutation_map = affine_map <(d0 , d1 , d2 , d3 ) -> (d3 , d2 )>
9498 } : vector <4 x8 xi16 >, memref <2 x2 x8 x4 xi16 >
@@ -128,7 +132,8 @@ func.func @xfer_write_non_transposing_permutation_map(
128132 return
129133}
130134
131- // Even with out-of-bounds, it is safe to apply this pattern
135+ // Even with out-of-bounds accesses, it is safe to apply this pattern
136+
132137// CHECK-LABEL: func.func @xfer_write_non_transposing_permutation_map_with_mask_out_of_bounds(
133138// CHECK-SAME: %[[MEM:.*]]: memref<?x?xf32>,
134139// CHECK-SAME: %[[VEC:.*]]: vector<7xf32>,
@@ -157,8 +162,7 @@ func.func @xfer_write_non_transposing_permutation_map_with_mask_out_of_bounds(
157162// CHECK: func.func @permutation_with_mask_xfer_write_scalable(
158163// CHECK-SAME: %[[VEC:.*]]: vector<4x[8]xi16>,
159164// CHECK-SAME: %[[MEM:.*]]: memref<1x4x?x1xi16>,
160- // CHECK-SAME: %[[MASK:.*]]: vector<4x[8]xi1>) {
161- // CHECK: %[[C0:.*]] = arith.constant 0 : index
165+ // CHECK-SAME: %[[MASK:.*]]: vector<4x[8]xi1>
162166// CHECK: %[[BC_1:.*]] = vector.broadcast %[[VEC]] : vector<4x[8]xi16> to vector<1x4x[8]xi16>
163167// CHECK: %[[BC_2:.*]] = vector.broadcast %[[MASK]] : vector<4x[8]xi1> to vector<1x4x[8]xi1>
164168// CHECK: %[[TRANSPOSE_1:.*]] = vector.transpose %[[BC_2]], [1, 2, 0] : vector<1x4x[8]xi1> to vector<4x[8]x1xi1>
@@ -167,18 +171,19 @@ func.func @xfer_write_non_transposing_permutation_map_with_mask_out_of_bounds(
167171func.func @permutation_with_mask_xfer_write_scalable (
168172 %vec: vector <4 x[8 ]xi16 >,
169173 %mem: memref <1 x4 x?x1 xi16 >,
170- %mask: vector <4 x[8 ]xi1 >){
174+ %mask: vector <4 x[8 ]xi1 >,
175+ %idx: index ){
171176
172- %c0 = arith.constant 0 : index
173- vector.transfer_write %vec , %mem [%c0 , %c0 , %c0 , %c0 ], %mask {
177+ vector.transfer_write %vec , %mem [%idx , %idx , %idx , %idx ], %mask {
174178 in_bounds = [true , true ],
175179 permutation_map = affine_map <(d0 , d1 , d2 , d3 ) -> (d1 , d2 )>
176180 } : vector <4 x[8 ]xi16 >, memref <1 x4 x?x1 xi16 >
177181
178182 return
179183}
180184
181- // transfer_write in MaskOp case not supported.
185+ // Masked version is not supported
186+
182187// CHECK-LABEL: func @masked_permutation_xfer_write_fixed_width
183188// CHECK-SAME: %[[DEST:.*]]: tensor<?x?xf32>,
184189// CHECK-SAME: %[[VEC:.*]]: vector<16xf32>,
@@ -204,18 +209,19 @@ func.func @masked_permutation_xfer_write_fixed_width(
204209// CHECK-LABEL: func.func @masked_permutation_xfer_write_scalable(
205210// CHECK-SAME: %[[VEC:.*]]: vector<4x[8]xi16>,
206211// CHECK-SAME: %[[DEST:.*]]: tensor<?x?x?x?xf32>,
207- // CHECK-SAME: %[[MASK:.*]]: vector<4x[8]xi1>)
212+ // CHECK-SAME: %[[MASK:.*]]: vector<4x[8]xi1>
208213// CHECK-SAME: -> tensor<?x?x?x?xf32> {
209214// CHECK-NOT: vector.transpose
210215// CHECK: vector.mask %[[MASK]] { vector.transfer_write %[[VEC]], %[[DEST]]{{.*}} : vector<4x[8]xi16>, tensor<?x?x?x?xf32> } : vector<4x[8]xi1> -> tensor<?x?x?x?xf32>
211216func.func @masked_permutation_xfer_write_scalable (
212217 %vec: vector <4 x[8 ]xi16 >,
213218 %dest: tensor <?x?x?x?xf32 >,
214- %mask: vector <4 x[8 ]xi1 >) -> tensor <?x?x?x?xf32 > {
219+ %mask: vector <4 x[8 ]xi1 >,
220+ %idx: index ) -> tensor <?x?x?x?xf32 > {
215221
216222 %c0 = arith.constant 0 : index
217223 %res = vector.mask %mask {
218- vector.transfer_write %vec , %dest [%c0 , %c0 , %c0 , %c0 ] {
224+ vector.transfer_write %vec , %dest [%idx , %idx , %idx , %idx ] {
219225 in_bounds = [true , true ],
220226 permutation_map = affine_map <(d0 , d1 , d2 , d3 ) -> (d1 , d2 )>
221227 } : vector <4 x[8 ]xi16 >, tensor <?x?x?x?xf32 >
@@ -224,22 +230,23 @@ func.func @masked_permutation_xfer_write_scalable(
224230 return %res : tensor <?x?x?x?xf32 >
225231}
226232
227- // transfer_write in MaskOp case not supported.
233+ // Masked version is not supported
234+
228235// CHECK-LABEL: func @masked_non_permutation_xfer_write_fixed_width
229236// CHECK-SAME: %[[DEST:.*]]: tensor<?x?x?x?xf32>
230237// CHECK-SAME: %[[VEC:.*]]: vector<14x8x16xf32>
231- // CHECK-SAME: %[[IDX:.*]]: index) -> tensor<?x?x?x?xf32>
238+ // CHECK-SAME: %[[DIM:.*]]: index, %[[ IDX:.*]]: index) -> tensor<?x?x?x?xf32>
232239// CHECK-NOT: vector.broadcast
233240// CHECK: vector.mask %0 { vector.transfer_write %[[VEC]], %[[DEST]]{{.*}} : vector<14x8x16xf32>, tensor<?x?x?x?xf32> } : vector<14x8x16xi1> -> tensor<?x?x?x?xf32>
234241func.func @masked_non_permutation_xfer_write_fixed_width (
235242 %dest : tensor <?x?x?x?xf32 >,
236243 %vec : vector <14 x8 x16 xf32 >,
237- %dim : index ) -> tensor <?x?x?x?xf32 > {
244+ %dim : index ,
245+ %idx: index ) -> tensor <?x?x?x?xf32 > {
238246
239- %c0 = arith.constant 0 : index
240247 %mask = vector.create_mask %dim , %dim , %dim : vector <14 x8 x16 xi1 >
241248 %res = vector.mask %mask {
242- vector.transfer_write %vec , %dest [%c0 , %c0 , %c0 , %c0 ] {
249+ vector.transfer_write %vec , %dest [%idx , %idx , %idx , %idx ] {
243250 in_bounds = [false , false , true ],
244251 permutation_map = affine_map <(d0 , d1 , d2 , d3 ) -> (d0 , d1 , d3 )>
245252 } : vector <14 x8 x16 xf32 >, tensor <?x?x?x?xf32 >
@@ -259,25 +266,23 @@ func.func @masked_non_permutation_xfer_write_fixed_width(
259266
260267// CHECK-LABEL: func.func @permutation_with_mask_xfer_read_fixed_width(
261268// CHECK-SAME: %[[MEM:.*]]: memref<?x?xf32>,
262- // CHECK-SAME: %[[IDX_1:.*]]: index,
263- // CHECK-SAME: %[[IDX_2:.*]]: index) -> vector<8x4x2xf32> {
264- // CHECK: %[[C0:.*]] = arith.constant 0 : index
269+ // CHECK-SAME: %[[DIM_1:.*]]: index, %[[DIM_2:.*]]: index, %[[IDX:.*]]: index) -> vector<8x4x2xf32> {
265270// CHECK: %[[PASS_THROUGH:.*]] = arith.constant 0.000000e+00 : f32
266- // CHECK: %[[MASK:.*]] = vector.create_mask %[[IDX_2 ]], %[[IDX_1 ]] : vector<2x4xi1>
267- // CHECK: %[[T_READ:.*]] = vector.transfer_read %[[MEM]]{{\[}}%[[C0 ]], %[[C0 ]]], %[[PASS_THROUGH]], %[[MASK]] {in_bounds = [true, true]} : memref<?x?xf32>, vector<2x4xf32>
271+ // CHECK: %[[MASK:.*]] = vector.create_mask %[[DIM_2 ]], %[[DIM_1 ]] : vector<2x4xi1>
272+ // CHECK: %[[T_READ:.*]] = vector.transfer_read %[[MEM]]{{\[}}%[[IDX ]], %[[IDX ]]], %[[PASS_THROUGH]], %[[MASK]] {in_bounds = [true, true]} : memref<?x?xf32>, vector<2x4xf32>
268273// CHECK: %[[BCAST:.*]] = vector.broadcast %[[T_READ]] : vector<2x4xf32> to vector<8x2x4xf32>
269274// CHECK: %[[TRANSPOSE:.*]] = vector.transpose %[[BCAST]], [0, 2, 1] : vector<8x2x4xf32> to vector<8x4x2xf32>
270275// CHECK: return %[[TRANSPOSE]] : vector<8x4x2xf32>
271276func.func @permutation_with_mask_xfer_read_fixed_width (
272277 %mem: memref <?x?xf32 >,
273278 %dim_1: index ,
274- %dim_2: index ) -> (vector <8 x4 x2 xf32 >) {
279+ %dim_2: index ,
280+ %idx: index ) -> (vector <8 x4 x2 xf32 >) {
275281
276- %c0 = arith.constant 0 : index
277- %cst_0 = arith.constant 0.000000e+00 : f32
282+ %pad = arith.constant 0.000000e+00 : f32
278283
279284 %mask = vector.create_mask %dim_2 , %dim_1 : vector <2 x4 xi1 >
280- %res = vector.transfer_read %mem [%c0 , %c0 ], %cst_0 , %mask {
285+ %res = vector.transfer_read %mem [%idx , %idx ], %pad , %mask {
281286 in_bounds = [true , true , true ],
282287 permutation_map = affine_map <(d0 , d1 ) -> (0 , d1 , d0 )>
283288 } : memref <?x?xf32 >, vector <8 x4 x2 xf32 >
@@ -287,46 +292,45 @@ func.func @permutation_with_mask_xfer_read_fixed_width(
287292
288293// CHECK-LABEL: func.func @permutation_with_mask_xfer_read_scalable(
289294// CHECK-SAME: %[[MEM:.*]]: memref<?x?xf32>,
290- // CHECK-SAME: %[[IDX_1:.*]]: index,
291- // CHECK-SAME: %[[IDX_2:.*]]: index) -> vector<8x[4]x2xf32> {
292- // CHECK: %[[C0:.*]] = arith.constant 0 : index
293- // CHECK: %[[PASS_THROUGH:.*]] = arith.constant 0.000000e+00 : f32
294- // CHECK: %[[MASK:.*]] = vector.create_mask %[[IDX_2]], %[[IDX_1]] : vector<2x[4]xi1>
295- // CHECK: %[[T_READ:.*]] = vector.transfer_read %[[MEM]]{{\[}}%[[C0]], %[[C0]]], %[[PASS_THROUGH]], %[[MASK]] {in_bounds = [true, true]} : memref<?x?xf32>, vector<2x[4]xf32>
295+ // CHECK-SAME: %[[DIM_1:.*]]: index, %[[DIM_2:.*]]: index, %[[IDX:.*]]: index) -> vector<8x[4]x2xf32> {
296+ // CHECK: %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
297+ // CHECK: %[[MASK:.*]] = vector.create_mask %[[DIM_2]], %[[DIM_1]] : vector<2x[4]xi1>
298+ // CHECK: %[[T_READ:.*]] = vector.transfer_read %[[MEM]]{{\[}}%[[IDX]], %[[IDX]]], %[[PAD]], %[[MASK]] {in_bounds = [true, true]} : memref<?x?xf32>, vector<2x[4]xf32>
296299// CHECK: %[[BCAST:.*]] = vector.broadcast %[[T_READ]] : vector<2x[4]xf32> to vector<8x2x[4]xf32>
297300// CHECK: %[[TRANSPOSE:.*]] = vector.transpose %[[BCAST]], [0, 2, 1] : vector<8x2x[4]xf32> to vector<8x[4]x2xf32>
298301// CHECK: return %[[TRANSPOSE]] : vector<8x[4]x2xf32>
299302func.func @permutation_with_mask_xfer_read_scalable (
300303 %mem: memref <?x?xf32 >,
301304 %dim_1: index ,
302- %dim_2: index ) -> (vector <8 x[4 ]x2 xf32 >) {
305+ %dim_2: index ,
306+ %idx: index ) -> (vector <8 x[4 ]x2 xf32 >) {
303307
304- %c0 = arith.constant 0 : index
305- %cst_0 = arith.constant 0.000000e+00 : f32
308+ %pad = arith.constant 0.000000e+00 : f32
306309
307310 %mask = vector.create_mask %dim_2 , %dim_1 : vector <2 x[4 ]xi1 >
308- %res = vector.transfer_read %mem [%c0 , %c0 ], %cst_0 , %mask {
311+ %res = vector.transfer_read %mem [%idx , %idx ], %pad , %mask {
309312 in_bounds = [true , true , true ],
310313 permutation_map = affine_map <(d0 , d1 ) -> (0 , d1 , d0 )>
311314 } : memref <?x?xf32 >, vector <8 x[4 ]x2 xf32 >
312315
313316 return %res : vector <8 x[4 ]x2 xf32 >
314317}
315318
316- // transfer_read in MaskOp case not supported.
319+ // Masked version is not supported
320+
317321// CHECK-LABEL: func @masked_permutation_xfer_read_fixed_width
318322// CHECK-SAME: %[[DEST:.*]]: tensor<?x1xf32>,
319323// CHECK-SAME: %[[MASK:.*]]: vector<4x1xi1>
320324// CHECK-NOT: vector.transpose
321325// CHECK: vector.mask %[[MASK]] { vector.transfer_read %[[DEST]]{{.*}}: tensor<?x1xf32>, vector<1x4x4xf32> } : vector<4x1xi1> -> vector<1x4x4xf32>
322326func.func @masked_permutation_xfer_read_fixed_width (
323327 %dest: tensor <?x1 xf32 >,
324- %mask : vector <4 x1 xi1 >) {
328+ %mask : vector <4 x1 xi1 >,
329+ %idx: index ) {
325330
326- %cst = arith.constant 0.000000e+00 : f32
327- %c0 = arith.constant 0 : index
331+ %pad = arith.constant 0.000000e+00 : f32
328332 %3 = vector.mask %mask {
329- vector.transfer_read %dest [%c0 , %c0 ], %cst {
333+ vector.transfer_read %dest [%idx , %idx ], %pad {
330334 permutation_map = affine_map <(d0 , d1 ) -> (d1 , 0 , d0 )>
331335 } : tensor <?x1 xf32 >, vector <1 x4 x4 xf32 >
332336 } : vector <4 x1 xi1 > -> vector <1 x4 x4 xf32 >
@@ -337,18 +341,18 @@ func.func @masked_permutation_xfer_read_fixed_width(
337341
338342// CHECK-LABEL: func.func @masked_permutation_xfer_read_scalable(
339343// CHECK-SAME: %[[DEST:.*]]: tensor<?x?xf32>,
340- // CHECK-SAME: %[[MASK:.*]]: vector<2x[4]xi1>) -> vector<8x[4]x2xf32> {
344+ // CHECK-SAME: %[[MASK:.*]]: vector<2x[4]xi1>
341345// CHECK-NOT: vector.transpose
342346// CHECK: %[[T_READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[DEST]]{{.*}} : tensor<?x?xf32>, vector<8x[4]x2xf32> } : vector<2x[4]xi1> -> vector<8x[4]x2xf32>
343347func.func @masked_permutation_xfer_read_scalable (
344348 %dest: tensor <?x?xf32 >,
345- %mask : vector <2 x[4 ]xi1 >) -> vector <8 x[4 ]x2 xf32 > {
349+ %mask : vector <2 x[4 ]xi1 >,
350+ %idx: index ) -> vector <8 x[4 ]x2 xf32 > {
346351
347- %c0 = arith.constant 0 : index
348- %cst_0 = arith.constant 0.000000e+00 : f32
352+ %pad = arith.constant 0.000000e+00 : f32
349353
350354 %res = vector.mask %mask {
351- vector.transfer_read %dest [%c0 , %c0 ], %cst_0 {
355+ vector.transfer_read %dest [%idx , %idx ], %pad {
352356 in_bounds = [true , true , true ],
353357 permutation_map = affine_map <(d0 , d1 ) -> (0 , d1 , d0 )>
354358 } : tensor <?x?xf32 >, vector <8 x[4 ]x2 xf32 >
@@ -377,41 +381,41 @@ module attributes {transform.with_named_sequence} {
377381
378382// CHECK: #[[MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, 0, d3)>
379383// CHECK: func.func @transfer_read_reduce_rank_scalable(
380- // CHECK-SAME: %[[MEM:.*]]: memref<?x?x?x?xf32>) -> vector<8x[4]x2x3xf32> {
381- // CHECK: %[[C0:.*]] = arith.constant 0 : index
382- // CHECK: %[[T_READ:.*]] = vector.transfer_read %[[MEM]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]]{{.*}} permutation_map = #[[MAP]]} : memref<?x?x?x?xf32>, vector<[4]x2x3xf32>
384+ // CHECK-SAME: %[[MEM:.*]]: memref<?x?x?x?xf32>, %[[IDX:.*]]: index) -> vector<8x[4]x2x3xf32> {
385+ // CHECK: %[[T_READ:.*]] = vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]], %[[IDX]], %[[IDX]]]{{.*}} permutation_map = #[[MAP]]} : memref<?x?x?x?xf32>, vector<[4]x2x3xf32>
383386// CHECK: %[[BC:.*]] = vector.broadcast %[[T_READ]] : vector<[4]x2x3xf32> to vector<8x[4]x2x3xf32>
384387// CHECK: return %[[BC]] : vector<8x[4]x2x3xf32>
385388func.func @transfer_read_reduce_rank_scalable (
386- %mem: memref <?x?x?x?xf32 >) -> vector <8 x[4 ]x2 x3 xf32 > {
389+ %mem: memref <?x?x?x?xf32 >, %idx: index ) -> vector <8 x[4 ]x2 x3 xf32 > {
387390
388- %c0 = arith.constant 0 : index
389- %cst_0 = arith.constant 0.000000e+00 : f32
391+ %pad = arith.constant 0.000000e+00 : f32
390392
391- %res = vector.transfer_read %mem [%c0 , %c0 , %c0 , %c0 ], %cst_0 {
393+ %res = vector.transfer_read %mem [%idx , %idx , %idx , %idx ], %pad {
392394 in_bounds = [true , true , true , true ],
393395 permutation_map = affine_map <(d0 , d1 , d2 , d3 ) -> (0 , d1 , 0 , d3 )>
394396 } : memref <?x?x?x?xf32 >, vector <8 x[4 ]x2 x3 xf32 >
395397
396398 return %res : vector <8 x[4 ]x2 x3 xf32 >
397399}
398400
399- // Masked case not supported.
401+ // Masked version is not supported
402+
400403// CHECK-LABEL: func.func @masked_transfer_read_reduce_rank(
401404// CHECK-SAME: %[[MEM:.*]]: memref<?x?x?x?xf32>,
402- // CHECK-SAME: %[[DIM:.*]]: index) -> vector<8x[4]x2x3xf32> {
405+ // CHECK-SAME: %[[DIM:.*]]: index,
406+ // CHECK-SAME: %[[IDX:.*]]: index) -> vector<8x[4]x2x3xf32> {
403407// CHECK-NOT: vector.broadcast
404408// CHECK: %[[MASK:.*]] = vector.mask %0 { vector.transfer_read %[[MEM]]{{.*}} : memref<?x?x?x?xf32>, vector<8x[4]x2x3xf32> } : vector<[4]x3xi1> -> vector<8x[4]x2x3xf32>
405409func.func @masked_transfer_read_reduce_rank (
406410 %mem: memref <?x?x?x?xf32 >,
407- %dim: index ) -> vector <8 x[4 ]x2 x3 xf32 > {
411+ %dim: index ,
412+ %idx: index ) -> vector <8 x[4 ]x2 x3 xf32 > {
408413
409- %c0 = arith.constant 0 : index
410- %cst_0 = arith.constant 0.000000e+00 : f32
414+ %pad = arith.constant 0.000000e+00 : f32
411415 %mask = vector.create_mask %dim , %dim: vector <[4 ]x3 xi1 >
412416
413417 %res = vector.mask %mask {
414- vector.transfer_read %mem [%c0 , %c0 , %c0 , %c0 ], %cst_0 {
418+ vector.transfer_read %mem [%idx , %idx , %idx , %idx ], %pad {
415419 in_bounds = [true , true , true , true ],
416420 permutation_map = affine_map <(d0 , d1 , d2 , d3 ) -> (0 , d1 , 0 , d3 )>
417421 } : memref <?x?x?x?xf32 >, vector <8 x[4 ]x2 x3 xf32 >
0 commit comments