-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[mlir][vector] Restrict vector.insert/vector.extract to disallow 0-d vectors #121458
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][vector] Restrict vector.insert/vector.extract to disallow 0-d vectors #121458
Conversation
|
While the discussion is ongoing, I am posting this as draft. Please comment either here or on Discourse. |
0962edb to
473bd0f
Compare
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
473bd0f to
9434337
Compare
9434337 to
eb62bf0
Compare
eb62bf0 to
4084038
Compare
c214bbf to
29e7bb6
Compare
29e7bb6 to
9f9491a
Compare
|
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-vector Author: Andrzej Warzyński (banach-space) ChangesThis patch restricts the use of vector.insert and vector.extract Ops in
The following are now illegal. Note that the source and result types %0 = vector.insert %v, %dst[0, 0] : vector<f32> into vector<2x2xf32>
%1 = vector.extract %arg0[0, 0] : vector<f32> from vector<2x2xf32>Instead, use scalars as the source and result types: %0 = vector.insert %v, %dst[0, 0] : f32 into vector<2x2xf32>
%1 = vector.extract %src[0, 0] : f32 from vector<2x2xf32>Put differently, this PR removes the ambiguity when it comes to For more context, see the related RFC: Full diff: https://github.com/llvm/llvm-project/pull/121458.diff 4 Files Affected:
diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index cc5623068ab10..08f398a1c8ba6 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -1294,6 +1294,10 @@ struct UnrollTransferReadConversion
/// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds
/// accesses, and broadcasts and transposes in permutation maps.
+ ///
+ /// When unpacking rank-1 vectors (i.e. when the target rank is 0), replaces
+ /// `vector.transfer_read` with either `memref.load` or `tensor.extract` (for
+ /// MemRef and Tensor source, respectively).
LogicalResult matchAndRewrite(TransferReadOp xferOp,
PatternRewriter &rewriter) const override {
if (xferOp.getVectorType().getRank() <= options.targetRank)
@@ -1324,6 +1328,8 @@ struct UnrollTransferReadConversion
for (int64_t i = 0; i < dimSize; ++i) {
Value iv = rewriter.create<arith::ConstantIndexOp>(loc, i);
+ // FIXME: Rename this lambda - it does much more than just
+ // in-bounds-check generation.
vec = generateInBoundsCheck(
rewriter, xferOp, iv, unpackedDim(xferOp), TypeRange(vecType),
/*inBoundsCase=*/
@@ -1338,12 +1344,33 @@ struct UnrollTransferReadConversion
insertionIndices.push_back(rewriter.getIndexAttr(i));
auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
- auto newXferOp = b.create<vector::TransferReadOp>(
- loc, newXferVecType, xferOp.getBase(), xferIndices,
- AffineMapAttr::get(unpackedPermutationMap(b, xferOp)),
- xferOp.getPadding(), Value(), inBoundsAttr);
- maybeAssignMask(b, xferOp, newXferOp, i);
- return b.create<vector::InsertOp>(loc, newXferOp, vec,
+
+ // A value that's read after rank-reducing the original
+ // vector.transfer_read Op.
+ Value unpackedReadRes;
+ if (newXferVecType.getRank() != 0) {
+ // Unpacking Vector that's rank > 2
+ // (use vector.transfer_read to load a rank-reduced vector)
+ unpackedReadRes = b.create<vector::TransferReadOp>(
+ loc, newXferVecType, xferOp.getBase(), xferIndices,
+ AffineMapAttr::get(unpackedPermutationMap(b, xferOp)),
+ xferOp.getPadding(), Value(), inBoundsAttr);
+ maybeAssignMask(b, xferOp,
+ dyn_cast<vector::TransferReadOp>(
+ unpackedReadRes.getDefiningOp()),
+ i);
+ } else {
+ // Unpacking Vector that's rank == 1
+ // (use memref.load/tensor.extract to load a scalar)
+ unpackedReadRes = dyn_cast<MemRefType>(xferOp.getBase().getType())
+ ? b.create<memref::LoadOp>(
+ loc, xferOp.getBase(), xferIndices)
+ .getResult()
+ : b.create<tensor::ExtractOp>(
+ loc, xferOp.getBase(), xferIndices)
+ .getResult();
+ }
+ return b.create<vector::InsertOp>(loc, unpackedReadRes, vec,
insertionIndices);
},
/*outOfBoundsCase=*/
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 2a2357319bd23..dc4bcd9b6bd84 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1383,6 +1383,11 @@ bool ExtractOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
}
LogicalResult vector::ExtractOp::verify() {
+ if (auto resTy = dyn_cast<VectorType>(getResult().getType()))
+ if (resTy.getRank() == 0)
+ return emitError(
+ "expected a scalar instead of a 0-d vector as the result type");
+
// Note: This check must come before getMixedPosition() to prevent a crash.
auto dynamicMarkersCount =
llvm::count_if(getStaticPosition(), ShapedType::isDynamic);
@@ -3122,6 +3127,11 @@ void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
}
LogicalResult InsertOp::verify() {
+ if (auto srcTy = dyn_cast<VectorType>(getValueToStoreType()))
+ if (srcTy.getRank() == 0)
+ return emitError(
+ "expected a scalar instead of a 0-d vector as the source operand");
+
SmallVector<OpFoldResult> position = getMixedPosition();
auto destVectorType = getDestVectorType();
if (position.size() > static_cast<unsigned>(destVectorType.getRank()))
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 04810ed52584f..57ec12a8ccac1 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -260,8 +260,8 @@ func.func @insert_precise_position_overflow(%a: f32, %b: vector<4x8x16xf32>) {
// -----
func.func @insert_0d(%a: vector<f32>, %b: vector<4x8x16xf32>) {
- // expected-error@+1 {{expected position attribute rank + source rank to match dest vector rank}}
- %1 = vector.insert %a, %b[2, 6] : vector<f32> into vector<4x8x16xf32>
+ // expected-error@+1 {{expected a scalar instead of a 0-d vector as the source operand}}
+ %1 = vector.insert %a, %b[0, 0, 0] : vector<f32> into vector<4x8x16xf32>
}
// -----
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index f3220aed4360c..7d43f2a84dc77 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -298,12 +298,10 @@ func.func @insert_val_idx(%a: f32, %b: vector<16xf32>, %c: vector<8x16xf32>,
}
// CHECK-LABEL: @insert_0d
-func.func @insert_0d(%a: f32, %b: vector<f32>, %c: vector<2x3xf32>) -> (vector<f32>, vector<2x3xf32>) {
+func.func @insert_0d(%a: f32, %b: vector<f32>) -> vector<f32> {
// CHECK-NEXT: vector.insert %{{.*}}, %{{.*}}[] : f32 into vector<f32>
%1 = vector.insert %a, %b[] : f32 into vector<f32>
- // CHECK-NEXT: vector.insert %{{.*}}, %{{.*}}[0, 1] : vector<f32> into vector<2x3xf32>
- %2 = vector.insert %b, %c[0, 1] : vector<f32> into vector<2x3xf32>
- return %1, %2 : vector<f32>, vector<2x3xf32>
+ return %1 : vector<f32>
}
// CHECK-LABEL: @insert_poison_idx
|
|
Based on the discussion in the ODM, marking this as ready to review: |
…tract This patch enforces a restriction in the Vector dialect: the non-indexed operands of `vector.insert` and `vector.extract` must no longer be 0-D vectors. In other words, rank-0 vector types like `vector<f32>` are disallowed as the source or result. EXAMPLES -------- The following are now **illegal** (note the use of `vector<f32>`): ```mlir %0 = vector.insert %v, %dst[0, 0] : vector<f32> into vector<2x2xf32> %1 = vector.extract %src[0, 0] : vector<f32> from vector<2x2xf32> ``` Instead, use scalars as the source and result types: ```mlir %0 = vector.insert %v, %dst[0, 0] : f32 into vector<2x2xf32> %1 = vector.extract %src[0, 0] : f32 from vector<2x2xf32> ``` This change serves three goals: 1. REDUCED AMBIGUITY -------------------- By enforcing scalar-only semantics when n-k = 0, we eliminate ambiguity in interpretation. Prior to this patch, both `f32` and `vector<f32>` were accepted in practice, though only scalars were intended. 2. MATCH IMPLEMENTATION TO DOCUMENTATION ---------------------------------------- The current behavior contradicts the documented intent. For example, vector.extract states: > Degenerates to an element type if n-k is zero. This patch enforces that intent in code. 3. ENSURE SYMMETRY BETWEEN INSERT AND EXTRACT --------------------------------------------- With the stricter semantics in place, it’s natural and consistent to make `vector.insert` behave symmetrically to `vector.extract`, i.e., degenerate the source type to a scalar when n = 0. NOTES FOR REVIEWERS ------------------- 1. Main change is in "VectorOps.cpp", where stricter type checks are implemented. 2. Test updates in "invalid.mlir" and "ops.mlir" are minor cleanups to remove now-illegal examples. 2. Lowering changes in "VectorToSCF.cpp" are the main trade-off: we now avoid using `vector.transfer_read` for scalar loads and instead rely on `memref.load` / `tensor.extract`. RELATED RFC ----------- * https://discourse.llvm.org/t/rfc-should-we-restrict-the-usage-of-0-d-vectors-in-the-vector-dialect
9f9491a to
efc29a7
Compare
|
@dcaballe , @Groverkss , this is the patch that we discussed in the ODM last week. Would you be able to take a look some time soon? It would be great to make some progress on this. |
| Inserts an n-D source vector (the value to store) into an (n + k)-D | ||
| destination vector at a specified k-D position. When n = 0, the source | ||
| degenerates to a scalar element inserted into the (0 + k)-D destination | ||
| vector. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a bit confusing when i read it together with vector.extract docs.
Can we do
n-D vector base vector (source for vector.extract, dest for vector.insert)
k-D position
(n-k)-D subvector, degenerates to scalar if k = n
it's a bit easier to follow then
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure. But let me use the naming scheme from #131602, so:
valueToStore+destforvector.insert,sourceforvector.extract.
Let me know what you think!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this was addressed, what i meant was to use same rank for same class of operands:
n-D vector --> source/dest
k-D position
(n-k)-D subvector (valueToStore, result vector), degenerates to a scalar if k = n.
I don't mind the naming scheme, but having consistent rank documentation is easier to read.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for clarifying, now I see what you meant. Could you check the latest revision?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, that was a Git failure on my part 🤦🏻
Could you check this commit that I've just pushed?
|
|
||
| // A value that's read after rank-reducing the original | ||
| // vector.transfer_read Op. | ||
| Value unpackedReadRes; | ||
| if (newXferVecType.getRank() != 0) { | ||
| // Unpacking Vector that's rank > 2 | ||
| // (use vector.transfer_read to load a rank-reduced vector) | ||
| unpackedReadRes = b.create<vector::TransferReadOp>( | ||
| loc, newXferVecType, xferOp.getBase(), xferIndices, | ||
| AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), | ||
| xferOp.getPadding(), Value(), inBoundsAttr); | ||
| maybeAssignMask(b, xferOp, | ||
| dyn_cast<vector::TransferReadOp>( | ||
| unpackedReadRes.getDefiningOp()), | ||
| i); | ||
| } else { | ||
| // Unpacking Vector that's rank == 1 | ||
| // (use memref.load/tensor.extract to load a scalar) | ||
| unpackedReadRes = dyn_cast<MemRefType>(xferOp.getBase().getType()) | ||
| ? b.create<memref::LoadOp>( | ||
| loc, xferOp.getBase(), xferIndices) | ||
| .getResult() | ||
| : b.create<tensor::ExtractOp>( | ||
| loc, xferOp.getBase(), xferIndices) | ||
| .getResult(); | ||
| } | ||
| return b.create<vector::InsertOp>(loc, unpackedReadRes, vec, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is unrelated to the patch and changing behavior of other transformations. For now, if the transfer_read returns a 0-D vector, we should extract a scalar and then insert it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a great point, let me update this, thanks!
…ctor.extract Address Kunwar's comments
…sert/vector.extract Apply clang-format
|
@joker-eph , thanks for updating the title! I just wanted to point out, only non-indexed arguments are disallowed to be rank-0. This change will still allow the indexed arguments to be rank-0. This is explained in the summary. This has been a very long discussion, hence posting to avoid any potential confusion. |
|
Yes that seemed clear from the description. The title should be as descriptive as possible while remaining short, up-to-you if you want to add more information there. |
dcaballe
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thanks for moving this forward! Please, wait for one more approval before landing.
…ctor.insert/vector.extract Update the docs as suggested by Kunwar
Groverkss
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! LGTM
…vectors (llvm#121458) This patch enforces a restriction in the Vector dialect: the non-indexed operands of `vector.insert` and `vector.extract` must no longer be 0-D vectors. In other words, rank-0 vector types like `vector<f32>` are disallowed as the source or result. EXAMPLES -------- The following are now **illegal** (note the use of `vector<f32>`): ```mlir %0 = vector.insert %v, %dst[0, 0] : vector<f32> into vector<2x2xf32> %1 = vector.extract %src[0, 0] : vector<f32> from vector<2x2xf32> ``` Instead, use scalars as the source and result types: ```mlir %0 = vector.insert %v, %dst[0, 0] : f32 into vector<2x2xf32> %1 = vector.extract %src[0, 0] : f32 from vector<2x2xf32> ``` Note, this change serves three goals. These are summarised below. ## 1. REDUCED AMBIGUITY By enforcing scalar-only semantics when the result (`vector.extract`) or source (`vector.insert`) are rank-0, we eliminate ambiguity in interpretation. Prior to this patch, both `f32` and `vector<f32>` were accepted. ## 2. MATCH IMPLEMENTATION TO DOCUMENTATION The current behaviour contradicts the documented intent. For example, `vector.extract` states: > Degenerates to an element type if n-k is zero. This patch enforces that intent in code. ## 3. ENSURE SYMMETRY BETWEEN INSERT AND EXTRACT With the stricter semantics in place, it’s natural and consistent to make `vector.insert` behave symmetrically to `vector.extract`, i.e., degenerate the source type to a scalar when n = 0. NOTES FOR REVIEWERS ------------------- 1. Main change is in "VectorOps.cpp", where stricter type checks are implemented. 2. Test updates in "invalid.mlir" and "ops.mlir" are minor cleanups to remove now-illegal examples. 2. Lowering changes in "VectorToSCF.cpp" are the main trade-off: we now require an additional `vector.extract` when a preceding `vector.transfer_read` generates a rank-0 vector. RELATED RFC ----------- * https://discourse.llvm.org/t/rfc-should-we-restrict-the-usage-of-0-d-vectors-in-the-vector-dialect

This patch enforces a restriction in the Vector dialect: the non-indexed
operands of
vector.insertandvector.extractmust no longer be 0-Dvectors. In other words, rank-0 vector types like
vector<f32>aredisallowed as the source or result.
EXAMPLES
The following are now illegal (note the use of
vector<f32>):Instead, use scalars as the source and result types:
Note, this change serves three goals. These are summarised below.
1. REDUCED AMBIGUITY
By enforcing scalar-only semantics when the result (
vector.extract)or source (
vector.insert) are rank-0, we eliminate ambiguityin interpretation. Prior to this patch, both
f32andvector<f32>were accepted.
2. MATCH IMPLEMENTATION TO DOCUMENTATION
The current behaviour contradicts the documented intent. For example,
vector.extractstates:This patch enforces that intent in code.
3. ENSURE SYMMETRY BETWEEN INSERT AND EXTRACT
With the stricter semantics in place, it’s natural and consistent to
make
vector.insertbehave symmetrically tovector.extract, i.e.,degenerate the source type to a scalar when n = 0.
NOTES FOR REVIEWERS
implemented.
remove now-illegal examples.
require an additional
vector.extractwhen a precedingvector.transfer_readgenerates a rank-0 vector.RELATED RFC