-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[MLIR][Vector] Allow any strided memref for one-element vector.load in lowering vector.gather #122437
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
…wering vector.gather Signed-off-by: PragmaTwice <[email protected]>
e8f3a28 to
c9853d6
Compare
|
@llvm/pr-subscribers-mlir-vector @llvm/pr-subscribers-mlir Author: Twice (PragmaTwice) ChangesIn llvm-project/mlir/lib/Dialect/Vector/IR/VectorOps.cpp Lines 4971 to 4975 in 4e32271
.. if the output vector type of So here we can allow more cases in lowering Full diff: https://github.com/llvm/llvm-project/pull/122437.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
index f1a5aa7664d2f3..4aff565b81b453 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
@@ -205,11 +205,13 @@ struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> {
Value condMask = op.getMask();
Value base = op.getBase();
- // vector.load requires the most minor memref dim to have unit stride
+ // vector.load requires the most minor memref dim to have unit stride,
+ // or the result vector type to have only one element
if (auto memType = dyn_cast<MemRefType>(base.getType())) {
if (auto stridesAttr =
dyn_cast_if_present<StridedLayoutAttr>(memType.getLayout())) {
- if (stridesAttr.getStrides().back() != 1)
+ if (stridesAttr.getStrides().back() != 1 &&
+ resultTy.getNumElements() != 1)
return failure();
}
}
diff --git a/mlir/test/Dialect/Vector/vector-gather-lowering.mlir b/mlir/test/Dialect/Vector/vector-gather-lowering.mlir
index 5ad3a23e0ba15c..5d7aff6f8762ad 100644
--- a/mlir/test/Dialect/Vector/vector-gather-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-gather-lowering.mlir
@@ -136,6 +136,24 @@ func.func @gather_tensor_1d(%base: tensor<?xf32>, %v: vector<2xindex>, %mask: ve
return %0 : vector<2xf32>
}
+// CHECK-LABEL: @gather_strided_memref_1d
+// CHECK: %[[MASK:.*]] = vector.extract %arg2[0] : i1 from vector<1xi1>
+// CHECK: %1 = vector.extract %arg1[0] : index from vector<1xindex>
+// CHECK: %[[RET:.*]] = scf.if %[[MASK]] -> (vector<1xf32>) {
+// CHECK: %[[VEC:.*]] = vector.load %arg0[%1] : memref<4xf32, strided<[2]>>, vector<1xf32>
+// CHECK: %[[VAL:.*]] = vector.extract %[[VEC]][0] : f32 from vector<1xf32>
+// CHECK: %[[RES:.*]] = vector.insert %[[VAL]], %arg3 [0] : f32 into vector<1xf32>
+// CHECK: scf.yield %[[RES]] : vector<1xf32>
+// CHECK: } else {
+// CHECK: scf.yield %arg3 : vector<1xf32>
+// CHECK: }
+// CHECK: return %[[RET]] : vector<1xf32>
+func.func @gather_strided_memref_1d(%base: memref<4xf32, strided<[2]>>, %v: vector<1xindex>, %mask: vector<1xi1>, %pass_thru: vector<1xf32>) -> vector<1xf32> {
+ %c0 = arith.constant 0 : index
+ %0 = vector.gather %base[%c0][%v], %mask, %pass_thru : memref<4xf32, strided<[2]>>, vector<1xindex>, vector<1xi1>, vector<1xf32> into vector<1xf32>
+ return %0 : vector<1xf32>
+}
+
// CHECK-LABEL: @gather_tensor_2d
// CHECK: scf.if
// CHECK: tensor.extract
|
banach-space
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense, thanks for the fix!
I've left a couple of minor comments. Also, could you add a negative test when reading more than one element from a strided memref? Thanks!
Co-authored-by: Andrzej Warzyński <[email protected]>
|
Thank you for your review! I've applied these great suggestions and added a negative case : ) |
banach-space
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
Do you have commit access? If not, I can land this for you.
Thank you! I don't have access yet, so I'd appreciate it if you could land it : ) |
In
Gather1DToConditionalLoads, currently we will check if the stride of the most minor dim of the input memref is 1. And if not, the rewriting pattern will not be applied. However, according to the verification ofvector.loadhere:llvm-project/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Lines 4971 to 4975 in 4e32271
.. if the output vector type of
vector.loadcontains only one element, we can ignore the requirement of the stride of the input memref, i.e. the input memref can be with any stride layout attribute in such case.So here we can allow more cases in lowering
vector.gatherby relaxing such check.As shown in the test case attached in this patch here, now
vector.gatherof memref with non-trivial stride can be lowered successfully if the result vector contains only one element.