@@ -136,6 +136,34 @@ func.func @gather_tensor_1d(%base: tensor<?xf32>, %v: vector<2xindex>, %mask: ve
136136  return  %0  : vector <2 xf32 >
137137}
138138
139+ // CHECK-LABEL: @gather_memref_non_unit_stride_read_1_element 
140+ // CHECK: %[[MASK:.*]] = vector.extract %arg2[0] : i1 from vector<1xi1> 
141+ // CHECK: %[[IDX:.*]] = vector.extract %arg1[0] : index from vector<1xindex> 
142+ // CHECK: %[[RET:.*]] = scf.if %[[MASK]] -> (vector<1xf32>) { 
143+ // CHECK:   %[[VEC:.*]] = vector.load %arg0[%[[IDX]]] : memref<4xf32, strided<[2]>>, vector<1xf32> 
144+ // CHECK:   %[[VAL:.*]] = vector.extract %[[VEC]][0] : f32 from vector<1xf32> 
145+ // CHECK:   %[[RES:.*]] = vector.insert %[[VAL]], %arg3 [0] : f32 into vector<1xf32> 
146+ // CHECK:   scf.yield %[[RES]] : vector<1xf32> 
147+ // CHECK: } else { 
148+ // CHECK:    scf.yield %arg3 : vector<1xf32> 
149+ // CHECK: } 
150+ // CHECK: return %[[RET]] : vector<1xf32> 
151+ func.func  @gather_memref_non_unit_stride_read_1_element (%base:  memref <4 xf32 , strided <[2 ]>>, %v:  vector <1 xindex >, %mask:  vector <1 xi1 >, %pass_thru:  vector <1 xf32 >) -> vector <1 xf32 > {
152+   %c0  = arith.constant  0  : index 
153+   %0  = vector.gather  %base [%c0 ][%v ], %mask , %pass_thru  : memref <4 xf32 , strided <[2 ]>>, vector <1 xindex >, vector <1 xi1 >, vector <1 xf32 > into  vector <1 xf32 >
154+   return  %0  : vector <1 xf32 >
155+ }
156+ 
157+ // CHECK-LABEL: @gather_memref_non_unit_stride_read_more_than_1_element 
158+ // CHECK: %[[CONST:.*]] = arith.constant 0 : index 
159+ // CHECK: %[[RET:.*]] = vector.gather %arg0[%[[CONST]]] [%arg1], %arg2, %arg3 : memref<4xf32, strided<[2]>>, vector<2xindex>, vector<2xi1>, vector<2xf32> into vector<2xf32> 
160+ // CHECK: return %[[RET]] : vector<2xf32> 
161+ func.func  @gather_memref_non_unit_stride_read_more_than_1_element (%base:  memref <4 xf32 , strided <[2 ]>>, %v:  vector <2 xindex >, %mask:  vector <2 xi1 >, %pass_thru:  vector <2 xf32 >) -> vector <2 xf32 > {
162+   %c0  = arith.constant  0  : index 
163+   %0  = vector.gather  %base [%c0 ][%v ], %mask , %pass_thru  : memref <4 xf32 , strided <[2 ]>>, vector <2 xindex >, vector <2 xi1 >, vector <2 xf32 > into  vector <2 xf32 >
164+   return  %0  : vector <2 xf32 >
165+ }
166+ 
139167// CHECK-LABEL: @gather_tensor_2d 
140168// CHECK:  scf.if 
141169// CHECK:    tensor.extract 
0 commit comments