Skip to content

Commit b91d5af

Browse files
[MLIR][Vector] Allow any strided memref for one-element vector.load in lowering vector.gather (#122437)
In `Gather1DToConditionalLoads`, currently we will check if the stride of the most minor dim of the input memref is 1. And if not, the rewriting pattern will not be applied. However, according to the verification of `vector.load` here: https://github.com/llvm/llvm-project/blob/4e32271e8b304eb018c69f74c16edd1668fcdaf3/mlir/lib/Dialect/Vector/IR/VectorOps.cpp#L4971-L4975 .. if the output vector type of `vector.load` contains only one element, we can ignore the requirement of the stride of the input memref, i.e. the input memref can be with any stride layout attribute in such case. So here we can allow more cases in lowering `vector.gather` by relaxing such check. As shown in the test case attached in this patch [here](https://github.com/llvm/llvm-project/blob/1933fbad58302814ccce5991a9320c0967f3571b/mlir/test/Dialect/Vector/vector-gather-lowering.mlir#L151), now `vector.gather` of memref with non-trivial stride can be lowered successfully if the result vector contains only one element. --------- Signed-off-by: PragmaTwice <[email protected]> Co-authored-by: Andrzej Warzyński <[email protected]>
1 parent be6c752 commit b91d5af

File tree

2 files changed

+31
-1
lines changed

2 files changed

+31
-1
lines changed

mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,10 +206,12 @@ struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> {
206206
Value base = op.getBase();
207207

208208
// vector.load requires the most minor memref dim to have unit stride
209+
// (unless reading exactly 1 element)
209210
if (auto memType = dyn_cast<MemRefType>(base.getType())) {
210211
if (auto stridesAttr =
211212
dyn_cast_if_present<StridedLayoutAttr>(memType.getLayout())) {
212-
if (stridesAttr.getStrides().back() != 1)
213+
if (stridesAttr.getStrides().back() != 1 &&
214+
resultTy.getNumElements() != 1)
213215
return failure();
214216
}
215217
}

mlir/test/Dialect/Vector/vector-gather-lowering.mlir

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,34 @@ func.func @gather_tensor_1d(%base: tensor<?xf32>, %v: vector<2xindex>, %mask: ve
136136
return %0 : vector<2xf32>
137137
}
138138

139+
// CHECK-LABEL: @gather_memref_non_unit_stride_read_1_element
140+
// CHECK: %[[MASK:.*]] = vector.extract %arg2[0] : i1 from vector<1xi1>
141+
// CHECK: %[[IDX:.*]] = vector.extract %arg1[0] : index from vector<1xindex>
142+
// CHECK: %[[RET:.*]] = scf.if %[[MASK]] -> (vector<1xf32>) {
143+
// CHECK: %[[VEC:.*]] = vector.load %arg0[%[[IDX]]] : memref<4xf32, strided<[2]>>, vector<1xf32>
144+
// CHECK: %[[VAL:.*]] = vector.extract %[[VEC]][0] : f32 from vector<1xf32>
145+
// CHECK: %[[RES:.*]] = vector.insert %[[VAL]], %arg3 [0] : f32 into vector<1xf32>
146+
// CHECK: scf.yield %[[RES]] : vector<1xf32>
147+
// CHECK: } else {
148+
// CHECK: scf.yield %arg3 : vector<1xf32>
149+
// CHECK: }
150+
// CHECK: return %[[RET]] : vector<1xf32>
151+
func.func @gather_memref_non_unit_stride_read_1_element(%base: memref<4xf32, strided<[2]>>, %v: vector<1xindex>, %mask: vector<1xi1>, %pass_thru: vector<1xf32>) -> vector<1xf32> {
152+
%c0 = arith.constant 0 : index
153+
%0 = vector.gather %base[%c0][%v], %mask, %pass_thru : memref<4xf32, strided<[2]>>, vector<1xindex>, vector<1xi1>, vector<1xf32> into vector<1xf32>
154+
return %0 : vector<1xf32>
155+
}
156+
157+
// CHECK-LABEL: @gather_memref_non_unit_stride_read_more_than_1_element
158+
// CHECK: %[[CONST:.*]] = arith.constant 0 : index
159+
// CHECK: %[[RET:.*]] = vector.gather %arg0[%[[CONST]]] [%arg1], %arg2, %arg3 : memref<4xf32, strided<[2]>>, vector<2xindex>, vector<2xi1>, vector<2xf32> into vector<2xf32>
160+
// CHECK: return %[[RET]] : vector<2xf32>
161+
func.func @gather_memref_non_unit_stride_read_more_than_1_element(%base: memref<4xf32, strided<[2]>>, %v: vector<2xindex>, %mask: vector<2xi1>, %pass_thru: vector<2xf32>) -> vector<2xf32> {
162+
%c0 = arith.constant 0 : index
163+
%0 = vector.gather %base[%c0][%v], %mask, %pass_thru : memref<4xf32, strided<[2]>>, vector<2xindex>, vector<2xi1>, vector<2xf32> into vector<2xf32>
164+
return %0 : vector<2xf32>
165+
}
166+
139167
// CHECK-LABEL: @gather_tensor_2d
140168
// CHECK: scf.if
141169
// CHECK: tensor.extract

0 commit comments

Comments
 (0)