-
Notifications
You must be signed in to change notification settings - Fork 15.4k
[mlir][Linalg] Remove implicit zero rank vectors in vectorization #116069
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
@llvm/pr-subscribers-mlir-linalg @llvm/pr-subscribers-mlir Author: Kunwar Grover (Groverkss) ChangesVectorization today converts any zero rank vector it encounters into a scalar. This patch moves this check from all operations, to only operations that do not support zero-rank operations yet. For linalg vectorization, this is primarily vector::MultiDimReductionOp and vector::ContractionOp. Full diff: https://github.com/llvm/llvm-project/pull/116069.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 23b46a2ee55f8d..9f35e40a964af6 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -590,9 +590,6 @@ static Operation *matchLinalgReduction(OpOperand *outputOperand) {
/// otherwise.
static Value broadcastIfNeeded(OpBuilder &b, Value value, Type dstType) {
auto dstVecType = dyn_cast<VectorType>(dstType);
- // If no shape to broadcast to, just return `value`.
- if (dstVecType.getRank() == 0)
- return value;
if (vector::isBroadcastableTo(value.getType(), dstVecType) !=
vector::BroadcastableToResult::Success)
return value;
@@ -608,6 +605,15 @@ static Value broadcastIfNeeded(OpBuilder &b, Value value, Type dstType) {
static Operation *buildMultiDimReduce(OpBuilder &b, Operation *reduceOp,
Value valueToReduce, Value acc,
ArrayRef<bool> dimsToMask) {
+ // If `acc` is a zero-rank vector, extract the scalar value from it, since
+ // vector.multi_reduction does not support 0 rank vectors yet.
+ // TODO: Remove this once vector.multi_reduction supports 0 rank vectors.
+ auto accVecType = dyn_cast<VectorType>(acc.getType());
+ if (accVecType && accVecType.getRank() == 0) {
+ acc = b.create<vector::ExtractOp>(reduceOp->getLoc(), acc,
+ ArrayRef<int64_t>());
+ }
+
auto maybeKind = getCombinerOpKind(reduceOp);
assert(maybeKind && "Failed precondition: could not get reduction kind");
return b.create<vector::MultiDimReductionOp>(
@@ -1410,12 +1416,6 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
.setInBoundsAttr(rewriter.getBoolArrayAttr(inBounds));
}
- // 3.c. Not all ops support 0-d vectors, extract the scalar for now.
- // TODO: remove this.
- if (readType.getRank() == 0)
- readValue = rewriter.create<vector::ExtractOp>(loc, readValue,
- ArrayRef<int64_t>());
-
LDBG("New vectorized bbarg(" << bbarg.getArgNumber() << "): " << readValue
<< "\n");
bvm.map(bbarg, readValue);
diff --git a/mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir b/mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir
index 0c996bed996d3c..ee18610071eb20 100644
--- a/mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir
@@ -1777,11 +1777,8 @@ module attributes {transform.with_named_sequence} {
// CHECK-LABEL: func @zero_dim_tensor
// CHECK: vector.transfer_read {{.*}} : tensor<f32>, vector<f32>
-// CHECK: vector.extract
// CHECK: vector.transfer_read {{.*}} : tensor<f32>, vector<f32>
-// CHECK: vector.extract
-// CHECK: arith.addf {{.*}} : f32
-// CHECK: vector.broadcast %{{.*}} : f32 to vector<f32>
+// CHECK: arith.addf {{.*}} : vector<f32>
// CHECK: vector.transfer_write {{.*}} : vector<f32>, tensor<f32>
// -----
|
|
I'm surprised that there was only one test change. This likely means that linalg vectorization isn't well tested for 0 rank tensors. I'm going to try to add more tests as well, but would like to get others opinions as well. |
|
Makes sense.
How did you identify these Ops? Mostly trying to make sure that we are not missing anything due to limited testing.
+1 It's a corner case that's been barely tested, so that will be much appreciated. I should warn you though - some refactor/re-org of those tests is long-overdue 😅 LGTM, but please give it 24hrs before landing. In case others want to chime in. |
|
I don't have a strong concern here but I'm curious about the motivation to prioritize a 0-D vector over a scalar. Is this mostly to preserve a 0-D tensor input? |
Vectorization today converts any zero rank vector it encounters into a scalar. This patch moves this check from all operations, to only operations that do not support zero-rank operations yet. For linalg vectorization, this is primarily vector::MultiDimReductionOp and vector::ContractionOp.