Skip to content

Commit 6a76046

Browse files
committed
fix test case and add a negative case
1 parent 70f4b42 commit 6a76046

File tree

1 file changed

+13
-3
lines changed

1 file changed

+13
-3
lines changed

mlir/test/Dialect/Vector/vector-gather-lowering.mlir

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -136,11 +136,11 @@ func.func @gather_tensor_1d(%base: tensor<?xf32>, %v: vector<2xindex>, %mask: ve
136136
return %0 : vector<2xf32>
137137
}
138138

139-
// CHECK-LABEL: @gather_strided_memref_1d
139+
// CHECK-LABEL: @gather_memref_non_unit_stride_read_1_element
140140
// CHECK: %[[MASK:.*]] = vector.extract %arg2[0] : i1 from vector<1xi1>
141-
// CHECK: %1 = vector.extract %arg1[0] : index from vector<1xindex>
141+
// CHECK: %[[IDX:.*]] = vector.extract %arg1[0] : index from vector<1xindex>
142142
// CHECK: %[[RET:.*]] = scf.if %[[MASK]] -> (vector<1xf32>) {
143-
// CHECK: %[[VEC:.*]] = vector.load %arg0[%1] : memref<4xf32, strided<[2]>>, vector<1xf32>
143+
// CHECK: %[[VEC:.*]] = vector.load %arg0[%[[IDX]]] : memref<4xf32, strided<[2]>>, vector<1xf32>
144144
// CHECK: %[[VAL:.*]] = vector.extract %[[VEC]][0] : f32 from vector<1xf32>
145145
// CHECK: %[[RES:.*]] = vector.insert %[[VAL]], %arg3 [0] : f32 into vector<1xf32>
146146
// CHECK: scf.yield %[[RES]] : vector<1xf32>
@@ -154,6 +154,16 @@ func.func @gather_memref_non_unit_stride_read_1_element(%base: memref<4xf32, str
154154
return %0 : vector<1xf32>
155155
}
156156

157+
// CHECK-LABEL: @gather_memref_non_unit_stride_read_more_than_1_element
158+
// CHECK: %[[CONST:.*]] = arith.constant 0 : index
159+
// CHECK: %[[RET:.*]] = vector.gather %arg0[%[[CONST]]] [%arg1], %arg2, %arg3 : memref<4xf32, strided<[2]>>, vector<2xindex>, vector<2xi1>, vector<2xf32> into vector<2xf32>
160+
// CHECK: return %[[RET]] : vector<2xf32>
161+
func.func @gather_memref_non_unit_stride_read_more_than_1_element(%base: memref<4xf32, strided<[2]>>, %v: vector<2xindex>, %mask: vector<2xi1>, %pass_thru: vector<2xf32>) -> vector<2xf32> {
162+
%c0 = arith.constant 0 : index
163+
%0 = vector.gather %base[%c0][%v], %mask, %pass_thru : memref<4xf32, strided<[2]>>, vector<2xindex>, vector<2xi1>, vector<2xf32> into vector<2xf32>
164+
return %0 : vector<2xf32>
165+
}
166+
157167
// CHECK-LABEL: @gather_tensor_2d
158168
// CHECK: scf.if
159169
// CHECK: tensor.extract

0 commit comments

Comments
 (0)