Skip to content

Commit 41be5bb

Browse files
authored
[mlir][vector] Update tests for xfer permutation lowering (3/N) (#127320)
* Remove `vector.create_mask` from tests. Instead, pass masks as arguments. This simplifies the tests without sacrificing test coverage. * Update `@xfer_read_minor_identity_tranposed_with_mask_scalable` to use similar shapes as other tests and to avoid using test Ops (e.g. `@test.some_use`). This improves consistency between tests. * Fix some comment typos.
1 parent c71f914 commit 41be5bb

File tree

1 file changed

+38
-38
lines changed

1 file changed

+38
-38
lines changed

mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir

Lines changed: 38 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
// RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s
22

3+
// TODO: Review the usage of `in_bounds` and remove where not affecting the
4+
// generated output.
5+
36
/// CHECK: #[[$MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, 0, d3)>
47

58
///----------------------------------------------------------------------------------------
@@ -106,8 +109,8 @@ func.func @xfer_write_minor_identity_transposed_map_masked(
106109
/// (neither a minor identity nor transposed minor identity map)
107110
/// OUT 1: vector.broadcast + vector.transfer_write
108111
/// (transposed minor identity)
109-
/// OUT 2: vector.transfer_write -> vector.broadcast + vector.transpose + vector.transfer_write
110-
/// (minor identity)
112+
/// OUT 2: vector.transfer_write -> vector.broadcast + vector.transpose
113+
/// + vector.transfer_write (minor identity)
111114
///----------------------------------------------------------------------------------------
112115

113116
// CHECK-LABEL: func.func @xfer_write_non_minor_identity(
@@ -233,16 +236,16 @@ func.func @xfer_write_non_minor_identity_masked_scalable(
233236
// CHECK-LABEL: func @xfer_write_non_minor_identity_masked_2
234237
// CHECK-SAME: %[[DEST:.*]]: tensor<?x?x?x?xf32>
235238
// CHECK-SAME: %[[VEC:.*]]: vector<14x8x16xf32>
236-
// CHECK-SAME: %[[DIM:.*]]: index, %[[IDX:.*]]: index) -> tensor<?x?x?x?xf32>
239+
// CHECK-SAME: %[[MASK:.*]]: vector<14x8x16xi1>
240+
// CHECK-SAME: %[[DIM:.*]]: index
237241
// CHECK-NOT: vector.broadcast
238-
// CHECK: vector.mask %0 { vector.transfer_write %[[VEC]], %[[DEST]]{{.*}} : vector<14x8x16xf32>, tensor<?x?x?x?xf32> } : vector<14x8x16xi1> -> tensor<?x?x?x?xf32>
242+
// CHECK: vector.mask %[[MASK]] { vector.transfer_write %[[VEC]], %[[DEST]]{{.*}} : vector<14x8x16xf32>, tensor<?x?x?x?xf32> } : vector<14x8x16xi1> -> tensor<?x?x?x?xf32>
239243
func.func @xfer_write_non_minor_identity_masked_2(
240244
%dest : tensor<?x?x?x?xf32>,
241245
%vec : vector<14x8x16xf32>,
242-
%dim : index,
246+
%mask: vector<14x8x16xi1>,
243247
%idx: index) -> tensor<?x?x?x?xf32> {
244248

245-
%mask = vector.create_mask %dim, %dim, %dim : vector<14x8x16xi1>
246249
%res = vector.mask %mask {
247250
vector.transfer_write %vec, %dest[%idx, %idx, %idx, %idx] {
248251
in_bounds = [false, false, true],
@@ -259,29 +262,27 @@ func.func @xfer_write_non_minor_identity_masked_2(
259262
///
260263
/// IN: vector.transfer_read
261264
/// (_transposed_ minor identity permutation map, with 0 or more broadcast dims)
262-
/// OUT: vector.transpose + vector.transfer_write
265+
/// OUT: vector.transfer_read + vector.broadcast + vector.transpose
263266
/// (minor identity permutation map with 0 or more leading broadcast dims)
264267
///----------------------------------------------------------------------------------------
265268
/// TODO: Inner broadcast dim - see also the block at the bottom of this file
266269

267-
// CHECK-LABEL: func.func @xfer_read_minor_identity_tranposed_with_mask
270+
// CHECK-LABEL: func.func @xfer_read_minor_identity_transposed_with_mask
268271
// CHECK-SAME: %[[MEM:.*]]: memref<?x?xf32>,
269-
// CHECK-SAME: %[[DIM_1:.*]]: index, %[[DIM_2:.*]]: index, %[[IDX:.*]]: index) -> vector<8x4x2xf32> {
272+
// CHECK-SAME: %[[MASK:.*]]: vector<2x4xi1>
273+
// CHECK-SAME: %[[IDX:.*]]: index
270274
// CHECK: %[[PASS_THROUGH:.*]] = arith.constant 0.000000e+00 : f32
271-
// CHECK: %[[MASK:.*]] = vector.create_mask %[[DIM_2]], %[[DIM_1]] : vector<2x4xi1>
272275
// CHECK: %[[T_READ:.*]] = vector.transfer_read %[[MEM]]{{\[}}%[[IDX]], %[[IDX]]], %[[PASS_THROUGH]], %[[MASK]] {in_bounds = [true, true]} : memref<?x?xf32>, vector<2x4xf32>
273276
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[T_READ]] : vector<2x4xf32> to vector<8x2x4xf32>
274277
// CHECK: %[[TRANSPOSE:.*]] = vector.transpose %[[BCAST]], [0, 2, 1] : vector<8x2x4xf32> to vector<8x4x2xf32>
275278
// CHECK: return %[[TRANSPOSE]] : vector<8x4x2xf32>
276-
func.func @xfer_read_minor_identity_tranposed_with_mask(
279+
func.func @xfer_read_minor_identity_transposed_with_mask(
277280
%mem: memref<?x?xf32>,
278-
%dim_1: index,
279-
%dim_2: index,
281+
%mask: vector<2x4xi1>,
280282
%idx: index) -> (vector<8x4x2xf32>) {
281283

282284
%pad = arith.constant 0.000000e+00 : f32
283285

284-
%mask = vector.create_mask %dim_2, %dim_1 : vector<2x4xi1>
285286
%res = vector.transfer_read %mem[%idx, %idx], %pad, %mask {
286287
in_bounds = [true, true, true],
287288
permutation_map = affine_map<(d0, d1) -> (0, d1, d0)>
@@ -290,24 +291,22 @@ func.func @xfer_read_minor_identity_tranposed_with_mask(
290291
return %res : vector<8x4x2xf32>
291292
}
292293

293-
// CHECK-LABEL: func.func @xfer_read_minor_identity_tranposed_with_mask_scalable(
294+
// CHECK-LABEL: func.func @xfer_read_minor_identity_transposed_with_mask_scalable(
294295
// CHECK-SAME: %[[MEM:.*]]: memref<?x?xf32>,
295-
// CHECK-SAME: %[[DIM_1:.*]]: index, %[[DIM_2:.*]]: index, %[[IDX:.*]]: index) -> vector<8x[4]x2xf32> {
296+
// CHECK-SAME: %[[MASK:.*]]: vector<2x[4]xi1>
297+
// CHECK-SAME: %[[IDX:.*]]: index
296298
// CHECK: %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
297-
// CHECK: %[[MASK:.*]] = vector.create_mask %[[DIM_2]], %[[DIM_1]] : vector<2x[4]xi1>
298299
// CHECK: %[[T_READ:.*]] = vector.transfer_read %[[MEM]]{{\[}}%[[IDX]], %[[IDX]]], %[[PAD]], %[[MASK]] {in_bounds = [true, true]} : memref<?x?xf32>, vector<2x[4]xf32>
299300
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[T_READ]] : vector<2x[4]xf32> to vector<8x2x[4]xf32>
300301
// CHECK: %[[TRANSPOSE:.*]] = vector.transpose %[[BCAST]], [0, 2, 1] : vector<8x2x[4]xf32> to vector<8x[4]x2xf32>
301302
// CHECK: return %[[TRANSPOSE]] : vector<8x[4]x2xf32>
302-
func.func @xfer_read_minor_identity_tranposed_with_mask_scalable(
303+
func.func @xfer_read_minor_identity_transposed_with_mask_scalable(
303304
%mem: memref<?x?xf32>,
304-
%dim_1: index,
305-
%dim_2: index,
305+
%mask: vector<2x[4]xi1>,
306306
%idx: index) -> (vector<8x[4]x2xf32>) {
307307

308308
%pad = arith.constant 0.000000e+00 : f32
309309

310-
%mask = vector.create_mask %dim_2, %dim_1 : vector<2x[4]xi1>
311310
%res = vector.transfer_read %mem[%idx, %idx], %pad, %mask {
312311
in_bounds = [true, true, true],
313312
permutation_map = affine_map<(d0, d1) -> (0, d1, d0)>
@@ -319,24 +318,26 @@ func.func @xfer_read_minor_identity_tranposed_with_mask_scalable(
319318
// Masked version is not supported
320319

321320
// CHECK-LABEL: func @xfer_read_minor_identity_transposed_masked(
322-
// CHECK-SAME: %[[DEST:.*]]: tensor<?x1xf32>,
323-
// CHECK-SAME: %[[MASK:.*]]: vector<4x1xi1>
321+
// CHECK-SAME: %[[DEST:.*]]: tensor<?x?xf32>,
322+
// CHECK-SAME: %[[MASK:.*]]: vector<2x4xi1>
323+
// CHECK-SAME: %[[IDX:.*]]: index
324324
// CHECK-NOT: vector.transpose
325-
// CHECK: vector.mask %[[MASK]] { vector.transfer_read %[[DEST]]{{.*}}: tensor<?x1xf32>, vector<1x4x4xf32> } : vector<4x1xi1> -> vector<1x4x4xf32>
325+
// CHECK: vector.mask %[[MASK]] { vector.transfer_read %[[DEST]]{{.*}}: tensor<?x?xf32>, vector<8x4x2xf32> } : vector<2x4xi1> -> vector<8x4x2xf32>
326326
func.func @xfer_read_minor_identity_transposed_masked(
327-
%dest: tensor<?x1xf32>,
328-
%mask : vector<4x1xi1>,
329-
%idx: index) {
327+
%dest: tensor<?x?xf32>,
328+
%mask: vector<2x4xi1>,
329+
%idx: index) -> (vector<8x4x2xf32>) {
330330

331331
%pad = arith.constant 0.000000e+00 : f32
332-
%3 = vector.mask %mask {
332+
333+
%res = vector.mask %mask {
333334
vector.transfer_read %dest[%idx, %idx], %pad {
334-
permutation_map = affine_map<(d0, d1) -> (d1, 0, d0)>
335-
} : tensor<?x1xf32>, vector<1x4x4xf32>
336-
} : vector<4x1xi1> -> vector<1x4x4xf32>
335+
in_bounds = [true, true, true],
336+
permutation_map = affine_map<(d0, d1) -> (0, d1, d0)>
337+
} : tensor<?x?xf32>, vector<8x4x2xf32>
338+
} : vector<2x4xi1> -> vector<8x4x2xf32>
337339

338-
"test.some_use"(%3) : (vector<1x4x4xf32>) -> ()
339-
return
340+
return %res : vector<8x4x2xf32>
340341
}
341342

342343
// CHECK-LABEL: func.func @xfer_read_minor_identity_transposed_masked_scalable(
@@ -346,7 +347,7 @@ func.func @xfer_read_minor_identity_transposed_masked(
346347
// CHECK: %[[T_READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[DEST]]{{.*}} : tensor<?x?xf32>, vector<8x[4]x2xf32> } : vector<2x[4]xi1> -> vector<8x[4]x2xf32>
347348
func.func @xfer_read_minor_identity_transposed_masked_scalable(
348349
%dest: tensor<?x?xf32>,
349-
%mask : vector<2x[4]xi1>,
350+
%mask: vector<2x[4]xi1>,
350351
%idx: index) -> vector<8x[4]x2xf32> {
351352

352353
%pad = arith.constant 0.000000e+00 : f32
@@ -388,17 +389,16 @@ func.func @xfer_read_minor_identitiy_bcast_dims_scalable(
388389

389390
// CHECK-LABEL: func.func @xfer_read_minor_identitiy_bcast_dims_masked
390391
// CHECK-SAME: %[[MEM:.*]]: memref<?x?x?x?xf32>,
391-
// CHECK-SAME: %[[DIM:.*]]: index,
392+
// CHECK-SAME: %[[MASK:.*]]: vector<[4]x3xi1>
392393
// CHECK-SAME: %[[IDX:.*]]: index) -> vector<8x[4]x2x3xf32> {
393394
// CHECK-NOT: vector.broadcast
394-
// CHECK: %[[MASK:.*]] = vector.mask %0 { vector.transfer_read %[[MEM]]{{.*}} : memref<?x?x?x?xf32>, vector<8x[4]x2x3xf32> } : vector<[4]x3xi1> -> vector<8x[4]x2x3xf32>
395+
// CHECK: vector.mask %[[MASK]] { vector.transfer_read %[[MEM]]{{.*}} : memref<?x?x?x?xf32>, vector<8x[4]x2x3xf32> } : vector<[4]x3xi1> -> vector<8x[4]x2x3xf32>
395396
func.func @xfer_read_minor_identitiy_bcast_dims_masked(
396397
%mem: memref<?x?x?x?xf32>,
397-
%dim: index,
398+
%mask: vector<[4]x3xi1>,
398399
%idx: index) -> vector<8x[4]x2x3xf32> {
399400

400401
%pad = arith.constant 0.000000e+00 : f32
401-
%mask = vector.create_mask %dim, %dim: vector<[4]x3xi1>
402402

403403
%res = vector.mask %mask {
404404
vector.transfer_read %mem[%idx, %idx, %idx, %idx], %pad {

0 commit comments

Comments
 (0)