-
Notifications
You must be signed in to change notification settings - Fork 15.1k
[mlir][Vector] Improve vector.mask verifier
#139823
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@llvm/pr-subscribers-mlir-vector @llvm/pr-subscribers-mlir Author: Diego Caballero (dcaballe) ChangesThis PR improves the As part of this change, the logic that ensures that a terminator is present in the region mask has been simplified to make it less surprising to the user when a Full diff: https://github.com/llvm/llvm-project/pull/139823.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 3aefcea8de994..2820759687293 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2482,8 +2482,13 @@ def Vector_MaskOp : Vector_Op<"mask", [
masked. Values used within the region are captured from above. Only one
*maskable* operation can be masked with a `vector.mask` operation at a time.
An operation is *maskable* if it implements the `MaskableOpInterface`. The
- terminator yields all results of the maskable operation to the result of
- this operation.
+ terminator yields all results from the maskable operation to the result of
+ this operation. No other values are allowed to be yielded.
+
+ An empty `vector.mask` operation is considered ill-formed but legal to
+ facilitate optimizations across the `vector.mask` operation. It is considered
+ a no-op regardless of its returned values and will be removed by the
+ canonicalizer.
The vector mask argument holds a bit for each vector lane and determines
which vector lanes should execute the maskable operation and which ones
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index f6c3c6a61afb6..395680b5e814b 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -6543,29 +6543,31 @@ void mlir::vector::MaskOp::print(OpAsmPrinter &p) {
}
void MaskOp::ensureTerminator(Region ®ion, Builder &builder, Location loc) {
- OpTrait::SingleBlockImplicitTerminator<vector::YieldOp>::Impl<
- MaskOp>::ensureTerminator(region, builder, loc);
- // Keep the default yield terminator if the number of masked operations is not
- // the expected. This case will trigger a verification failure.
- Block &block = region.front();
- if (block.getOperations().size() != 2)
+ // Create default terminator if there are no ops to mask.
+ if (region.empty() || region.front().empty()) {
+ OpTrait::SingleBlockImplicitTerminator<vector::YieldOp>::Impl<
+ MaskOp>::ensureTerminator(region, builder, loc);
return;
+ }
- // Replace default yield terminator with a new one that returns the results
- // from the masked operation.
- OpBuilder opBuilder(builder.getContext());
- Operation *maskedOp = &block.front();
- Operation *oldYieldOp = &block.back();
- assert(isa<vector::YieldOp>(oldYieldOp) && "Expected vector::YieldOp");
+ // If region has an explicit terminator, we don't modify it.
+ Block &block = region.front();
+ if (isa<vector::YieldOp>(block.back()))
+ return;
- // Empty vector.mask op.
- if (maskedOp == oldYieldOp)
+ // Create default terminator if the number of masked operations is not
+ // one. This case will trigger a verification failure.
+ if (block.getOperations().size() != 1) {
+ OpTrait::SingleBlockImplicitTerminator<vector::YieldOp>::Impl<
+ MaskOp>::ensureTerminator(region, builder, loc);
return;
+ }
- opBuilder.setInsertionPoint(oldYieldOp);
+ // Create a terminator that yields the results from the masked operation.
+ OpBuilder opBuilder(builder.getContext());
+ Operation *maskedOp = &block.front();
+ opBuilder.setInsertionPointToEnd(&block);
opBuilder.create<vector::YieldOp>(loc, maskedOp->getResults());
- oldYieldOp->dropAllReferences();
- oldYieldOp->erase();
}
LogicalResult MaskOp::verify() {
@@ -6600,6 +6602,11 @@ LogicalResult MaskOp::verify() {
return emitOpError("expects number of results to match maskable operation "
"number of results");
+ if (!llvm::equal(maskableOp->getResults(), terminator.getOperands()))
+ return emitOpError(
+ "expects all the results from the MaskableOpInterface to "
+ "be returned by the terminator");
+
if (!llvm::equal(maskableOp->getResultTypes(), getResultTypes()))
return emitOpError(
"expects result type to match maskable operation result type");
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index be65d4c2eef58..de6333ca8dc7b 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1747,6 +1747,34 @@ func.func @vector_mask_0d_mask(%arg0: tensor<2x4xi32>,
// -----
+func.func @vector_mask_non_empty_external_return(%t0: tensor<?xf32>, %idx: index,
+ %m0: vector<16xi1>, %ext: vector<16xf32>) -> vector<16xf32> {
+ %ft0 = arith.constant 0.0 : f32
+ // expected-error@+1 {{'vector.mask' op expects all the results from the MaskableOpInterface to be returned by the terminator}}
+ %0 = vector.mask %m0 {
+ %1 =vector.transfer_read %t0[%idx], %ft0 : tensor<?xf32>, vector<16xf32>
+ vector.yield %ext : vector<16xf32>
+ } : vector<16xi1> -> vector<16xf32>
+
+ return %0 : vector<16xf32>
+}
+
+// -----
+
+func.func @vector_mask_non_empty_mixed_return(%t0: tensor<?xf32>, %idx: index,
+ %m0: vector<16xi1>, %ext: vector<16xf32>) -> (vector<16xf32>, vector<16xf32>) {
+ %ft0 = arith.constant 0.0 : f32
+ // expected-error@+1 {{'vector.mask' op expects number of results to match maskable operation number of results}}
+ %0:2 = vector.mask %m0 {
+ %1 =vector.transfer_read %t0[%idx], %ft0 : tensor<?xf32>, vector<16xf32>
+ vector.yield %1, %ext : vector<16xf32>, vector<16xf32>
+ } : vector<16xi1> -> (vector<16xf32>, vector<16xf32>)
+
+ return %0#0, %0#1 : vector<16xf32>, vector<16xf32>
+}
+
+// -----
+
func.func @vector_scalable_insert_unaligned(%subv: vector<4xi32>, %vec: vector<[16]xi32>) {
// expected-error@+1 {{op failed to verify that position is a multiple of the source length.}}
%0 = vector.scalable.insert %subv, %vec[2] : vector<4xi32> into vector<[16]xi32>
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice, thank you!
I've left a few comments, but these are mostly questions so that I better understand the underlying design.
|
|
||
| // ----- | ||
|
|
||
| func.func @vector_mask_non_empty_external_return(%t0: tensor<?xf32>, %idx: index, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[ultranit] In the spirit of "less is more" 😅 (similar suggestion elsewhere)
| func.func @vector_mask_non_empty_external_return(%t0: tensor<?xf32>, %idx: index, | |
| func.func @vector_mask_non_empty_external_return(%t: tensor<?xf32>, %idx: index, |
| terminator yields all results from the maskable operation to the result of | ||
| this operation. No other values are allowed to be yielded. | ||
|
|
||
| An empty `vector.mask` operation is considered ill-formed but legal to |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't want come across as nit-picking, but what does "ill-formed" mean? To me, the sentence that follows suggests that empty masks are totally fine, no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Better now?
| "expects all the results from the MaskableOpInterface to " | ||
| "be returned by the terminator"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will trigger even when "all the results from the MaskableOpInterface" are returned by the terminator.
| "expects all the results from the MaskableOpInterface to " | |
| "be returned by the terminator"); | |
| "expects all the results from the MaskableOpInterface to " | |
| "match all the values returned by the terminator"); |
| // Create default terminator if there are no ops to mask. | ||
| if (region.empty() || region.front().empty()) { | ||
| OpTrait::SingleBlockImplicitTerminator<vector::YieldOp>::Impl< | ||
| MaskOp>::ensureTerminator(region, builder, loc); | ||
| return; | ||
| } | ||
|
|
||
| // Replace default yield terminator with a new one that returns the results | ||
| // from the masked operation. | ||
| OpBuilder opBuilder(builder.getContext()); | ||
| Operation *maskedOp = &block.front(); | ||
| Operation *oldYieldOp = &block.back(); | ||
| assert(isa<vector::YieldOp>(oldYieldOp) && "Expected vector::YieldOp"); | ||
| // If region has an explicit terminator, we don't modify it. | ||
| Block &block = region.front(); | ||
| if (isa<vector::YieldOp>(block.back())) | ||
| return; | ||
|
|
||
| // Empty vector.mask op. | ||
| if (maskedOp == oldYieldOp) | ||
| // Create default terminator if the number of masked operations is not | ||
| // one. This case will trigger a verification failure. | ||
| if (block.getOperations().size() != 1) { | ||
| OpTrait::SingleBlockImplicitTerminator<vector::YieldOp>::Impl< | ||
| MaskOp>::ensureTerminator(region, builder, loc); | ||
| return; | ||
| } | ||
|
|
||
| opBuilder.setInsertionPoint(oldYieldOp); | ||
| // Create a terminator that yields the results from the masked operation. | ||
| OpBuilder opBuilder(builder.getContext()); | ||
| Operation *maskedOp = &block.front(); | ||
| opBuilder.setInsertionPointToEnd(&block); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not immediately clear what the cases are. IIUC, it's something like this:
// 1. For an empty `vector.mask`, create a default terminator.
if (region.empty() || region.front().empty()) {
OpTrait::SingleBlockImplicitTerminator<vector::YieldOp>::Impl<
MaskOp>::ensureTerminator(region, builder, loc);
return;
}
// 2. For a non-empty `vector.mask` _with_ an existing terminator, do nothing.
Block &block = region.front();
if (isa<vector::YieldOp>(block.back()))
return;
// 3. For a non-empty `vector.mask` _without_ a terminator, split into two cases.
// 3.1. If the number of masked operations is != 1, create the default terminator (this case is invalid and will be flagged by the Op verifier).
if (block.getOperations().size() != 1) {
OpTrait::SingleBlockImplicitTerminator<vector::YieldOp>::Impl<
MaskOp>::ensureTerminator(region, builder, loc);
return;
}
// 3.2 Otherwise, create a terminator that yields all the results from the masked operation.
OpBuilder opBuilder(builder.getContext());
Operation *maskedOp = &block.front();
opBuilder.setInsertionPointToEnd(&block);
opBuilder.create<vector::YieldOp>(loc, maskedOp->getResults());? Did I get it correctly?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I highlighted the three use cases
| func.func @vector_mask_non_empty_mixed_return(%t0: tensor<?xf32>, %idx: index, | ||
| %m0: vector<16xi1>, %ext: vector<16xf32>) -> (vector<16xf32>, vector<16xf32>) { | ||
| %ft0 = arith.constant 0.0 : f32 | ||
| // expected-error@+1 {{'vector.mask' op expects number of results to match maskable operation number of results}} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This fails today with this error:
error: 'vector.mask' op expects number of results to match mask region yielded values
Its not clear to me what generates this new error.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The original code in ensureTerminator was replacing the explicitly-provided terminator with a different one.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the feedback!
| terminator yields all results from the maskable operation to the result of | ||
| this operation. No other values are allowed to be yielded. | ||
|
|
||
| An empty `vector.mask` operation is considered ill-formed but legal to |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Better now?
| // Create default terminator if there are no ops to mask. | ||
| if (region.empty() || region.front().empty()) { | ||
| OpTrait::SingleBlockImplicitTerminator<vector::YieldOp>::Impl< | ||
| MaskOp>::ensureTerminator(region, builder, loc); | ||
| return; | ||
| } | ||
|
|
||
| // Replace default yield terminator with a new one that returns the results | ||
| // from the masked operation. | ||
| OpBuilder opBuilder(builder.getContext()); | ||
| Operation *maskedOp = &block.front(); | ||
| Operation *oldYieldOp = &block.back(); | ||
| assert(isa<vector::YieldOp>(oldYieldOp) && "Expected vector::YieldOp"); | ||
| // If region has an explicit terminator, we don't modify it. | ||
| Block &block = region.front(); | ||
| if (isa<vector::YieldOp>(block.back())) | ||
| return; | ||
|
|
||
| // Empty vector.mask op. | ||
| if (maskedOp == oldYieldOp) | ||
| // Create default terminator if the number of masked operations is not | ||
| // one. This case will trigger a verification failure. | ||
| if (block.getOperations().size() != 1) { | ||
| OpTrait::SingleBlockImplicitTerminator<vector::YieldOp>::Impl< | ||
| MaskOp>::ensureTerminator(region, builder, loc); | ||
| return; | ||
| } | ||
|
|
||
| opBuilder.setInsertionPoint(oldYieldOp); | ||
| // Create a terminator that yields the results from the masked operation. | ||
| OpBuilder opBuilder(builder.getContext()); | ||
| Operation *maskedOp = &block.front(); | ||
| opBuilder.setInsertionPointToEnd(&block); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I highlighted the three use cases
| func.func @vector_mask_non_empty_mixed_return(%t0: tensor<?xf32>, %idx: index, | ||
| %m0: vector<16xi1>, %ext: vector<16xf32>) -> (vector<16xf32>, vector<16xf32>) { | ||
| %ft0 = arith.constant 0.0 : f32 | ||
| // expected-error@+1 {{'vector.mask' op expects number of results to match maskable operation number of results}} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The original code in ensureTerminator was replacing the explicitly-provided terminator with a different one.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice improvement to the verifier, LGTM. My one comment isn't directly related to the PR changes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[not directly related to PR] vector.mask is not an op I've encountered before so I'm still wrapping my head around it. I guess this paragraph is tested here. I think it makes sense that
%0 = vector.mask %mask { vector.yield %a : vector<8xf32> } : vector<8xi1> -> vector<8xf32>
return %0 : vector<8xf32> just becomes return %a : vector<8xf32> based on the rest of the definition. But I'm wondering about the case where there is a passthru value,
%0 = vector.mask %mask, %passthru { vector.yield %a : vector<8xf32> } : vector<8xi1> -> vector<8xf32>
return %0 : vector<8xf32> should this not become
%0 = arith.select %mask, %a, %passthru : vector<8xi1>, vector<8xf32>
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That makes sense from a correctness perspective. Let me add that canonicalization separately
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After seeing the comment from @newling , I've realised that we are special-casing an empty vector.mask. This example will not trigger the error:
%0 = vector.mask %mask { vector.yield %a : vector<8xf32> } : vector<8xi1> -> vector<8xf32>That's a bit of inconsistency. Perhaps leave a TODO to address this at some point?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, my intent with this PR is to make clearer that an empty vector.mask is not a common valid case to mask operations and that may eventually go away. Let me clarify that a bit better in the doc. We would need the CSE equivalence issue to be fixed and improve some of the existing vector transformations. Definitely a target we should be moving towards.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1
I would still appreciate a TODO or some comment - mostly for my future self as a reminder about this conversation :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
Thanks for addressing my comments. I've left one more suggestion, but that's non-blocking and possibly a discussion for a separate PR.
6da9abb to
57c381f
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for all the updates, LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1
I would still appreciate a TODO or some comment - mostly for my future self as a reminder about this conversation :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
This PR improves the verifier for the `vector.mask` operation to make sure it's not applying masking semantics to operations defined outside of the `vector.mask` region. Documentation is updated to emphasize that and make it clearer, even though it already stated that. As part of this change, the logic that ensures that a terminator is present in the region mask has been simplified to make it less surprising to the user when a `vector.yield` is explicitly provided.
57c381f to
2da41ac
Compare
This PR improves the
vector.maskverifier to make sure it's not applying masking semantics to operations defined outside of thevector.maskregion. Documentation is updated to emphasize that and make it clearer, even though it already stated that.As part of this change, the logic that ensures that a terminator is present in the region mask has been simplified to make it less surprising to the user when a
vector.yieldis explicitly provided in the IR.