From 2f14c0f69b6f704aa2d21d0a566ce8d93a268f77 Mon Sep 17 00:00:00 2001 From: Diego Caballero Date: Tue, 13 May 2025 23:39:34 +0000 Subject: [PATCH 1/4] [mlir][Vector] Improve `vector.mask` verifier This PR improves the verifier for the `vector.mask` operation to make sure it's not applying masking semantics to operations defined outside of the `vector.mask` region. Documentation is updated to emphasize that and make it clearer, even though it already stated that. As part of this change, the logic that ensures that a terminator is present in the region mask has been simplified to make it less surprising to the user when a `vector.yield` is explicitly provided. --- .../mlir/Dialect/Vector/IR/VectorOps.td | 9 +++- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 41 +++++++++++-------- mlir/test/Dialect/Vector/invalid.mlir | 28 +++++++++++++ 3 files changed, 59 insertions(+), 19 deletions(-) diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index 3aefcea8de994..2820759687293 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -2482,8 +2482,13 @@ def Vector_MaskOp : Vector_Op<"mask", [ masked. Values used within the region are captured from above. Only one *maskable* operation can be masked with a `vector.mask` operation at a time. An operation is *maskable* if it implements the `MaskableOpInterface`. The - terminator yields all results of the maskable operation to the result of - this operation. + terminator yields all results from the maskable operation to the result of + this operation. No other values are allowed to be yielded. + + An empty `vector.mask` operation is considered ill-formed but legal to + facilitate optimizations across the `vector.mask` operation. It is considered + a no-op regardless of its returned values and will be removed by the + canonicalizer. The vector mask argument holds a bit for each vector lane and determines which vector lanes should execute the maskable operation and which ones diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 1b5534d4d94ff..25c3f439e877a 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -6550,29 +6550,31 @@ void mlir::vector::MaskOp::print(OpAsmPrinter &p) { } void MaskOp::ensureTerminator(Region ®ion, Builder &builder, Location loc) { - OpTrait::SingleBlockImplicitTerminator::Impl< - MaskOp>::ensureTerminator(region, builder, loc); - // Keep the default yield terminator if the number of masked operations is not - // the expected. This case will trigger a verification failure. - Block &block = region.front(); - if (block.getOperations().size() != 2) + // Create default terminator if there are no ops to mask. + if (region.empty() || region.front().empty()) { + OpTrait::SingleBlockImplicitTerminator::Impl< + MaskOp>::ensureTerminator(region, builder, loc); return; + } - // Replace default yield terminator with a new one that returns the results - // from the masked operation. - OpBuilder opBuilder(builder.getContext()); - Operation *maskedOp = &block.front(); - Operation *oldYieldOp = &block.back(); - assert(isa(oldYieldOp) && "Expected vector::YieldOp"); + // If region has an explicit terminator, we don't modify it. + Block &block = region.front(); + if (isa(block.back())) + return; - // Empty vector.mask op. - if (maskedOp == oldYieldOp) + // Create default terminator if the number of masked operations is not + // one. This case will trigger a verification failure. + if (block.getOperations().size() != 1) { + OpTrait::SingleBlockImplicitTerminator::Impl< + MaskOp>::ensureTerminator(region, builder, loc); return; + } - opBuilder.setInsertionPoint(oldYieldOp); + // Create a terminator that yields the results from the masked operation. + OpBuilder opBuilder(builder.getContext()); + Operation *maskedOp = &block.front(); + opBuilder.setInsertionPointToEnd(&block); opBuilder.create(loc, maskedOp->getResults()); - oldYieldOp->dropAllReferences(); - oldYieldOp->erase(); } LogicalResult MaskOp::verify() { @@ -6607,6 +6609,11 @@ LogicalResult MaskOp::verify() { return emitOpError("expects number of results to match maskable operation " "number of results"); + if (!llvm::equal(maskableOp->getResults(), terminator.getOperands())) + return emitOpError( + "expects all the results from the MaskableOpInterface to " + "be returned by the terminator"); + if (!llvm::equal(maskableOp->getResultTypes(), getResultTypes())) return emitOpError( "expects result type to match maskable operation result type"); diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir index 740c6b7ae3174..bebb617f8aa09 100644 --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -1756,6 +1756,20 @@ func.func @vector_mask_empty_passthru_no_return_type(%mask : vector<8xi1>, // ----- +func.func @vector_mask_non_empty_external_return(%t0: tensor, %idx: index, + %m0: vector<16xi1>, %ext: vector<16xf32>) -> vector<16xf32> { + %ft0 = arith.constant 0.0 : f32 + // expected-error@+1 {{'vector.mask' op expects all the results from the MaskableOpInterface to be returned by the terminator}} + %0 = vector.mask %m0 { + %1 =vector.transfer_read %t0[%idx], %ft0 : tensor, vector<16xf32> + vector.yield %ext : vector<16xf32> + } : vector<16xi1> -> vector<16xf32> + + return %0 : vector<16xf32> +} + +// ----- + func.func @vector_mask_empty_passthru_empty_return_type(%mask : vector<8xi1>, %passthru : vector<8xi32>) { // expected-error@+1 {{'vector.mask' expects a result if passthru operand is provided}} @@ -1765,6 +1779,20 @@ func.func @vector_mask_empty_passthru_empty_return_type(%mask : vector<8xi1>, // ----- +func.func @vector_mask_non_empty_mixed_return(%t0: tensor, %idx: index, + %m0: vector<16xi1>, %ext: vector<16xf32>) -> (vector<16xf32>, vector<16xf32>) { + %ft0 = arith.constant 0.0 : f32 + // expected-error@+1 {{'vector.mask' op expects number of results to match maskable operation number of results}} + %0:2 = vector.mask %m0 { + %1 =vector.transfer_read %t0[%idx], %ft0 : tensor, vector<16xf32> + vector.yield %1, %ext : vector<16xf32>, vector<16xf32> + } : vector<16xi1> -> (vector<16xf32>, vector<16xf32>) + + return %0#0, %0#1 : vector<16xf32>, vector<16xf32> +} + +// ----- + func.func @vector_scalable_insert_unaligned(%subv: vector<4xi32>, %vec: vector<[16]xi32>) { // expected-error@+1 {{op failed to verify that position is a multiple of the source length.}} %0 = vector.scalable.insert %subv, %vec[2] : vector<4xi32> into vector<[16]xi32> From d716bbb0871d653d19d1b57d7baf6f8fcdd4882f Mon Sep 17 00:00:00 2001 From: Diego Caballero Date: Wed, 14 May 2025 17:55:30 +0000 Subject: [PATCH 2/4] Review feedback --- .../mlir/Dialect/Vector/IR/VectorOps.td | 7 +++---- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 11 ++++++----- mlir/test/Dialect/Vector/invalid.mlir | 18 +++++++++--------- 3 files changed, 18 insertions(+), 18 deletions(-) diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index 2820759687293..e4c66a42ea333 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -2485,10 +2485,9 @@ def Vector_MaskOp : Vector_Op<"mask", [ terminator yields all results from the maskable operation to the result of this operation. No other values are allowed to be yielded. - An empty `vector.mask` operation is considered ill-formed but legal to - facilitate optimizations across the `vector.mask` operation. It is considered - a no-op regardless of its returned values and will be removed by the - canonicalizer. + An empty `vector.mask` operation is legal to facilitate optimizations across + the `vector.mask` operation. However, it is considered a no-op regardless of + its returned values and will be removed by the canonicalizer. The vector mask argument holds a bit for each vector lane and determines which vector lanes should execute the maskable operation and which ones diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 25c3f439e877a..bbb366b01fa6e 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -6550,18 +6550,20 @@ void mlir::vector::MaskOp::print(OpAsmPrinter &p) { } void MaskOp::ensureTerminator(Region ®ion, Builder &builder, Location loc) { - // Create default terminator if there are no ops to mask. + // 1. For an empty `vector.mask`, create a default terminator. if (region.empty() || region.front().empty()) { OpTrait::SingleBlockImplicitTerminator::Impl< MaskOp>::ensureTerminator(region, builder, loc); return; } - // If region has an explicit terminator, we don't modify it. + // 2. For a non-empty `vector.mask` with an explicit terminator, do nothing. Block &block = region.front(); if (isa(block.back())) return; + // 3. For a non-empty `vector.mask` without an explicit terminator: + // Create default terminator if the number of masked operations is not // one. This case will trigger a verification failure. if (block.getOperations().size() != 1) { @@ -6610,9 +6612,8 @@ LogicalResult MaskOp::verify() { "number of results"); if (!llvm::equal(maskableOp->getResults(), terminator.getOperands())) - return emitOpError( - "expects all the results from the MaskableOpInterface to " - "be returned by the terminator"); + return emitOpError("expects all the results from the MaskableOpInterface " + "to match all the values returned by the terminator"); if (!llvm::equal(maskableOp->getResultTypes(), getResultTypes())) return emitOpError( diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir index bebb617f8aa09..04810ed52584f 100644 --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -1756,12 +1756,12 @@ func.func @vector_mask_empty_passthru_no_return_type(%mask : vector<8xi1>, // ----- -func.func @vector_mask_non_empty_external_return(%t0: tensor, %idx: index, - %m0: vector<16xi1>, %ext: vector<16xf32>) -> vector<16xf32> { +func.func @vector_mask_non_empty_external_return(%t: tensor, %idx: index, + %m: vector<16xi1>, %ext: vector<16xf32>) -> vector<16xf32> { %ft0 = arith.constant 0.0 : f32 - // expected-error@+1 {{'vector.mask' op expects all the results from the MaskableOpInterface to be returned by the terminator}} - %0 = vector.mask %m0 { - %1 =vector.transfer_read %t0[%idx], %ft0 : tensor, vector<16xf32> + // expected-error@+1 {{'vector.mask' op expects all the results from the MaskableOpInterface to match all the values returned by the terminator}} + %0 = vector.mask %m { + %1 =vector.transfer_read %t[%idx], %ft0 : tensor, vector<16xf32> vector.yield %ext : vector<16xf32> } : vector<16xi1> -> vector<16xf32> @@ -1779,12 +1779,12 @@ func.func @vector_mask_empty_passthru_empty_return_type(%mask : vector<8xi1>, // ----- -func.func @vector_mask_non_empty_mixed_return(%t0: tensor, %idx: index, - %m0: vector<16xi1>, %ext: vector<16xf32>) -> (vector<16xf32>, vector<16xf32>) { +func.func @vector_mask_non_empty_mixed_return(%t: tensor, %idx: index, + %m: vector<16xi1>, %ext: vector<16xf32>) -> (vector<16xf32>, vector<16xf32>) { %ft0 = arith.constant 0.0 : f32 // expected-error@+1 {{'vector.mask' op expects number of results to match maskable operation number of results}} - %0:2 = vector.mask %m0 { - %1 =vector.transfer_read %t0[%idx], %ft0 : tensor, vector<16xf32> + %0:2 = vector.mask %m { + %1 =vector.transfer_read %t[%idx], %ft0 : tensor, vector<16xf32> vector.yield %1, %ext : vector<16xf32>, vector<16xf32> } : vector<16xi1> -> (vector<16xf32>, vector<16xf32>) From 13917011878ef930d11b07b58d3982d07887a97f Mon Sep 17 00:00:00 2001 From: Diego Caballero Date: Fri, 16 May 2025 22:49:44 +0000 Subject: [PATCH 3/4] Improve doc --- mlir/include/mlir/Dialect/Vector/IR/VectorOps.td | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index e4c66a42ea333..cd734c796582f 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -2485,9 +2485,9 @@ def Vector_MaskOp : Vector_Op<"mask", [ terminator yields all results from the maskable operation to the result of this operation. No other values are allowed to be yielded. - An empty `vector.mask` operation is legal to facilitate optimizations across - the `vector.mask` operation. However, it is considered a no-op regardless of - its returned values and will be removed by the canonicalizer. + An empty `vector.mask` operation is currently legal to enable optimizations + across the `vector.mask` region. However, this might change in the future + once vector transformations gain better support for `vector.mask`. The vector mask argument holds a bit for each vector lane and determines which vector lanes should execute the maskable operation and which ones From 2da41ace7085ed6bcfd98dc0388aa0169403f9a0 Mon Sep 17 00:00:00 2001 From: Diego Caballero Date: Tue, 20 May 2025 22:17:53 +0000 Subject: [PATCH 4/4] Add TODO --- mlir/include/mlir/Dialect/Vector/IR/VectorOps.td | 1 + 1 file changed, 1 insertion(+) diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index cd734c796582f..5e8421ed67d66 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -2488,6 +2488,7 @@ def Vector_MaskOp : Vector_Op<"mask", [ An empty `vector.mask` operation is currently legal to enable optimizations across the `vector.mask` region. However, this might change in the future once vector transformations gain better support for `vector.mask`. + TODO: Consider making empty `vector.mask` illegal. The vector mask argument holds a bit for each vector lane and determines which vector lanes should execute the maskable operation and which ones