Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -265,8 +265,8 @@ struct LegalizeMaskedVectorOuterProductOpsByDecomposition
LogicalResult
matchAndRewrite(vector::MaskOp maskOp, OpAdaptor adaptor,
OneToNPatternRewriter &rewriter) const override {
if (auto outerProductOp =
llvm::dyn_cast<vector::OuterProductOp>(maskOp.getMaskableOp())) {
if (auto outerProductOp = llvm::dyn_cast_or_null<vector::OuterProductOp>(
maskOp.getMaskableOp())) {
LegalizeVectorOuterProductOpsByDecomposition pattern(*getTypeConverter(),
getContext());
return static_cast<RewritePattern &>(pattern).matchAndRewrite(
Expand Down
8 changes: 8 additions & 0 deletions mlir/test/Dialect/ArmSME/vector-legalization.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -646,3 +646,11 @@ func.func @negative_transpose_store_scalable_via_za__bad_source_shape(%vec: vect
vector.transfer_write %tr, %dest[%i, %j] {in_bounds = [true, true]} : vector<[7]x2xf32>, memref<?x?xf32>
return
}

// -----

// From: https://github.com/llvm/llvm-project/issues/118449 (check we don't crash).
func.func @vector_mask_empty(%m0: vector<16x2xi1>, %arg1: vector<16x16xf32>) -> vector<16x16xf32> {
%0 = vector.mask %m0 { vector.yield %arg1 : vector<16x16xf32> } : vector<16x2xi1> -> vector<16x16xf32>
return %0 : vector<16x16xf32>
}
Loading