11// RUN: mlir-opt %s --transform-interpreter | FileCheck %s
22
3+ //-----------------------------------------------------------------------------
4+ // [Patterns: TransferWriteDropUnitDimsPattern, TransferReadeDropUnitDimsPattern]
5+ //-----------------------------------------------------------------------------
6+
37func.func @transfer_read_rank_reducing (
48 %arg : memref <1 x1 x3 x2 xi8 , strided <[6 , 6 , 2 , 1 ], offset : ?>>) -> vector <3 x2 xi8 > {
59 %c0 = arith.constant 0 : index
@@ -14,7 +18,29 @@ func.func @transfer_read_rank_reducing(
1418// CHECK-SAME: memref<1x1x3x2xi8, {{.*}}> to memref<3x2xi8, {{.*}}>
1519// CHECK: vector.transfer_read %[[SUBVIEW]]
1620
17- func.func @transfer_write_rank_reducing (%arg : memref <1 x1 x3 x2 xi8 , strided <[6 , 6 , 2 , 1 ], offset : ?>>, %vec : vector <3 x2 xi8 >) {
21+ func.func @transfer_read_rank_reducing_masked (
22+ %arg : memref <1 x1 x3 x2 xi8 , strided <[6 , 6 , 2 , 1 ], offset : ?>>,
23+ %mask: vector <3 x2 xi1 >) -> vector <3 x2 xi8 > {
24+ %c0 = arith.constant 0 : index
25+ %cst = arith.constant 0 : i8
26+ %v = vector.mask %mask {
27+ vector.transfer_read %arg [%c0 , %c0 , %c0 , %c0 ], %cst :
28+ memref <1 x1 x3 x2 xi8 , strided <[6 , 6 , 2 , 1 ], offset : ?>>, vector <3 x2 xi8 >
29+ } : vector <3 x2 xi1 > -> vector <3 x2 xi8 >
30+ return %v : vector <3 x2 xi8 >
31+ }
32+ // CHECK-LABEL: func @transfer_read_rank_reducing_masked
33+ // CHECK-SAME: %[[ARG:.+]]: memref<1x1x3x2xi8
34+ // CHECK-SAME: %[[MASK:.+]]: vector<3x2xi1>
35+ // CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0, 0, 0] [1, 1, 3, 2] [1, 1, 1, 1]
36+ // CHECK-SAME: memref<1x1x3x2xi8, {{.*}}> to memref<3x2xi8, {{.*}}>
37+ // CHECK: vector.mask %[[MASK]]
38+ // CHECK-SAME: vector.transfer_read %[[SUBVIEW]]
39+
40+ func.func @transfer_write_rank_reducing (
41+ %arg : memref <1 x1 x3 x2 xi8 , strided <[6 , 6 , 2 , 1 ], offset : ?>>,
42+ %vec : vector <3 x2 xi8 >) {
43+
1844 %c0 = arith.constant 0 : index
1945 vector.transfer_write %vec , %arg [%c0 , %c0 , %c0 , %c0 ] :
2046 vector <3 x2 xi8 >, memref <1 x1 x3 x2 xi8 , strided <[6 , 6 , 2 , 1 ], offset : ?>>
@@ -26,6 +52,26 @@ func.func @transfer_write_rank_reducing(%arg : memref<1x1x3x2xi8, strided<[6, 6,
2652// CHECK-SAME: memref<1x1x3x2xi8, {{.*}}> to memref<3x2xi8, {{.*}}>
2753// CHECK: vector.transfer_write %{{.*}}, %[[SUBVIEW]]
2854
55+ func.func @transfer_write_rank_reducing_masked (
56+ %arg : memref <1 x1 x3 x2 xi8 , strided <[6 , 6 , 2 , 1 ], offset : ?>>,
57+ %vec : vector <3 x2 xi8 >,
58+ %mask: vector <3 x2 xi1 >) {
59+ %c0 = arith.constant 0 : index
60+ vector.mask %mask {
61+ vector.transfer_write %vec , %arg [%c0 , %c0 , %c0 , %c0 ] :
62+ vector <3 x2 xi8 >, memref <1 x1 x3 x2 xi8 , strided <[6 , 6 , 2 , 1 ], offset : ?>>
63+ } : vector <3 x2 xi1 >
64+ return
65+ }
66+ // CHECK-LABEL: func @transfer_write_rank_reducing_masked
67+ // CHECK-SAME: %[[ARG:.+]]: memref<1x1x3x2xi8
68+ // CHECK-SAME: %[[VEC:.+]]: vector<3x2xi8>
69+ // CHECK-SAME: %[[MASK:.+]]: vector<3x2xi1>
70+ // CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0, 0, 0] [1, 1, 3, 2] [1, 1, 1, 1]
71+ // CHECK-SAME: memref<1x1x3x2xi8, {{.*}}> to memref<3x2xi8, {{.*}}>
72+ // CHECK: vector.mask %[[MASK]]
73+ // CHECK-SAME: vector.transfer_write %{{.*}}, %[[SUBVIEW]]
74+
2975func.func @transfer_read_and_vector_rank_reducing (
3076 %arg : memref <1 x1 x3 x2 x1 xf32 >) -> vector <3 x2 x1 xf32 > {
3177 %c0 = arith.constant 0 : index
@@ -68,6 +114,22 @@ func.func @transfer_read_and_vector_rank_reducing_to_0d(
68114// CHECK: %[[READ:.+]] = vector.transfer_read %[[SUBVIEW]]{{.*}} : memref<f32>, vector<f32>
69115// CHECK: vector.shape_cast %[[READ]] : vector<f32> to vector<1x1x1xf32>
70116
117+ func.func @transfer_read_and_vector_rank_reducing_to_0d_masked (
118+ %arg : memref <1 x1 x1 x1 x1 xf32 >,
119+ %mask: vector <1 x1 x1 xi1 >) -> vector <1 x1 x1 xf32 > {
120+
121+ %c0 = arith.constant 0 : index
122+ %cst = arith.constant 0.0 : f32
123+ %v = vector.mask %mask {
124+ vector.transfer_read %arg [%c0 , %c0 , %c0 , %c0 , %c0 ], %cst
125+ : memref <1 x1 x1 x1 x1 xf32 >, vector <1 x1 x1 xf32 >
126+ } : vector <1 x1 x1 xi1 > -> vector <1 x1 x1 xf32 >
127+ return %v : vector <1 x1 x1 xf32 >
128+ }
129+ // CHECK-LABEL: func @transfer_read_and_vector_rank_reducing_to_0d_masked
130+ // CHECK-NOT: vector.shape_cast
131+ // CHECK-NOT: memref.subview
132+
71133func.func @transfer_write_and_vector_rank_reducing_to_0d (
72134 %arg : memref <1 x1 x1 x1 x1 xf32 >,
73135 %vec : vector <1 x1 x1 xf32 >) {
@@ -82,6 +144,23 @@ func.func @transfer_write_and_vector_rank_reducing_to_0d(
82144// CHECK: %[[SHCAST:.+]] = vector.shape_cast %[[VECTOR]] : vector<1x1x1xf32> to vector<f32>
83145// CHECK: vector.transfer_write %[[SHCAST]], %[[SUBVIEW]]{{.*}} : vector<f32>, memref<f32>
84146
147+ func.func @transfer_write_and_vector_rank_reducing_to_0d_masked (
148+ %arg : memref <1 x1 x1 x1 x1 xf32 >,
149+ %vec : vector <1 x1 x1 xf32 >,
150+ %mask: vector <1 x1 x1 xi1 >) {
151+
152+ %c0 = arith.constant 0 : index
153+ %cst = arith.constant 0.0 : f32
154+ vector.mask %mask {
155+ vector.transfer_write %vec , %arg [%c0 , %c0 , %c0 , %c0 , %c0 ] :
156+ vector <1 x1 x1 xf32 >, memref <1 x1 x1 x1 x1 xf32 >
157+ } : vector <1 x1 x1 xi1 >
158+ return
159+ }
160+ // CHECK-LABEL: func @transfer_write_and_vector_rank_reducing_to_0d_masked
161+ // CHECK-NOT: vector.shape_cast
162+ // CHECK-NOT: memref.subview
163+
85164func.func @transfer_read_dynamic_rank_reducing (
86165 %arg : memref <?x1 xi8 , strided <[?, ?], offset : ?>>) -> vector <[16 ]x1 xi8 > {
87166 %c0 = arith.constant 0 : index
0 commit comments