Skip to content

Conversation

@Groverkss
Copy link
Member

Vectorization today converts any zero rank vector it encounters into a scalar. This patch moves this check from all operations, to only operations that do not support zero-rank operations yet. For linalg vectorization, this is primarily vector::MultiDimReductionOp and vector::ContractionOp.

@llvmbot
Copy link
Member

llvmbot commented Nov 13, 2024

@llvm/pr-subscribers-mlir-linalg

@llvm/pr-subscribers-mlir

Author: Kunwar Grover (Groverkss)

Changes

Vectorization today converts any zero rank vector it encounters into a scalar. This patch moves this check from all operations, to only operations that do not support zero-rank operations yet. For linalg vectorization, this is primarily vector::MultiDimReductionOp and vector::ContractionOp.


Full diff: https://github.com/llvm/llvm-project/pull/116069.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp (+9-9)
  • (modified) mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir (+1-4)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 23b46a2ee55f8d..9f35e40a964af6 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -590,9 +590,6 @@ static Operation *matchLinalgReduction(OpOperand *outputOperand) {
 /// otherwise.
 static Value broadcastIfNeeded(OpBuilder &b, Value value, Type dstType) {
   auto dstVecType = dyn_cast<VectorType>(dstType);
-  // If no shape to broadcast to, just return `value`.
-  if (dstVecType.getRank() == 0)
-    return value;
   if (vector::isBroadcastableTo(value.getType(), dstVecType) !=
       vector::BroadcastableToResult::Success)
     return value;
@@ -608,6 +605,15 @@ static Value broadcastIfNeeded(OpBuilder &b, Value value, Type dstType) {
 static Operation *buildMultiDimReduce(OpBuilder &b, Operation *reduceOp,
                                       Value valueToReduce, Value acc,
                                       ArrayRef<bool> dimsToMask) {
+  // If `acc` is a zero-rank vector, extract the scalar value from it, since
+  // vector.multi_reduction does not support 0 rank vectors yet.
+  // TODO: Remove this once vector.multi_reduction supports 0 rank vectors.
+  auto accVecType = dyn_cast<VectorType>(acc.getType());
+  if (accVecType && accVecType.getRank() == 0) {
+    acc = b.create<vector::ExtractOp>(reduceOp->getLoc(), acc,
+                                      ArrayRef<int64_t>());
+  }
+
   auto maybeKind = getCombinerOpKind(reduceOp);
   assert(maybeKind && "Failed precondition: could not get reduction kind");
   return b.create<vector::MultiDimReductionOp>(
@@ -1410,12 +1416,6 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
           .setInBoundsAttr(rewriter.getBoolArrayAttr(inBounds));
     }
 
-    // 3.c. Not all ops support 0-d vectors, extract the scalar for now.
-    // TODO: remove this.
-    if (readType.getRank() == 0)
-      readValue = rewriter.create<vector::ExtractOp>(loc, readValue,
-                                                     ArrayRef<int64_t>());
-
     LDBG("New vectorized bbarg(" << bbarg.getArgNumber() << "): " << readValue
                                  << "\n");
     bvm.map(bbarg, readValue);
diff --git a/mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir b/mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir
index 0c996bed996d3c..ee18610071eb20 100644
--- a/mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir
@@ -1777,11 +1777,8 @@ module attributes {transform.with_named_sequence} {
 
 // CHECK-LABEL: func @zero_dim_tensor
 //       CHECK:     vector.transfer_read {{.*}} : tensor<f32>, vector<f32>
-//       CHECK:     vector.extract
 //       CHECK:     vector.transfer_read {{.*}} : tensor<f32>, vector<f32>
-//       CHECK:     vector.extract
-//       CHECK:     arith.addf {{.*}} : f32
-//       CHECK:     vector.broadcast %{{.*}} : f32 to vector<f32>
+//       CHECK:     arith.addf {{.*}} : vector<f32>
 //       CHECK:     vector.transfer_write {{.*}} : vector<f32>, tensor<f32>
 
 // -----

@Groverkss
Copy link
Member Author

I'm surprised that there was only one test change. This likely means that linalg vectorization isn't well tested for 0 rank tensors. I'm going to try to add more tests as well, but would like to get others opinions as well.

@banach-space
Copy link
Contributor

Makes sense.

This patch moves this check from all operations, to only operations that do not support zero-rank operations yet.

How did you identify these Ops? Mostly trying to make sure that we are not missing anything due to limited testing.

This likely means that linalg vectorization isn't well tested for 0 rank tensors. I'm going to try to add more tests as well

+1 It's a corner case that's been barely tested, so that will be much appreciated. I should warn you though - some refactor/re-org of those tests is long-overdue 😅

LGTM, but please give it 24hrs before landing. In case others want to chime in.

@dcaballe
Copy link
Contributor

I don't have a strong concern here but I'm curious about the motivation to prioritize a 0-D vector over a scalar. Is this mostly to preserve a 0-D tensor input?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants