Skip to content

Conversation

@PragmaTwice
Copy link
Member

@PragmaTwice PragmaTwice commented Jan 10, 2025

In Gather1DToConditionalLoads, currently we will check if the stride of the most minor dim of the input memref is 1. And if not, the rewriting pattern will not be applied. However, according to the verification of vector.load here:

// If rank==0 or size==1 it's equivalent to scalar load/store, so we don't
// need any strides limitations.
if (!vecTy.isScalable() &&
(vecTy.getRank() == 0 || vecTy.getNumElements() == 1))
return success();

.. if the output vector type of vector.load contains only one element, we can ignore the requirement of the stride of the input memref, i.e. the input memref can be with any stride layout attribute in such case.

So here we can allow more cases in lowering vector.gather by relaxing such check.

As shown in the test case attached in this patch here, now vector.gather of memref with non-trivial stride can be lowered successfully if the result vector contains only one element.

@github-actions
Copy link

github-actions bot commented Jan 10, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@llvmbot
Copy link
Member

llvmbot commented Jan 10, 2025

@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Author: Twice (PragmaTwice)

Changes

In Gather1DToConditionalLoads, currently we will check if the stride of the most minor dim of the input memref is 1. And if not, the rewriting pattern will not be applied. However, according to the verification of vector.load here:

// If rank==0 or size==1 it's equivalent to scalar load/store, so we don't
// need any strides limitations.
if (!vecTy.isScalable() &&
(vecTy.getRank() == 0 || vecTy.getNumElements() == 1))
return success();

.. if the output vector type of vector.load contains only one element, we can ignore the requirement of the stride of the input memref, i.e. the input memref can be with any stride layout attribute in such case.

So here we can allow more cases in lowering vector.gather by relaxing such check.


Full diff: https://github.com/llvm/llvm-project/pull/122437.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp (+4-2)
  • (modified) mlir/test/Dialect/Vector/vector-gather-lowering.mlir (+18)
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
index f1a5aa7664d2f3..4aff565b81b453 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
@@ -205,11 +205,13 @@ struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> {
     Value condMask = op.getMask();
     Value base = op.getBase();
 
-    // vector.load requires the most minor memref dim to have unit stride
+    // vector.load requires the most minor memref dim to have unit stride,
+    // or the result vector type to have only one element
     if (auto memType = dyn_cast<MemRefType>(base.getType())) {
       if (auto stridesAttr =
               dyn_cast_if_present<StridedLayoutAttr>(memType.getLayout())) {
-        if (stridesAttr.getStrides().back() != 1)
+        if (stridesAttr.getStrides().back() != 1 &&
+            resultTy.getNumElements() != 1)
           return failure();
       }
     }
diff --git a/mlir/test/Dialect/Vector/vector-gather-lowering.mlir b/mlir/test/Dialect/Vector/vector-gather-lowering.mlir
index 5ad3a23e0ba15c..5d7aff6f8762ad 100644
--- a/mlir/test/Dialect/Vector/vector-gather-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-gather-lowering.mlir
@@ -136,6 +136,24 @@ func.func @gather_tensor_1d(%base: tensor<?xf32>, %v: vector<2xindex>, %mask: ve
   return %0 : vector<2xf32>
 }
 
+// CHECK-LABEL: @gather_strided_memref_1d
+// CHECK: %[[MASK:.*]] = vector.extract %arg2[0] : i1 from vector<1xi1>
+// CHECK: %1 = vector.extract %arg1[0] : index from vector<1xindex>
+// CHECK: %[[RET:.*]] = scf.if %[[MASK]] -> (vector<1xf32>) {
+// CHECK:   %[[VEC:.*]] = vector.load %arg0[%1] : memref<4xf32, strided<[2]>>, vector<1xf32>
+// CHECK:   %[[VAL:.*]] = vector.extract %[[VEC]][0] : f32 from vector<1xf32>
+// CHECK:   %[[RES:.*]] = vector.insert %[[VAL]], %arg3 [0] : f32 into vector<1xf32>
+// CHECK:   scf.yield %[[RES]] : vector<1xf32>
+// CHECK: } else {
+// CHECK:    scf.yield %arg3 : vector<1xf32>
+// CHECK: }
+// CHECK: return %[[RET]] : vector<1xf32>
+func.func @gather_strided_memref_1d(%base: memref<4xf32, strided<[2]>>, %v: vector<1xindex>, %mask: vector<1xi1>, %pass_thru: vector<1xf32>) -> vector<1xf32> {
+  %c0 = arith.constant 0 : index
+  %0 = vector.gather %base[%c0][%v], %mask, %pass_thru : memref<4xf32, strided<[2]>>, vector<1xindex>, vector<1xi1>, vector<1xf32> into vector<1xf32>
+  return %0 : vector<1xf32>
+}
+
 // CHECK-LABEL: @gather_tensor_2d
 // CHECK:  scf.if
 // CHECK:    tensor.extract

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, thanks for the fix!

I've left a couple of minor comments. Also, could you add a negative test when reading more than one element from a strided memref? Thanks!

@PragmaTwice
Copy link
Member Author

Thank you for your review! I've applied these great suggestions and added a negative case : )

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

Do you have commit access? If not, I can land this for you.

@PragmaTwice
Copy link
Member Author

PragmaTwice commented Jan 12, 2025

LGTM, thanks!

Do you have commit access? If not, I can land this for you.

Thank you! I don't have access yet, so I'd appreciate it if you could land it : )

@banach-space banach-space merged commit b91d5af into llvm:main Jan 12, 2025
8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants