@@ -70,28 +70,29 @@ func.func @transfer_read_dims_match_contiguous_empty_stride(
7070
7171// -----
7272
73- // The shape of the memref and the vector don't match, but the vector,
74- // ignoring the unit dimensions, is a contiguous subset of the memref,
75- // so "flattenable"
73+ // The shape of the memref and the vector don't match, but the vector is a
74+ // contiguous subset of the memref, so "flattenable". The leading unit dimensions
75+ // of the vector have no effect on the memref area read even if they
76+ // span a non-contiguous part of the memref.
7677
7778func.func @transfer_read_dims_mismatch_contiguous_unit_dims (
78- %mem : memref <5 x4 x3 x2 xi8 , strided <[24 , 6 , 2 , 1 ], offset : ?>>) -> vector <1 x1 x2 x2 xi8 > {
79+ %mem : memref <5 x4 x3 x2 xi8 , strided <[48 , 6 , 2 , 1 ], offset : ?>>) -> vector <1 x1 x2 x2 xi8 > {
7980
8081 %c0 = arith.constant 0 : index
8182 %cst = arith.constant 0 : i8
8283 %res = vector.transfer_read %mem [%c0 , %c0 , %c0 , %c0 ], %cst :
83- memref <5 x4 x3 x2 xi8 , strided <[24 , 6 , 2 , 1 ], offset : ?>>, vector <1 x1 x2 x2 xi8 >
84+ memref <5 x4 x3 x2 xi8 , strided <[48 , 6 , 2 , 1 ], offset : ?>>, vector <1 x1 x2 x2 xi8 >
8485 return %res : vector <1 x1 x2 x2 xi8 >
8586}
8687
8788// CHECK-LABEL: func.func @transfer_read_dims_mismatch_contiguous_unit_dims(
88- // CHECK-SAME: %[[MEM:.*]]: memref<5x4x3x2xi8, strided<[24 , 6, 2, 1], offset: ?>>) -> vector<1x1x2x2xi8> {
89+ // CHECK-SAME: %[[MEM:.*]]: memref<5x4x3x2xi8, strided<[48 , 6, 2, 1], offset: ?>>) -> vector<1x1x2x2xi8> {
8990// CHECK: %[[VAL_1:.*]] = arith.constant 0 : i8
9091// CHECK: %[[VAL_2:.*]] = arith.constant 0 : index
9192// CHECK: %[[VAL_3:.*]] = memref.collapse_shape %[[MEM]]
9293// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]]
93- // CHECK-SAME: : memref<5x4x3x2xi8, strided<[24 , 6, 2, 1], offset: ?>> into memref<5x4x6xi8, strided<[24 , 6, 1], offset: ?>>
94- // CHECK: %[[VAL_4:.*]] = vector.transfer_read %[[VAL_3]][%[[VAL_2]], %[[VAL_2]], %[[VAL_2]]], %[[VAL_1]] {in_bounds = [true]} : memref<5x4x6xi8, strided<[24 , 6, 1], offset: ?>>, vector<4xi8>
94+ // CHECK-SAME: : memref<5x4x3x2xi8, strided<[48 , 6, 2, 1], offset: ?>> into memref<5x4x6xi8, strided<[48 , 6, 1], offset: ?>>
95+ // CHECK: %[[VAL_4:.*]] = vector.transfer_read %[[VAL_3]][%[[VAL_2]], %[[VAL_2]], %[[VAL_2]]], %[[VAL_1]] {in_bounds = [true]} : memref<5x4x6xi8, strided<[48 , 6, 1], offset: ?>>, vector<4xi8>
9596// CHECK: %[[VAL_5:.*]] = vector.shape_cast %[[VAL_4]] : vector<4xi8> to vector<1x1x2x2xi8>
9697// CHECK: return %[[VAL_5]] : vector<1x1x2x2xi8>
9798
@@ -412,31 +413,40 @@ func.func @transfer_write_dims_match_contiguous_empty_stride(
412413
413414// -----
414415
416+ // The shape of the memref and the vector don't match, but the vector is a
417+ // contiguous subset of the memref, so "flattenable". The leading unit dimensions
418+ // of the vector have no effect on the memref area written even if they
419+ // span a non-contiguous part of the memref.
420+
415421func.func @transfer_write_dims_mismatch_contiguous_unit_dims (
416- %mem : memref <5 x4 x3 x2 xi8 , strided <[24 , 6 , 2 , 1 ], offset : ?>>,
422+ %mem : memref <5 x4 x3 x2 xi8 , strided <[48 , 6 , 2 , 1 ], offset : ?>>,
417423 %vec : vector <1 x1 x2 x2 xi8 >) {
418424
419425 %c0 = arith.constant 0 : index
420426 vector.transfer_write %vec , %mem [%c0 , %c0 , %c0 , %c0 ] :
421- vector <1 x1 x2 x2 xi8 >, memref <5 x4 x3 x2 xi8 , strided <[24 , 6 , 2 , 1 ], offset : ?>>
427+ vector <1 x1 x2 x2 xi8 >, memref <5 x4 x3 x2 xi8 , strided <[48 , 6 , 2 , 1 ], offset : ?>>
422428 return
423429}
424430
425431// CHECK-LABEL: func.func @transfer_write_dims_mismatch_contiguous_unit_dims
426- // CHECK-SAME: %[[MEM:.*]]: memref<5x4x3x2xi8, strided<[24 , 6, 2, 1], offset: ?>>,
432+ // CHECK-SAME: %[[MEM:.*]]: memref<5x4x3x2xi8, strided<[48 , 6, 2, 1], offset: ?>>,
427433// CHECK-SAME: %[[VEC:.*]]: vector<1x1x2x2xi8>) {
428434// CHECK: %[[C0:.*]] = arith.constant 0 : index
429435// CHECK: %[[COLLAPSED:.*]] = memref.collapse_shape %[[MEM]]
430436// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]]
431- // CHECK-SAME: : memref<5x4x3x2xi8, strided<[24 , 6, 2, 1], offset: ?>> into memref<5x4x6xi8, strided<[24 , 6, 1], offset: ?>>
437+ // CHECK-SAME: : memref<5x4x3x2xi8, strided<[48 , 6, 2, 1], offset: ?>> into memref<5x4x6xi8, strided<[48 , 6, 1], offset: ?>>
432438// CHECK: %[[VEC_1D:.*]] = vector.shape_cast %[[VEC]] : vector<1x1x2x2xi8> to vector<4xi8>
433- // CHECK: vector.transfer_write %[[VEC_1D]], %[[COLLAPSED]][%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true]} : vector<4xi8>, memref<5x4x6xi8, strided<[24, 6, 1], offset: ?>>
439+ // CHECK: vector.transfer_write %[[VEC_1D]], %[[COLLAPSED]][%[[C0]], %[[C0]], %[[C0]]]
440+ // CHECK-SAME: {in_bounds = [true]} : vector<4xi8>, memref<5x4x6xi8, strided<[48, 6, 1], offset: ?>>
434441
435442// CHECK-128B-LABEL: func @transfer_write_dims_mismatch_contiguous_unit_dims(
436443// CHECK-128B: memref.collapse_shape
437444
438445// -----
439446
447+ // The shape of the memref and the vector don't match, but the vector is a
448+ // contiguous subset of the memref, so "flattenable".
449+
440450func.func @transfer_write_dims_mismatch_contiguous_non_unit_dims (
441451 %mem : memref <5 x4 x3 x2 xi8 , strided <[24 , 6 , 2 , 1 ], offset : ?>>,
442452 %vec : vector <2 x2 xi8 >) {
0 commit comments