@@ -136,11 +136,11 @@ func.func @gather_tensor_1d(%base: tensor<?xf32>, %v: vector<2xindex>, %mask: ve
136136 return %0 : vector <2 xf32 >
137137}
138138
139- // CHECK-LABEL: @gather_strided_memref_1d
139+ // CHECK-LABEL: @gather_memref_non_unit_stride_read_1_element
140140// CHECK: %[[MASK:.*]] = vector.extract %arg2[0] : i1 from vector<1xi1>
141- // CHECK: %1 = vector.extract %arg1[0] : index from vector<1xindex>
141+ // CHECK: %[[IDX:.*]] = vector.extract %arg1[0] : index from vector<1xindex>
142142// CHECK: %[[RET:.*]] = scf.if %[[MASK]] -> (vector<1xf32>) {
143- // CHECK: %[[VEC:.*]] = vector.load %arg0[%1 ] : memref<4xf32, strided<[2]>>, vector<1xf32>
143+ // CHECK: %[[VEC:.*]] = vector.load %arg0[%[[IDX]] ] : memref<4xf32, strided<[2]>>, vector<1xf32>
144144// CHECK: %[[VAL:.*]] = vector.extract %[[VEC]][0] : f32 from vector<1xf32>
145145// CHECK: %[[RES:.*]] = vector.insert %[[VAL]], %arg3 [0] : f32 into vector<1xf32>
146146// CHECK: scf.yield %[[RES]] : vector<1xf32>
@@ -154,6 +154,16 @@ func.func @gather_memref_non_unit_stride_read_1_element(%base: memref<4xf32, str
154154 return %0 : vector <1 xf32 >
155155}
156156
157+ // CHECK-LABEL: @gather_memref_non_unit_stride_read_more_than_1_element
158+ // CHECK: %[[CONST:.*]] = arith.constant 0 : index
159+ // CHECK: %[[RET:.*]] = vector.gather %arg0[%[[CONST]]] [%arg1], %arg2, %arg3 : memref<4xf32, strided<[2]>>, vector<2xindex>, vector<2xi1>, vector<2xf32> into vector<2xf32>
160+ // CHECK: return %[[RET]] : vector<2xf32>
161+ func.func @gather_memref_non_unit_stride_read_more_than_1_element (%base: memref <4 xf32 , strided <[2 ]>>, %v: vector <2 xindex >, %mask: vector <2 xi1 >, %pass_thru: vector <2 xf32 >) -> vector <2 xf32 > {
162+ %c0 = arith.constant 0 : index
163+ %0 = vector.gather %base [%c0 ][%v ], %mask , %pass_thru : memref <4 xf32 , strided <[2 ]>>, vector <2 xindex >, vector <2 xi1 >, vector <2 xf32 > into vector <2 xf32 >
164+ return %0 : vector <2 xf32 >
165+ }
166+
157167// CHECK-LABEL: @gather_tensor_2d
158168// CHECK: scf.if
159169// CHECK: tensor.extract
0 commit comments