@@ -136,6 +136,34 @@ func.func @gather_tensor_1d(%base: tensor<?xf32>, %v: vector<2xindex>, %mask: ve
136
136
return %0 : vector <2 xf32 >
137
137
}
138
138
139
+ // CHECK-LABEL: @gather_memref_non_unit_stride_read_1_element
140
+ // CHECK: %[[MASK:.*]] = vector.extract %arg2[0] : i1 from vector<1xi1>
141
+ // CHECK: %[[IDX:.*]] = vector.extract %arg1[0] : index from vector<1xindex>
142
+ // CHECK: %[[RET:.*]] = scf.if %[[MASK]] -> (vector<1xf32>) {
143
+ // CHECK: %[[VEC:.*]] = vector.load %arg0[%[[IDX]]] : memref<4xf32, strided<[2]>>, vector<1xf32>
144
+ // CHECK: %[[VAL:.*]] = vector.extract %[[VEC]][0] : f32 from vector<1xf32>
145
+ // CHECK: %[[RES:.*]] = vector.insert %[[VAL]], %arg3 [0] : f32 into vector<1xf32>
146
+ // CHECK: scf.yield %[[RES]] : vector<1xf32>
147
+ // CHECK: } else {
148
+ // CHECK: scf.yield %arg3 : vector<1xf32>
149
+ // CHECK: }
150
+ // CHECK: return %[[RET]] : vector<1xf32>
151
+ func.func @gather_memref_non_unit_stride_read_1_element (%base: memref <4 xf32 , strided <[2 ]>>, %v: vector <1 xindex >, %mask: vector <1 xi1 >, %pass_thru: vector <1 xf32 >) -> vector <1 xf32 > {
152
+ %c0 = arith.constant 0 : index
153
+ %0 = vector.gather %base [%c0 ][%v ], %mask , %pass_thru : memref <4 xf32 , strided <[2 ]>>, vector <1 xindex >, vector <1 xi1 >, vector <1 xf32 > into vector <1 xf32 >
154
+ return %0 : vector <1 xf32 >
155
+ }
156
+
157
+ // CHECK-LABEL: @gather_memref_non_unit_stride_read_more_than_1_element
158
+ // CHECK: %[[CONST:.*]] = arith.constant 0 : index
159
+ // CHECK: %[[RET:.*]] = vector.gather %arg0[%[[CONST]]] [%arg1], %arg2, %arg3 : memref<4xf32, strided<[2]>>, vector<2xindex>, vector<2xi1>, vector<2xf32> into vector<2xf32>
160
+ // CHECK: return %[[RET]] : vector<2xf32>
161
+ func.func @gather_memref_non_unit_stride_read_more_than_1_element (%base: memref <4 xf32 , strided <[2 ]>>, %v: vector <2 xindex >, %mask: vector <2 xi1 >, %pass_thru: vector <2 xf32 >) -> vector <2 xf32 > {
162
+ %c0 = arith.constant 0 : index
163
+ %0 = vector.gather %base [%c0 ][%v ], %mask , %pass_thru : memref <4 xf32 , strided <[2 ]>>, vector <2 xindex >, vector <2 xi1 >, vector <2 xf32 > into vector <2 xf32 >
164
+ return %0 : vector <2 xf32 >
165
+ }
166
+
139
167
// CHECK-LABEL: @gather_tensor_2d
140
168
// CHECK: scf.if
141
169
// CHECK: tensor.extract
0 commit comments