Skip to content

Conversation

@MacDue
Copy link
Member

@MacDue MacDue commented Dec 4, 2024

Fixes #118449

@llvmbot
Copy link
Member

llvmbot commented Dec 4, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-sme

Author: Benjamin Maxwell (MacDue)

Changes

Fixes #118449


Full diff: https://github.com/llvm/llvm-project/pull/118613.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp (+2-2)
  • (modified) mlir/test/Dialect/ArmSME/vector-legalization.mlir (+8)
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
index e908a536e6fb27..61767f3b21c9c3 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
@@ -265,8 +265,8 @@ struct LegalizeMaskedVectorOuterProductOpsByDecomposition
   LogicalResult
   matchAndRewrite(vector::MaskOp maskOp, OpAdaptor adaptor,
                   OneToNPatternRewriter &rewriter) const override {
-    if (auto outerProductOp =
-            llvm::dyn_cast<vector::OuterProductOp>(maskOp.getMaskableOp())) {
+    if (auto outerProductOp = llvm::dyn_cast_or_null<vector::OuterProductOp>(
+            maskOp.getMaskableOp())) {
       LegalizeVectorOuterProductOpsByDecomposition pattern(*getTypeConverter(),
                                                            getContext());
       return static_cast<RewritePattern &>(pattern).matchAndRewrite(
diff --git a/mlir/test/Dialect/ArmSME/vector-legalization.mlir b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
index 458906a1879829..2f33007720258b 100644
--- a/mlir/test/Dialect/ArmSME/vector-legalization.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
@@ -646,3 +646,11 @@ func.func @negative_transpose_store_scalable_via_za__bad_source_shape(%vec: vect
   vector.transfer_write %tr, %dest[%i, %j] {in_bounds = [true, true]} : vector<[7]x2xf32>,  memref<?x?xf32>
   return
 }
+
+// -----
+
+// From: https://github.com/llvm/llvm-project/issues/118449 (check we don't crash).
+func.func @vector_mask_empty(%m0: vector<16x2xi1>, %arg1: vector<16x16xf32>) -> vector<16x16xf32> {
+  %0 = vector.mask %m0 { vector.yield %arg1 : vector<16x16xf32> } : vector<16x2xi1> -> vector<16x16xf32>
+  return %0 : vector<16x16xf32>
+}

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix Ben and for paying attention to GitHub issues 🙏🏻

Makes sense, but I'd appreciate some clarification in comments and code. In particular, "crash on empty vector.mask" -> "crash on vector.mask masking an empty block"? That wasn't immediately obvious to me.

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

@MacDue MacDue merged commit a9eb8f0 into llvm:main Dec 5, 2024
8 checks passed
@MacDue MacDue deleted the crash_fix_mlir branch December 5, 2024 09:11
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[mlir] -arm-sme-vector-legalization crashes

4 participants