From d76f30cf7fd32f40fbe464058568cef767cd7899 Mon Sep 17 00:00:00 2001 From: Benjamin Maxwell Date: Wed, 4 Dec 2024 11:17:54 +0000 Subject: [PATCH 1/2] [mlir][ArmSME] Fix crash on empty vector.mask in arm-sme-vector-legalization Fixes #118449 --- mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp | 4 ++-- mlir/test/Dialect/ArmSME/vector-legalization.mlir | 8 ++++++++ 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp index e908a536e6fb2..61767f3b21c9c 100644 --- a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp +++ b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp @@ -265,8 +265,8 @@ struct LegalizeMaskedVectorOuterProductOpsByDecomposition LogicalResult matchAndRewrite(vector::MaskOp maskOp, OpAdaptor adaptor, OneToNPatternRewriter &rewriter) const override { - if (auto outerProductOp = - llvm::dyn_cast(maskOp.getMaskableOp())) { + if (auto outerProductOp = llvm::dyn_cast_or_null( + maskOp.getMaskableOp())) { LegalizeVectorOuterProductOpsByDecomposition pattern(*getTypeConverter(), getContext()); return static_cast(pattern).matchAndRewrite( diff --git a/mlir/test/Dialect/ArmSME/vector-legalization.mlir b/mlir/test/Dialect/ArmSME/vector-legalization.mlir index 458906a187982..2f33007720258 100644 --- a/mlir/test/Dialect/ArmSME/vector-legalization.mlir +++ b/mlir/test/Dialect/ArmSME/vector-legalization.mlir @@ -646,3 +646,11 @@ func.func @negative_transpose_store_scalable_via_za__bad_source_shape(%vec: vect vector.transfer_write %tr, %dest[%i, %j] {in_bounds = [true, true]} : vector<[7]x2xf32>, memref return } + +// ----- + +// From: https://github.com/llvm/llvm-project/issues/118449 (check we don't crash). +func.func @vector_mask_empty(%m0: vector<16x2xi1>, %arg1: vector<16x16xf32>) -> vector<16x16xf32> { + %0 = vector.mask %m0 { vector.yield %arg1 : vector<16x16xf32> } : vector<16x2xi1> -> vector<16x16xf32> + return %0 : vector<16x16xf32> +} From 56c87699684916f2273cbc515355d4f85695f1c1 Mon Sep 17 00:00:00 2001 From: Benjamin Maxwell Date: Wed, 4 Dec 2024 17:50:49 +0000 Subject: [PATCH 2/2] Fixup --- mlir/test/Dialect/ArmSME/vector-legalization.mlir | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/mlir/test/Dialect/ArmSME/vector-legalization.mlir b/mlir/test/Dialect/ArmSME/vector-legalization.mlir index 2f33007720258..d56df9814f173 100644 --- a/mlir/test/Dialect/ArmSME/vector-legalization.mlir +++ b/mlir/test/Dialect/ArmSME/vector-legalization.mlir @@ -649,8 +649,10 @@ func.func @negative_transpose_store_scalable_via_za__bad_source_shape(%vec: vect // ----- -// From: https://github.com/llvm/llvm-project/issues/118449 (check we don't crash). -func.func @vector_mask_empty(%m0: vector<16x2xi1>, %arg1: vector<16x16xf32>) -> vector<16x16xf32> { - %0 = vector.mask %m0 { vector.yield %arg1 : vector<16x16xf32> } : vector<16x2xi1> -> vector<16x16xf32> +// From: https://github.com/llvm/llvm-project/issues/118449. +// Check -arm-sme-vector-legalization does not crash when it encounters a `vector.mask` that +// does not contain a maskable op. +func.func @vector_mask_without_maskable_op(%mask: vector<16x2xi1>, %vec: vector<16x16xf32>) -> vector<16x16xf32> { + %0 = vector.mask %mask { vector.yield %vec : vector<16x16xf32> } : vector<16x2xi1> -> vector<16x16xf32> return %0 : vector<16x16xf32> }