Skip to content

Conversation

@dcaballe
Copy link
Contributor

This PR improves the vector.mask verifier to make sure it's not applying masking semantics to operations defined outside of the vector.mask region. Documentation is updated to emphasize that and make it clearer, even though it already stated that.

As part of this change, the logic that ensures that a terminator is present in the region mask has been simplified to make it less surprising to the user when a vector.yield is explicitly provided in the IR.

@llvmbot
Copy link
Member

llvmbot commented May 14, 2025

@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Author: Diego Caballero (dcaballe)

Changes

This PR improves the vector.mask verifier to make sure it's not applying masking semantics to operations defined outside of the vector.mask region. Documentation is updated to emphasize that and make it clearer, even though it already stated that.

As part of this change, the logic that ensures that a terminator is present in the region mask has been simplified to make it less surprising to the user when a vector.yield is explicitly provided in the IR.


Full diff: https://github.com/llvm/llvm-project/pull/139823.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Vector/IR/VectorOps.td (+7-2)
  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+24-17)
  • (modified) mlir/test/Dialect/Vector/invalid.mlir (+28)
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 3aefcea8de994..2820759687293 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2482,8 +2482,13 @@ def Vector_MaskOp : Vector_Op<"mask", [
     masked. Values used within the region are captured from above. Only one
     *maskable* operation can be masked with a `vector.mask` operation at a time.
     An operation is *maskable* if it implements the `MaskableOpInterface`. The
-    terminator yields all results of the maskable operation to the result of
-    this operation.
+    terminator yields all results from the maskable operation to the result of
+    this operation. No other values are allowed to be yielded.
+
+    An empty `vector.mask` operation is considered ill-formed but legal to
+    facilitate optimizations across the `vector.mask` operation. It is considered
+    a no-op regardless of its returned values and will be removed by the
+    canonicalizer.
 
     The vector mask argument holds a bit for each vector lane and determines
     which vector lanes should execute the maskable operation and which ones
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index f6c3c6a61afb6..395680b5e814b 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -6543,29 +6543,31 @@ void mlir::vector::MaskOp::print(OpAsmPrinter &p) {
 }
 
 void MaskOp::ensureTerminator(Region &region, Builder &builder, Location loc) {
-  OpTrait::SingleBlockImplicitTerminator<vector::YieldOp>::Impl<
-      MaskOp>::ensureTerminator(region, builder, loc);
-  // Keep the default yield terminator if the number of masked operations is not
-  // the expected. This case will trigger a verification failure.
-  Block &block = region.front();
-  if (block.getOperations().size() != 2)
+  // Create default terminator if there are no ops to mask.
+  if (region.empty() || region.front().empty()) {
+    OpTrait::SingleBlockImplicitTerminator<vector::YieldOp>::Impl<
+        MaskOp>::ensureTerminator(region, builder, loc);
     return;
+  }
 
-  // Replace default yield terminator with a new one that returns the results
-  // from the masked operation.
-  OpBuilder opBuilder(builder.getContext());
-  Operation *maskedOp = &block.front();
-  Operation *oldYieldOp = &block.back();
-  assert(isa<vector::YieldOp>(oldYieldOp) && "Expected vector::YieldOp");
+  // If  region has an explicit terminator, we don't modify it.
+  Block &block = region.front();
+  if (isa<vector::YieldOp>(block.back()))
+    return;
 
-  // Empty vector.mask op.
-  if (maskedOp == oldYieldOp)
+  // Create default terminator if the number of masked operations is not
+  // one. This case will trigger a verification failure.
+  if (block.getOperations().size() != 1) {
+    OpTrait::SingleBlockImplicitTerminator<vector::YieldOp>::Impl<
+        MaskOp>::ensureTerminator(region, builder, loc);
     return;
+  }
 
-  opBuilder.setInsertionPoint(oldYieldOp);
+  // Create a terminator that yields the results from the masked operation.
+  OpBuilder opBuilder(builder.getContext());
+  Operation *maskedOp = &block.front();
+  opBuilder.setInsertionPointToEnd(&block);
   opBuilder.create<vector::YieldOp>(loc, maskedOp->getResults());
-  oldYieldOp->dropAllReferences();
-  oldYieldOp->erase();
 }
 
 LogicalResult MaskOp::verify() {
@@ -6600,6 +6602,11 @@ LogicalResult MaskOp::verify() {
     return emitOpError("expects number of results to match maskable operation "
                        "number of results");
 
+  if (!llvm::equal(maskableOp->getResults(), terminator.getOperands()))
+    return emitOpError(
+        "expects all the results from the MaskableOpInterface to "
+        "be returned by the terminator");
+
   if (!llvm::equal(maskableOp->getResultTypes(), getResultTypes()))
     return emitOpError(
         "expects result type to match maskable operation result type");
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index be65d4c2eef58..de6333ca8dc7b 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1747,6 +1747,34 @@ func.func @vector_mask_0d_mask(%arg0: tensor<2x4xi32>,
 
 // -----
 
+func.func @vector_mask_non_empty_external_return(%t0: tensor<?xf32>, %idx: index,
+                                                 %m0: vector<16xi1>, %ext: vector<16xf32>) -> vector<16xf32> {
+  %ft0 = arith.constant 0.0 : f32
+  // expected-error@+1 {{'vector.mask' op expects all the results from the MaskableOpInterface to be returned by the terminator}}
+  %0 = vector.mask %m0 {
+    %1 =vector.transfer_read %t0[%idx], %ft0 : tensor<?xf32>, vector<16xf32>
+    vector.yield %ext : vector<16xf32>
+  } : vector<16xi1> -> vector<16xf32>
+
+  return %0 : vector<16xf32>
+}
+
+// -----
+
+func.func @vector_mask_non_empty_mixed_return(%t0: tensor<?xf32>, %idx: index,
+                                              %m0: vector<16xi1>, %ext: vector<16xf32>) -> (vector<16xf32>, vector<16xf32>) {
+  %ft0 = arith.constant 0.0 : f32
+  // expected-error@+1 {{'vector.mask' op expects number of results to match maskable operation number of results}}
+  %0:2 = vector.mask %m0 {
+    %1 =vector.transfer_read %t0[%idx], %ft0 : tensor<?xf32>, vector<16xf32>
+    vector.yield %1, %ext : vector<16xf32>, vector<16xf32>
+  } : vector<16xi1> -> (vector<16xf32>, vector<16xf32>)
+
+  return %0#0, %0#1 : vector<16xf32>, vector<16xf32>
+}
+
+// -----
+
 func.func @vector_scalable_insert_unaligned(%subv: vector<4xi32>, %vec: vector<[16]xi32>) {
   // expected-error@+1 {{op failed to verify that position is a multiple of the source length.}}
   %0 = vector.scalable.insert %subv, %vec[2] : vector<4xi32> into vector<[16]xi32>

@dcaballe dcaballe requested a review from newling May 14, 2025 01:32
Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, thank you!

I've left a few comments, but these are mostly questions so that I better understand the underlying design.


// -----

func.func @vector_mask_non_empty_external_return(%t0: tensor<?xf32>, %idx: index,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[ultranit] In the spirit of "less is more" 😅 (similar suggestion elsewhere)

Suggested change
func.func @vector_mask_non_empty_external_return(%t0: tensor<?xf32>, %idx: index,
func.func @vector_mask_non_empty_external_return(%t: tensor<?xf32>, %idx: index,

terminator yields all results from the maskable operation to the result of
this operation. No other values are allowed to be yielded.

An empty `vector.mask` operation is considered ill-formed but legal to
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't want come across as nit-picking, but what does "ill-formed" mean? To me, the sentence that follows suggests that empty masks are totally fine, no?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better now?

Comment on lines 6607 to 6608
"expects all the results from the MaskableOpInterface to "
"be returned by the terminator");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will trigger even when "all the results from the MaskableOpInterface" are returned by the terminator.

Suggested change
"expects all the results from the MaskableOpInterface to "
"be returned by the terminator");
"expects all the results from the MaskableOpInterface to "
"match all the values returned by the terminator");

Comment on lines 6546 to 6578
// Create default terminator if there are no ops to mask.
if (region.empty() || region.front().empty()) {
OpTrait::SingleBlockImplicitTerminator<vector::YieldOp>::Impl<
MaskOp>::ensureTerminator(region, builder, loc);
return;
}

// Replace default yield terminator with a new one that returns the results
// from the masked operation.
OpBuilder opBuilder(builder.getContext());
Operation *maskedOp = &block.front();
Operation *oldYieldOp = &block.back();
assert(isa<vector::YieldOp>(oldYieldOp) && "Expected vector::YieldOp");
// If region has an explicit terminator, we don't modify it.
Block &block = region.front();
if (isa<vector::YieldOp>(block.back()))
return;

// Empty vector.mask op.
if (maskedOp == oldYieldOp)
// Create default terminator if the number of masked operations is not
// one. This case will trigger a verification failure.
if (block.getOperations().size() != 1) {
OpTrait::SingleBlockImplicitTerminator<vector::YieldOp>::Impl<
MaskOp>::ensureTerminator(region, builder, loc);
return;
}

opBuilder.setInsertionPoint(oldYieldOp);
// Create a terminator that yields the results from the masked operation.
OpBuilder opBuilder(builder.getContext());
Operation *maskedOp = &block.front();
opBuilder.setInsertionPointToEnd(&block);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not immediately clear what the cases are. IIUC, it's something like this:

  // 1. For an empty `vector.mask`, create a default terminator.
  if (region.empty() || region.front().empty()) {
    OpTrait::SingleBlockImplicitTerminator<vector::YieldOp>::Impl<
        MaskOp>::ensureTerminator(region, builder, loc);
    return;
  }
  
  // 2. For a non-empty `vector.mask` _with_ an existing terminator, do nothing.
  Block &block = region.front();
  if (isa<vector::YieldOp>(block.back()))
    return;
    
  // 3. For a non-empty `vector.mask` _without_ a terminator, split into two cases.
  // 3.1. If the number of masked operations is != 1, create the default terminator (this case is invalid and will be flagged by the Op verifier).
  if (block.getOperations().size() != 1) {
    OpTrait::SingleBlockImplicitTerminator<vector::YieldOp>::Impl<
        MaskOp>::ensureTerminator(region, builder, loc);
    return;
  }
  
  // 3.2 Otherwise, create a terminator that yields all the results from the masked operation.
  OpBuilder opBuilder(builder.getContext());
  Operation *maskedOp = &block.front();
  opBuilder.setInsertionPointToEnd(&block);
  opBuilder.create<vector::YieldOp>(loc, maskedOp->getResults());

? Did I get it correctly?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I highlighted the three use cases

func.func @vector_mask_non_empty_mixed_return(%t0: tensor<?xf32>, %idx: index,
%m0: vector<16xi1>, %ext: vector<16xf32>) -> (vector<16xf32>, vector<16xf32>) {
%ft0 = arith.constant 0.0 : f32
// expected-error@+1 {{'vector.mask' op expects number of results to match maskable operation number of results}}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This fails today with this error:

 error: 'vector.mask' op expects number of results to match mask region yielded values

Its not clear to me what generates this new error.

Copy link
Contributor Author

@dcaballe dcaballe May 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The original code in ensureTerminator was replacing the explicitly-provided terminator with a different one.

Copy link
Contributor Author

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the feedback!

terminator yields all results from the maskable operation to the result of
this operation. No other values are allowed to be yielded.

An empty `vector.mask` operation is considered ill-formed but legal to
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better now?

Comment on lines 6546 to 6578
// Create default terminator if there are no ops to mask.
if (region.empty() || region.front().empty()) {
OpTrait::SingleBlockImplicitTerminator<vector::YieldOp>::Impl<
MaskOp>::ensureTerminator(region, builder, loc);
return;
}

// Replace default yield terminator with a new one that returns the results
// from the masked operation.
OpBuilder opBuilder(builder.getContext());
Operation *maskedOp = &block.front();
Operation *oldYieldOp = &block.back();
assert(isa<vector::YieldOp>(oldYieldOp) && "Expected vector::YieldOp");
// If region has an explicit terminator, we don't modify it.
Block &block = region.front();
if (isa<vector::YieldOp>(block.back()))
return;

// Empty vector.mask op.
if (maskedOp == oldYieldOp)
// Create default terminator if the number of masked operations is not
// one. This case will trigger a verification failure.
if (block.getOperations().size() != 1) {
OpTrait::SingleBlockImplicitTerminator<vector::YieldOp>::Impl<
MaskOp>::ensureTerminator(region, builder, loc);
return;
}

opBuilder.setInsertionPoint(oldYieldOp);
// Create a terminator that yields the results from the masked operation.
OpBuilder opBuilder(builder.getContext());
Operation *maskedOp = &block.front();
opBuilder.setInsertionPointToEnd(&block);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I highlighted the three use cases

func.func @vector_mask_non_empty_mixed_return(%t0: tensor<?xf32>, %idx: index,
%m0: vector<16xi1>, %ext: vector<16xf32>) -> (vector<16xf32>, vector<16xf32>) {
%ft0 = arith.constant 0.0 : f32
// expected-error@+1 {{'vector.mask' op expects number of results to match maskable operation number of results}}
Copy link
Contributor Author

@dcaballe dcaballe May 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The original code in ensureTerminator was replacing the explicitly-provided terminator with a different one.

Copy link
Contributor

@newling newling left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice improvement to the verifier, LGTM. My one comment isn't directly related to the PR changes

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[not directly related to PR] vector.mask is not an op I've encountered before so I'm still wrapping my head around it. I guess this paragraph is tested here. I think it makes sense that

%0 = vector.mask %mask { vector.yield %a : vector<8xf32> } : vector<8xi1> -> vector<8xf32>
return %0 : vector<8xf32> 

just becomes return %a : vector<8xf32> based on the rest of the definition. But I'm wondering about the case where there is a passthru value,

%0 = vector.mask %mask, %passthru { vector.yield %a : vector<8xf32> } : vector<8xi1> -> vector<8xf32>
return %0 : vector<8xf32> 

should this not become

 %0 = arith.select %mask, %a, %passthru : vector<8xi1>, vector<8xf32>

?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense from a correctness perspective. Let me add that canonicalization separately

Comment on lines +6608 to +6616
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After seeing the comment from @newling , I've realised that we are special-casing an empty vector.mask. This example will not trigger the error:

%0 = vector.mask %mask { vector.yield %a : vector<8xf32> } : vector<8xi1> -> vector<8xf32>

That's a bit of inconsistency. Perhaps leave a TODO to address this at some point?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, my intent with this PR is to make clearer that an empty vector.mask is not a common valid case to mask operations and that may eventually go away. Let me clarify that a bit better in the doc. We would need the CSE equivalence issue to be fixed and improve some of the existing vector transformations. Definitely a target we should be moving towards.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

I would still appreciate a TODO or some comment - mostly for my future self as a reminder about this conversation :)

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

Thanks for addressing my comments. I've left one more suggestion, but that's non-blocking and possibly a discussion for a separate PR.

@dcaballe dcaballe force-pushed the verify-vector-mask-nesting branch from 6da9abb to 57c381f Compare May 16, 2025 22:50
Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for all the updates, LGTM

Comment on lines +6608 to +6616
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

I would still appreciate a TODO or some comment - mostly for my future self as a reminder about this conversation :)

Copy link
Contributor

@newling newling left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

dcaballe added 4 commits May 20, 2025 22:08
This PR improves the verifier for the `vector.mask` operation to
make sure it's not applying masking semantics to operations defined
outside of the `vector.mask` region. Documentation is updated to
emphasize that and make it clearer, even though it already stated that.

As part of this change, the logic that ensures that a terminator is
present in the region mask has been simplified to make it less surprising
to the user when a `vector.yield` is explicitly provided.
@dcaballe dcaballe force-pushed the verify-vector-mask-nesting branch from 57c381f to 2da41ac Compare May 20, 2025 22:18
@dcaballe dcaballe merged commit 6cac792 into llvm:main May 20, 2025
7 of 10 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants