@@ -70,10 +70,11 @@ func.func @transfer_read_dims_match_contiguous_empty_stride(
7070
7171// -----
7272
73- // The shape of the memref and the vector don't match, but the vector is a
74- // contiguous subset of the memref, so "flattenable".
73+ // The shape of the memref and the vector don't match, but the vector,
74+ // ignoring the unit dimensions, is a contiguous subset of the memref,
75+ // so "flattenable"
7576
76- func.func @transfer_read_dims_mismatch_contiguous (
77+ func.func @transfer_read_dims_mismatch_contiguous_unit_dims (
7778 %mem : memref <5 x4 x3 x2 xi8 , strided <[24 , 6 , 2 , 1 ], offset : ?>>) -> vector <1 x1 x2 x2 xi8 > {
7879
7980 %c0 = arith.constant 0 : index
@@ -83,7 +84,7 @@ func.func @transfer_read_dims_mismatch_contiguous(
8384 return %res : vector <1 x1 x2 x2 xi8 >
8485}
8586
86- // CHECK-LABEL: func.func @transfer_read_dims_mismatch_contiguous (
87+ // CHECK-LABEL: func.func @transfer_read_dims_mismatch_contiguous_unit_dims (
8788// CHECK-SAME: %[[MEM:.*]]: memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<1x1x2x2xi8> {
8889// CHECK: %[[VAL_1:.*]] = arith.constant 0 : i8
8990// CHECK: %[[VAL_2:.*]] = arith.constant 0 : index
@@ -92,7 +93,37 @@ func.func @transfer_read_dims_mismatch_contiguous(
9293// CHECK: %[[VAL_5:.*]] = vector.shape_cast %[[VAL_4]] : vector<4xi8> to vector<1x1x2x2xi8>
9394// CHECK: return %[[VAL_5]] : vector<1x1x2x2xi8>
9495
95- // CHECK-128B-LABEL: func @transfer_read_dims_mismatch_contiguous(
96+ // CHECK-128B-LABEL: func @transfer_read_dims_mismatch_contiguous_unit_dims(
97+ // CHECK-128B: memref.collapse_shape
98+
99+ // -----
100+
101+ // The shape of the memref and the vector don't match, but the vector is a
102+ // contiguous subset of the memref, so "flattenable"
103+
104+ func.func @transfer_read_dims_mismatch_contiguous_non_unit_dims (
105+ %mem : memref <5 x4 x3 x2 xi8 , strided <[24 , 6 , 2 , 1 ], offset : ?>>) -> vector <2 x3 x2 xi8 > {
106+
107+ %c0 = arith.constant 0 : index
108+ %cst = arith.constant 0 : i8
109+ %res = vector.transfer_read %mem [%c0 , %c0 , %c0 , %c0 ], %cst :
110+ memref <5 x4 x3 x2 xi8 , strided <[24 , 6 , 2 , 1 ], offset : ?>>, vector <2 x3 x2 xi8 >
111+ return %res : vector <2 x3 x2 xi8 >
112+ }
113+
114+ // CHECK-LABEL: func.func @transfer_read_dims_mismatch_contiguous_non_unit_dims(
115+ // CHECK-SAME: %[[MEM:.+]]: memref<5x4x3x2xi8, {{.+}}>) -> vector<2x3x2xi8> {
116+ // CHECK: %[[C0_I8:.+]] = arith.constant 0 : i8
117+ // CHECK: %[[C0:.+]] = arith.constant 0 : index
118+ // CHECK: %[[COLLAPSED_MEM:.+]] = memref.collapse_shape %[[MEM]]
119+ // CHECK-SAME{LITERAL}: [[0, 1, 2, 3]]
120+ // CHECK-SAME: : memref<5x4x3x2xi8, {{.+}}> into memref<120xi8, {{.+}}>
121+ // CHECK: %[[VEC_1D:.+]] = vector.transfer_read %[[COLLAPSED_MEM]][%[[C0]]], %[[C0_I8]] {in_bounds = [true]}
122+ // CHECK-SAME: : memref<120xi8, strided<[1], offset: ?>>, vector<12xi8>
123+ // CHECK: %[[VEC:.+]] = vector.shape_cast %[[VEC_1D]] : vector<12xi8> to vector<2x3x2xi8>
124+ // CHECK: return %[[VEC]] : vector<2x3x2xi8>
125+
126+ // CHECK-128B-LABEL: func @transfer_read_dims_mismatch_contiguous_non_unit_dims(
96127// CHECK-128B: memref.collapse_shape
97128
98129// -----
@@ -384,7 +415,7 @@ func.func @transfer_write_dims_match_contiguous_empty_stride(
384415
385416// -----
386417
387- func.func @transfer_write_dims_mismatch_contiguous (
418+ func.func @transfer_write_dims_mismatch_contiguous_unit_dims (
388419 %mem : memref <5 x4 x3 x2 xi8 , strided <[24 , 6 , 2 , 1 ], offset : ?>>,
389420 %vec : vector <1 x1 x2 x2 xi8 >) {
390421
@@ -394,15 +425,41 @@ func.func @transfer_write_dims_mismatch_contiguous(
394425 return
395426}
396427
397- // CHECK-LABEL: func.func @transfer_write_dims_mismatch_contiguous
428+ // CHECK-LABEL: func.func @transfer_write_dims_mismatch_contiguous_unit_dims
398429// CHECK-SAME: %[[MEM:.*]]: memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>,
399430// CHECK-SAME: %[[VEC:.*]]: vector<1x1x2x2xi8>) {
400431// CHECK: %[[VAL_2:.*]] = arith.constant 0 : index
401432// CHECK: %[[VAL_3:.*]] = memref.collapse_shape %[[MEM]] {{\[\[}}0, 1, 2, 3]] : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>> into memref<120xi8, strided<[1], offset: ?>>
402433// CHECK: %[[VAL_4:.*]] = vector.shape_cast %[[VEC]] : vector<1x1x2x2xi8> to vector<4xi8>
403434// CHECK: vector.transfer_write %[[VAL_4]], %[[VAL_3]]{{\[}}%[[VAL_2]]] {in_bounds = [true]} : vector<4xi8>, memref<120xi8, strided<[1], offset: ?>>
404435
405- // CHECK-128B-LABEL: func @transfer_write_dims_mismatch_contiguous(
436+ // CHECK-128B-LABEL: func @transfer_write_dims_mismatch_contiguous_unit_dims(
437+ // CHECK-128B: memref.collapse_shape
438+
439+ // -----
440+
441+ func.func @transfer_write_dims_mismatch_contiguous_non_unit_dims (
442+ %mem : memref <5 x4 x3 x2 xi8 , strided <[24 , 6 , 2 , 1 ], offset : ?>>,
443+ %vec : vector <2 x2 xi8 >) {
444+
445+ %c0 = arith.constant 0 : index
446+ vector.transfer_write %vec , %mem [%c0 , %c0 , %c0 , %c0 ] :
447+ vector <2 x2 xi8 >, memref <5 x4 x3 x2 xi8 , strided <[24 , 6 , 2 , 1 ], offset : ?>>
448+ return
449+ }
450+
451+ // CHECK-LABEL: func.func @transfer_write_dims_mismatch_contiguous_non_unit_dims
452+ // CHECK-SAME: %[[MEM:.+]]: memref<5x4x3x2xi8, {{.+}}>,
453+ // CHECK-SAME: %[[VEC:.+]]: vector<2x2xi8>
454+ // CHECK: %[[C0:.+]] = arith.constant 0 : index
455+ // CHECK: %[[COLLAPSED_MEM:.+]] = memref.collapse_shape %[[MEM]]
456+ // CHECK-SAME{LITERAL}: [[0, 1, 2, 3]]
457+ // CHECK-SAME: : memref<5x4x3x2xi8, {{.+}}> into memref<120xi8, {{.+}}>
458+ // CHECK: %[[VEC_1D:.+]] = vector.shape_cast %[[VEC]] : vector<2x2xi8> to vector<4xi8>
459+ // CHECK: vector.transfer_write %[[VEC_1D]], %[[COLLAPSED_MEM]][%[[C0]]] {in_bounds = [true]}
460+ // CHECK-SAME: : vector<4xi8>, memref<120xi8, {{.+}}>
461+
462+ // CHECK-128B-LABEL: func @transfer_write_dims_mismatch_contiguous_non_unit_dims(
406463// CHECK-128B: memref.collapse_shape
407464
408465// -----
@@ -626,6 +683,7 @@ func.func @negative_out_of_bound_transfer_read(
626683}
627684// CHECK-LABEL: func.func @negative_out_of_bound_transfer_read
628685// CHECK-NOT: memref.collapse_shape
686+ // CHECK-NOT: vector.shape_cast
629687
630688// CHECK-128B-LABEL: func.func @negative_out_of_bound_transfer_read
631689// CHECK-128B-NOT: memref.collapse_shape
@@ -642,45 +700,9 @@ func.func @negative_out_of_bound_transfer_write(
642700}
643701// CHECK-LABEL: func.func @negative_out_of_bound_transfer_write
644702// CHECK-NOT: memref.collapse_shape
703+ // CHECK-NOT: vector.shape_cast
645704
646705// CHECK-128B-LABEL: func.func @negative_out_of_bound_transfer_write
647706// CHECK-128B-NOT: memref.collapse_shape
648707// CHECK-128B-NOT: vector.shape_cast
649708
650- // -----
651-
652- func.func @discontig_mem_contig_slice (
653- %mem : memref <8 x8 x8 xi32 , strided <[128 , 16 , 1 ]>>, %vec : vector <1 x1 x8 xi32 >) {
654- %c0 = arith.constant 0 : index
655- vector.transfer_write %vec , %mem [%c0 , %c0 , %c0 ] {in_bounds = [true , true , true ]} :
656- vector <1 x1 x8 xi32 >, memref <8 x8 x8 xi32 , strided <[128 , 16 , 1 ]>>
657- return
658- }
659-
660- // CHECK-LABEL: func.func @discontig_mem_contig_slice
661- // CHECK-SAME: %[[MEM:.+]]: memref<8x8x8xi32, strided<[128, 16, 1]>>
662- // CHECK-SAME: %[[VEC:.+]]: vector<1x1x8xi32>
663- // CHECK: %[[C0:.+]] = arith.constant 0 : index
664- // CHECK: %[[VEC_1D:.+]] = vector.shape_cast %[[VEC]] : vector<1x1x8xi32> to vector<8xi32>
665- // CHECK: vector.transfer_write %[[VEC_1D]], %[[MEM]][%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true]}
666- // CHECK-SAME: : vector<8xi32>, memref<8x8x8xi32, strided<[128, 16, 1]>>
667-
668- // CHECK-128B-LABEL: func.func @discontig_mem_contig_slice
669- // CHECK-128B-NOT: vector.shape_cast
670-
671- // -----
672-
673- func.func @discontig_mem_discontig_slice (
674- %mem : memref <8 x8 x8 xi32 , strided <[128 , 16 , 1 ]>>, %vec : vector <1 x2 x8 xi32 >) {
675- %c0 = arith.constant 0 : index
676- vector.transfer_write %vec , %mem [%c0 , %c0 , %c0 ] {in_bounds = [true , true , true ]} :
677- vector <1 x2 x8 xi32 >, memref <8 x8 x8 xi32 , strided <[128 , 16 , 1 ]>>
678- return
679- }
680-
681- // CHECK-LABEL: func.func @discontig_mem_discontig_slice
682- // CHECK-NOT: vector.shape_cast
683-
684- // CHECK-128B-LABEL: func.func @discontig_mem_discontig_slice
685- // CHECK-128B-NOT: vector.shape_cast
686-
0 commit comments