@@ -70,10 +70,11 @@ func.func @transfer_read_dims_match_contiguous_empty_stride(
7070
7171// -----
7272
73- // The shape of the memref and the vector don't match, but the vector is a
74- // contiguous subset of the memref, so "flattenable".
73+ // The shape of the memref and the vector don't match, but the vector,
74+ // ignoring the unit dimensions, is a contiguous subset of the memref,
75+ // so "flattenable"
7576
76- func.func @transfer_read_dims_mismatch_contiguous (
77+ func.func @transfer_read_dims_mismatch_contiguous_unit_dims (
7778 %mem : memref <5 x4 x3 x2 xi8 , strided <[24 , 6 , 2 , 1 ], offset : ?>>) -> vector <1 x1 x2 x2 xi8 > {
7879
7980 %c0 = arith.constant 0 : index
@@ -83,7 +84,7 @@ func.func @transfer_read_dims_mismatch_contiguous(
8384 return %res : vector <1 x1 x2 x2 xi8 >
8485}
8586
86- // CHECK-LABEL: func.func @transfer_read_dims_mismatch_contiguous (
87+ // CHECK-LABEL: func.func @transfer_read_dims_mismatch_contiguous_unit_dims (
8788// CHECK-SAME: %[[MEM:.*]]: memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<1x1x2x2xi8> {
8889// CHECK: %[[VAL_1:.*]] = arith.constant 0 : i8
8990// CHECK: %[[VAL_2:.*]] = arith.constant 0 : index
@@ -92,7 +93,37 @@ func.func @transfer_read_dims_mismatch_contiguous(
9293// CHECK: %[[VAL_5:.*]] = vector.shape_cast %[[VAL_4]] : vector<4xi8> to vector<1x1x2x2xi8>
9394// CHECK: return %[[VAL_5]] : vector<1x1x2x2xi8>
9495
95- // CHECK-128B-LABEL: func @transfer_read_dims_mismatch_contiguous(
96+ // CHECK-128B-LABEL: func @transfer_read_dims_mismatch_contiguous_unit_dims(
97+ // CHECK-128B: memref.collapse_shape
98+
99+ // -----
100+
101+ // The shape of the memref and the vector don't match, but the vector is a
102+ // contiguous subset of the memref, so "flattenable"
103+
104+ func.func @transfer_read_dims_mismatch_contiguous_non_unit_dims (
105+ %mem : memref <5 x4 x3 x2 xi8 , strided <[24 , 6 , 2 , 1 ], offset : ?>>) -> vector <2 x3 x2 xi8 > {
106+
107+ %c0 = arith.constant 0 : index
108+ %cst = arith.constant 0 : i8
109+ %res = vector.transfer_read %mem [%c0 , %c0 , %c0 , %c0 ], %cst :
110+ memref <5 x4 x3 x2 xi8 , strided <[24 , 6 , 2 , 1 ], offset : ?>>, vector <2 x3 x2 xi8 >
111+ return %res : vector <2 x3 x2 xi8 >
112+ }
113+
114+ // CHECK-LABEL: func.func @transfer_read_dims_mismatch_contiguous_non_unit_dims(
115+ // CHECK-SAME: %[[MEM:.+]]: memref<5x4x3x2xi8, {{.+}}>) -> vector<2x3x2xi8> {
116+ // CHECK: %[[C0_I8:.+]] = arith.constant 0 : i8
117+ // CHECK: %[[C0:.+]] = arith.constant 0 : index
118+ // CHECK: %[[COLLAPSED_MEM:.+]] = memref.collapse_shape %[[MEM]]
119+ // CHECK-SAME{LITERAL}: [[0, 1, 2, 3]]
120+ // CHECK-SAME: : memref<5x4x3x2xi8, {{.+}}> into memref<120xi8, {{.+}}>
121+ // CHECK: %[[VEC_1D:.+]] = vector.transfer_read %[[COLLAPSED_MEM]][%[[C0]]], %[[C0_I8]] {in_bounds = [true]}
122+ // CHECK-SAME: : memref<120xi8, strided<[1], offset: ?>>, vector<12xi8>
123+ // CHECK: %[[VEC:.+]] = vector.shape_cast %[[VEC_1D]] : vector<12xi8> to vector<2x3x2xi8>
124+ // CHECK: return %[[VEC]] : vector<2x3x2xi8>
125+
126+ // CHECK-128B-LABEL: func @transfer_read_dims_mismatch_contiguous_non_unit_dims(
96127// CHECK-128B: memref.collapse_shape
97128
98129// -----
@@ -380,7 +411,7 @@ func.func @transfer_write_dims_match_contiguous_empty_stride(
380411
381412// -----
382413
383- func.func @transfer_write_dims_mismatch_contiguous (
414+ func.func @transfer_write_dims_mismatch_contiguous_unit_dims (
384415 %mem : memref <5 x4 x3 x2 xi8 , strided <[24 , 6 , 2 , 1 ], offset : ?>>,
385416 %vec : vector <1 x1 x2 x2 xi8 >) {
386417
@@ -390,15 +421,41 @@ func.func @transfer_write_dims_mismatch_contiguous(
390421 return
391422}
392423
393- // CHECK-LABEL: func.func @transfer_write_dims_mismatch_contiguous
424+ // CHECK-LABEL: func.func @transfer_write_dims_mismatch_contiguous_unit_dims
394425// CHECK-SAME: %[[MEM:.*]]: memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>,
395426// CHECK-SAME: %[[VEC:.*]]: vector<1x1x2x2xi8>) {
396427// CHECK: %[[VAL_2:.*]] = arith.constant 0 : index
397428// CHECK: %[[VAL_3:.*]] = memref.collapse_shape %[[MEM]] {{\[\[}}0, 1, 2, 3]] : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>> into memref<120xi8, strided<[1], offset: ?>>
398429// CHECK: %[[VAL_4:.*]] = vector.shape_cast %[[VEC]] : vector<1x1x2x2xi8> to vector<4xi8>
399430// CHECK: vector.transfer_write %[[VAL_4]], %[[VAL_3]]{{\[}}%[[VAL_2]]] {in_bounds = [true]} : vector<4xi8>, memref<120xi8, strided<[1], offset: ?>>
400431
401- // CHECK-128B-LABEL: func @transfer_write_dims_mismatch_contiguous(
432+ // CHECK-128B-LABEL: func @transfer_write_dims_mismatch_contiguous_unit_dims(
433+ // CHECK-128B: memref.collapse_shape
434+
435+ // -----
436+
437+ func.func @transfer_write_dims_mismatch_contiguous_non_unit_dims (
438+ %mem : memref <5 x4 x3 x2 xi8 , strided <[24 , 6 , 2 , 1 ], offset : ?>>,
439+ %vec : vector <2 x2 xi8 >) {
440+
441+ %c0 = arith.constant 0 : index
442+ vector.transfer_write %vec , %mem [%c0 , %c0 , %c0 , %c0 ] :
443+ vector <2 x2 xi8 >, memref <5 x4 x3 x2 xi8 , strided <[24 , 6 , 2 , 1 ], offset : ?>>
444+ return
445+ }
446+
447+ // CHECK-LABEL: func.func @transfer_write_dims_mismatch_contiguous_non_unit_dims
448+ // CHECK-SAME: %[[MEM:.+]]: memref<5x4x3x2xi8, {{.+}}>,
449+ // CHECK-SAME: %[[VEC:.+]]: vector<2x2xi8>
450+ // CHECK: %[[C0:.+]] = arith.constant 0 : index
451+ // CHECK: %[[COLLAPSED_MEM:.+]] = memref.collapse_shape %[[MEM]]
452+ // CHECK-SAME{LITERAL}: [[0, 1, 2, 3]]
453+ // CHECK-SAME: : memref<5x4x3x2xi8, {{.+}}> into memref<120xi8, {{.+}}>
454+ // CHECK: %[[VEC_1D:.+]] = vector.shape_cast %[[VEC]] : vector<2x2xi8> to vector<4xi8>
455+ // CHECK: vector.transfer_write %[[VEC_1D]], %[[COLLAPSED_MEM]][%[[C0]]] {in_bounds = [true]}
456+ // CHECK-SAME: : vector<4xi8>, memref<120xi8, {{.+}}>
457+
458+ // CHECK-128B-LABEL: func @transfer_write_dims_mismatch_contiguous_non_unit_dims(
402459// CHECK-128B: memref.collapse_shape
403460
404461// -----
@@ -620,6 +677,7 @@ func.func @negative_out_of_bound_transfer_read(
620677}
621678// CHECK-LABEL: func.func @negative_out_of_bound_transfer_read
622679// CHECK-NOT: memref.collapse_shape
680+ // CHECK-NOT: vector.shape_cast
623681
624682// CHECK-128B-LABEL: func.func @negative_out_of_bound_transfer_read
625683// CHECK-128B-NOT: memref.collapse_shape
@@ -638,45 +696,9 @@ func.func @negative_out_of_bound_transfer_write(
638696}
639697// CHECK-LABEL: func.func @negative_out_of_bound_transfer_write
640698// CHECK-NOT: memref.collapse_shape
699+ // CHECK-NOT: vector.shape_cast
641700
642701// CHECK-128B-LABEL: func.func @negative_out_of_bound_transfer_write
643702// CHECK-128B-NOT: memref.collapse_shape
644703// CHECK-128B-NOT: vector.shape_cast
645704
646- // -----
647-
648- func.func @discontig_mem_contig_slice (
649- %mem : memref <8 x8 x8 xi32 , strided <[128 , 16 , 1 ]>>, %vec : vector <1 x1 x8 xi32 >) {
650- %c0 = arith.constant 0 : index
651- vector.transfer_write %vec , %mem [%c0 , %c0 , %c0 ] {in_bounds = [true , true , true ]} :
652- vector <1 x1 x8 xi32 >, memref <8 x8 x8 xi32 , strided <[128 , 16 , 1 ]>>
653- return
654- }
655-
656- // CHECK-LABEL: func.func @discontig_mem_contig_slice
657- // CHECK-SAME: %[[MEM:.+]]: memref<8x8x8xi32, strided<[128, 16, 1]>>
658- // CHECK-SAME: %[[VEC:.+]]: vector<1x1x8xi32>
659- // CHECK: %[[C0:.+]] = arith.constant 0 : index
660- // CHECK: %[[VEC_1D:.+]] = vector.shape_cast %[[VEC]] : vector<1x1x8xi32> to vector<8xi32>
661- // CHECK: vector.transfer_write %[[VEC_1D]], %[[MEM]][%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true]}
662- // CHECK-SAME: : vector<8xi32>, memref<8x8x8xi32, strided<[128, 16, 1]>>
663-
664- // CHECK-128B-LABEL: func.func @discontig_mem_contig_slice
665- // CHECK-128B-NOT: vector.shape_cast
666-
667- // -----
668-
669- func.func @discontig_mem_discontig_slice (
670- %mem : memref <8 x8 x8 xi32 , strided <[128 , 16 , 1 ]>>, %vec : vector <1 x2 x8 xi32 >) {
671- %c0 = arith.constant 0 : index
672- vector.transfer_write %vec , %mem [%c0 , %c0 , %c0 ] {in_bounds = [true , true , true ]} :
673- vector <1 x2 x8 xi32 >, memref <8 x8 x8 xi32 , strided <[128 , 16 , 1 ]>>
674- return
675- }
676-
677- // CHECK-LABEL: func.func @discontig_mem_discontig_slice
678- // CHECK-NOT: vector.shape_cast
679-
680- // CHECK-128B-LABEL: func.func @discontig_mem_discontig_slice
681- // CHECK-128B-NOT: vector.shape_cast
682-
0 commit comments