-
Notifications
You must be signed in to change notification settings - Fork 15.3k
[mlir][ArmSME] Fix crash on empty vector.mask in arm-sme-vector-legalization #118613
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-sme Author: Benjamin Maxwell (MacDue) ChangesFixes #118449 Full diff: https://github.com/llvm/llvm-project/pull/118613.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
index e908a536e6fb27..61767f3b21c9c3 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
@@ -265,8 +265,8 @@ struct LegalizeMaskedVectorOuterProductOpsByDecomposition
LogicalResult
matchAndRewrite(vector::MaskOp maskOp, OpAdaptor adaptor,
OneToNPatternRewriter &rewriter) const override {
- if (auto outerProductOp =
- llvm::dyn_cast<vector::OuterProductOp>(maskOp.getMaskableOp())) {
+ if (auto outerProductOp = llvm::dyn_cast_or_null<vector::OuterProductOp>(
+ maskOp.getMaskableOp())) {
LegalizeVectorOuterProductOpsByDecomposition pattern(*getTypeConverter(),
getContext());
return static_cast<RewritePattern &>(pattern).matchAndRewrite(
diff --git a/mlir/test/Dialect/ArmSME/vector-legalization.mlir b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
index 458906a1879829..2f33007720258b 100644
--- a/mlir/test/Dialect/ArmSME/vector-legalization.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
@@ -646,3 +646,11 @@ func.func @negative_transpose_store_scalable_via_za__bad_source_shape(%vec: vect
vector.transfer_write %tr, %dest[%i, %j] {in_bounds = [true, true]} : vector<[7]x2xf32>, memref<?x?xf32>
return
}
+
+// -----
+
+// From: https://github.com/llvm/llvm-project/issues/118449 (check we don't crash).
+func.func @vector_mask_empty(%m0: vector<16x2xi1>, %arg1: vector<16x16xf32>) -> vector<16x16xf32> {
+ %0 = vector.mask %m0 { vector.yield %arg1 : vector<16x16xf32> } : vector<16x2xi1> -> vector<16x16xf32>
+ return %0 : vector<16x16xf32>
+}
|
banach-space
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the fix Ben and for paying attention to GitHub issues 🙏🏻
Makes sense, but I'd appreciate some clarification in comments and code. In particular, "crash on empty vector.mask" -> "crash on vector.mask masking an empty block"? That wasn't immediately obvious to me.
banach-space
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
Fixes #118449