From 09a240bf67d90dba6d93095bfa6a8ac324d1401d Mon Sep 17 00:00:00 2001 From: Momchil Velikov Date: Fri, 20 Jun 2025 16:16:05 +0000 Subject: [PATCH 1/3] [MLIR][ArmSVE] Add an ArmSVE dialect operation mapping to `bfmmla` --- mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td | 35 +++++++++++++++++++ .../Transforms/LegalizeForLLVMExport.cpp | 10 ++++-- .../Dialect/ArmSVE/legalize-for-llvm.mlir | 9 +++++ mlir/test/Dialect/ArmSVE/roundtrip.mlir | 10 ++++++ mlir/test/Target/LLVMIR/arm-sve.mlir | 12 +++++++ 5 files changed, 73 insertions(+), 3 deletions(-) diff --git a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td index 7385bb73b449a..c4007dd02c0d3 100644 --- a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td +++ b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td @@ -293,6 +293,35 @@ def UsmmlaOp : ArmSVE_Op<"usmmla", [Pure, "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)"; } + +def BfmmlaOp : ArmSVE_Op<"bfmmla", [Pure, + AllTypesMatch<["src1", "src2"]>, + AllTypesMatch<["acc", "dst"]>]> { + let summary = "BFloat16 matrix multiply-accumulate"; + let description = [{ + BFMMLA: BFloat16 matrix multiply-accumulate into 2×2 matrices"; + + This operation multiplies the 2x4 BFloat16 matrix held in each 128-bit + segment of the first source vector by the 4x2 BFloat16 matrix in the + corresponding segment of the second source vector, then accumulates + this intermediate result with the 2x2 Float32 matrix in the corresponding + segment of the accumulator vector, yielding the final 2x2 Float32 + segment of the result. + + Source: + https://developer.arm.com/documentation/100987/0000 + }]; + // Supports (vector<[8]xbf16>, vector<[8]xbf16>) -> (vector<[4]xf32>) + let arguments = (ins + ScalableVectorOfLengthAndType<[4], [F32]>:$acc, + ScalableVectorOfLengthAndType<[8], [BF16]>:$src1, + ScalableVectorOfLengthAndType<[8], [BF16]>:$src2 + ); + let results = (outs ScalableVectorOfLengthAndType<[4], [F32]>:$dst); + let assemblyFormat = + "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)"; +} + class SvboolTypeConstraint : TypesMatchWith< "expected corresponding svbool type widened to [16]xi1", lhsArg, rhsArg, @@ -590,6 +619,12 @@ def UsmmlaIntrOp : ArmSVE_IntrBinaryOverloadedOp<"usmmla">, Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>; +def BfmmlaIntrOp : + ArmSVE_IntrOp<"bfmmla", [Pure, TypeIs<"res", ScalableVectorOfLengthAndType<[4], [F32]>>]>, + Arguments<(ins Arg, "acc">:$acc, + Arg, "lhs">:$lhs, + Arg, "rhs">:$rhs)>; + def SdotIntrOp : ArmSVE_IntrBinaryOverloadedOp<"sdot">, Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>; diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp index 35f2a02cc4ec6..73f388b6d81c0 100644 --- a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp +++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp @@ -25,6 +25,7 @@ using SmmlaOpLowering = OneToOneConvertToLLVMPattern; using UdotOpLowering = OneToOneConvertToLLVMPattern; using UmmlaOpLowering = OneToOneConvertToLLVMPattern; using UsmmlaOpLowering = OneToOneConvertToLLVMPattern; +using BfmmlaOpLowering = OneToOneConvertToLLVMPattern; using DupQLaneLowering = OneToOneConvertToLLVMPattern; using ScalableMaskedAddIOpLowering = @@ -191,7 +192,8 @@ void mlir::populateArmSVELegalizeForLLVMExportPatterns( // Populate conversion patterns // clang-format off - patterns.add(); - target.addIllegalOp, // ----- +func.func @arm_sve_bfmmla(%a: vector<[8]xbf16>, + %b: vector<[8]xbf16>, + %c: vector<[4]xf32>) -> vector<[4]xf32> { + // CHECK: arm_sve.intr.bfmmla + %0 = arm_sve.bfmmla %c, %a, %b : vector<[8]xbf16> to vector<[4]xf32> + return %0 : vector<[4]xf32> +} +// ----- + func.func @arm_sve_arithi_masked(%a: vector<[4]xi32>, %b: vector<[4]xi32>, %c: vector<[4]xi32>, diff --git a/mlir/test/Dialect/ArmSVE/roundtrip.mlir b/mlir/test/Dialect/ArmSVE/roundtrip.mlir index 64e0cff39eb06..9a653df767400 100644 --- a/mlir/test/Dialect/ArmSVE/roundtrip.mlir +++ b/mlir/test/Dialect/ArmSVE/roundtrip.mlir @@ -55,6 +55,16 @@ func.func @arm_sve_usmmla(%a: vector<[16]xi8>, // ----- +func.func @arm_sve_bfmmla(%a: vector<[8]xbf16>, + %b: vector<[8]xbf16>, + %c: vector<[4]xf32>) -> vector<[4]xf32> { + // CHECK: arm_sve.bfmmla {{.*}}: vector<[8]xbf16> to vector<[4]xf32> + %0 = arm_sve.bfmmla %c, %a, %b : vector<[8]xbf16> to vector<[4]xf32> + return %0 : vector<[4]xf32> +} + +// ----- + func.func @arm_sve_masked_arithi(%a: vector<[4]xi32>, %b: vector<[4]xi32>, %c: vector<[4]xi32>, diff --git a/mlir/test/Target/LLVMIR/arm-sve.mlir b/mlir/test/Target/LLVMIR/arm-sve.mlir index da71cb5a63bd2..737145c74e331 100644 --- a/mlir/test/Target/LLVMIR/arm-sve.mlir +++ b/mlir/test/Target/LLVMIR/arm-sve.mlir @@ -60,6 +60,18 @@ llvm.func @arm_sve_usmmla(%arg0: vector<[16]xi8>, llvm.return %0 : vector<[4]xi32> } +// CHECK-LABEL: define @arm_sve_bfmmla +llvm.func @arm_sve_bfmmla(%arg0: vector<[8]xbf16>, + %arg1: vector<[8]xbf16>, + %arg2: vector<[4]xf32>) + -> vector<[4]xf32> { + // CHECK: call @llvm.aarch64.sve.bfmmla( + %0 = "arm_sve.intr.bfmmla"(%arg2, %arg0, %arg1) : + (vector<[4]xf32>, vector<[8]xbf16>, vector<[8]xbf16>) + -> vector<[4]xf32> + llvm.return %0 : vector<[4]xf32> +} + // CHECK-LABEL: define @arm_sve_arithi llvm.func @arm_sve_arithi(%arg0: vector<[4]xi32>, %arg1: vector<[4]xi32>, From 690b549719f86feac8ff9c04474cba1b234f27b2 Mon Sep 17 00:00:00 2001 From: Momchil Velikov Date: Wed, 25 Jun 2025 13:12:04 +0000 Subject: [PATCH 2/3] [fixup] Skip the two-stage LLVM IR generation, map the op directly to the LLVM IR intrinsic --- mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td | 18 ++++++------------ .../Transforms/LegalizeForLLVMExport.cpp | 9 +++------ .../test/Dialect/ArmSVE/legalize-for-llvm.mlir | 9 --------- mlir/test/Dialect/ArmSVE/roundtrip.mlir | 4 ++-- 4 files changed, 11 insertions(+), 29 deletions(-) diff --git a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td index c4007dd02c0d3..8988df680b8f9 100644 --- a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td +++ b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td @@ -293,10 +293,10 @@ def UsmmlaOp : ArmSVE_Op<"usmmla", [Pure, "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)"; } - -def BfmmlaOp : ArmSVE_Op<"bfmmla", [Pure, - AllTypesMatch<["src1", "src2"]>, - AllTypesMatch<["acc", "dst"]>]> { +def BfmmlaOp : ArmSVE_IntrOp<"bfmmla", [Pure, + AllTypesMatch<["src1", "src2"]>, + AllTypesMatch<["acc", "res"]>, + ]> { let summary = "BFloat16 matrix multiply-accumulate"; let description = [{ BFMMLA: BFloat16 matrix multiply-accumulate into 2×2 matrices"; @@ -317,9 +317,9 @@ def BfmmlaOp : ArmSVE_Op<"bfmmla", [Pure, ScalableVectorOfLengthAndType<[8], [BF16]>:$src1, ScalableVectorOfLengthAndType<[8], [BF16]>:$src2 ); - let results = (outs ScalableVectorOfLengthAndType<[4], [F32]>:$dst); + let results = (outs ScalableVectorOfLengthAndType<[4], [F32]>:$res); let assemblyFormat = - "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)"; + "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($res)"; } class SvboolTypeConstraint : TypesMatchWith< @@ -619,12 +619,6 @@ def UsmmlaIntrOp : ArmSVE_IntrBinaryOverloadedOp<"usmmla">, Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>; -def BfmmlaIntrOp : - ArmSVE_IntrOp<"bfmmla", [Pure, TypeIs<"res", ScalableVectorOfLengthAndType<[4], [F32]>>]>, - Arguments<(ins Arg, "acc">:$acc, - Arg, "lhs">:$lhs, - Arg, "rhs">:$rhs)>; - def SdotIntrOp : ArmSVE_IntrBinaryOverloadedOp<"sdot">, Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>; diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp index 73f388b6d81c0..006332b48325f 100644 --- a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp +++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp @@ -25,7 +25,6 @@ using SmmlaOpLowering = OneToOneConvertToLLVMPattern; using UdotOpLowering = OneToOneConvertToLLVMPattern; using UmmlaOpLowering = OneToOneConvertToLLVMPattern; using UsmmlaOpLowering = OneToOneConvertToLLVMPattern; -using BfmmlaOpLowering = OneToOneConvertToLLVMPattern; using DupQLaneLowering = OneToOneConvertToLLVMPattern; using ScalableMaskedAddIOpLowering = @@ -192,8 +191,7 @@ void mlir::populateArmSVELegalizeForLLVMExportPatterns( // Populate conversion patterns // clang-format off - patterns.add(); - target.addIllegalOp, // ----- -func.func @arm_sve_bfmmla(%a: vector<[8]xbf16>, - %b: vector<[8]xbf16>, - %c: vector<[4]xf32>) -> vector<[4]xf32> { - // CHECK: arm_sve.intr.bfmmla - %0 = arm_sve.bfmmla %c, %a, %b : vector<[8]xbf16> to vector<[4]xf32> - return %0 : vector<[4]xf32> -} -// ----- - func.func @arm_sve_arithi_masked(%a: vector<[4]xi32>, %b: vector<[4]xi32>, %c: vector<[4]xi32>, diff --git a/mlir/test/Dialect/ArmSVE/roundtrip.mlir b/mlir/test/Dialect/ArmSVE/roundtrip.mlir index 9a653df767400..b7b9329f1cb5a 100644 --- a/mlir/test/Dialect/ArmSVE/roundtrip.mlir +++ b/mlir/test/Dialect/ArmSVE/roundtrip.mlir @@ -58,8 +58,8 @@ func.func @arm_sve_usmmla(%a: vector<[16]xi8>, func.func @arm_sve_bfmmla(%a: vector<[8]xbf16>, %b: vector<[8]xbf16>, %c: vector<[4]xf32>) -> vector<[4]xf32> { - // CHECK: arm_sve.bfmmla {{.*}}: vector<[8]xbf16> to vector<[4]xf32> - %0 = arm_sve.bfmmla %c, %a, %b : vector<[8]xbf16> to vector<[4]xf32> + // CHECK: arm_sve.intr.bfmmla {{.*}}: vector<[8]xbf16> to vector<[4]xf32> + %0 = arm_sve.intr.bfmmla %c, %a, %b : vector<[8]xbf16> to vector<[4]xf32> return %0 : vector<[4]xf32> } From 9bcd62eb2ae16405c32bd6d5af3e5e9d9d815cf7 Mon Sep 17 00:00:00 2001 From: Momchil Velikov Date: Thu, 26 Jun 2025 12:24:32 +0000 Subject: [PATCH 3/3] [fixup] Add some tests with invalid operands --- mlir/test/Dialect/ArmSVE/invalid.mlir | 60 +++++++++++++++++++++++++++ 1 file changed, 60 insertions(+) diff --git a/mlir/test/Dialect/ArmSVE/invalid.mlir b/mlir/test/Dialect/ArmSVE/invalid.mlir index a021d4393107c..1d2529be0560b 100644 --- a/mlir/test/Dialect/ArmSVE/invalid.mlir +++ b/mlir/test/Dialect/ArmSVE/invalid.mlir @@ -72,3 +72,63 @@ func.func @arm_sve_psel_bad_vector_type(%a : vector<[7]xi1>, %index: index) { arm_sve.psel %a, %a[%index] : vector<[7]xi1>, vector<[7]xi1> return } + +// ----- + +func.func @bfmmla_invalid_element_type_lhs_rhs(%acc: vector<[4]xf32>, + %lhs: vector<[8]xf16>, + %rhs: vector<[8]xf16>) -> vector<[4]xf32> { + // expected-error@+1 {{operand #1 must be scalable vector of bfloat16 type values of length 8, but got 'vector<[8]xf16>'}} + %0 = arm_sve.intr.bfmmla %acc, %lhs, %rhs : vector<[8]xf16> to vector<[4]xf32> + return %0 : vector<[4]xf32> +} + +// ----- + +func.func @bfmmla_invalid_dimension_lhs_rhs(%acc: vector<[4]xf32>, + %lhs: vector<[4]xbf16>, + %rhs: vector<[4]xbf16>) -> vector<[4]xf32> { + // expected-error@+1 {{operand #1 must be scalable vector of bfloat16 type values of length 8, but got 'vector<[4]xbf16>}} + %0 = arm_sve.intr.bfmmla %acc, %lhs, %rhs : vector<[4]xbf16> to vector<[4]xf32> + return %0 : vector<[4]xf32> +} + +// ----- + +func.func @bfmmla_fixed_dimension_lhs_rhs(%acc: vector<[4]xf32>, + %lhs: vector<8xbf16>, + %rhs: vector<8xbf16>) -> vector<[4]xf32> { + // expected-error@+1 {{operand #1 must be scalable vector of bfloat16 type values of length 8, but got 'vector<8xbf16>}} + %0 = arm_sve.intr.bfmmla %acc, %lhs, %rhs : vector<8xbf16> to vector<[4]xf32> + return %0 : vector<[4]xf32> +} + +// ----- + +func.func @bfmmla_invalid_element_type_acc(%acc: vector<[4]xi32>, + %lhs: vector<[8]xbf16>, + %rhs: vector<[8]xbf16>) -> vector<[4]xi32> { + // expected-error@+1 {{operand #0 must be scalable vector of 32-bit float values of length 4, but got 'vector<[4]xi32>'}} + %0 = arm_sve.intr.bfmmla %acc, %lhs, %rhs : vector<[8]xbf16> to vector<[4]xi32> + return %0 : vector<[4]xi32> +} + +// ----- + +func.func @bfmmla_invalid_dimension_acc(%acc: vector<[8]xf32>, + %lhs: vector<[8]xbf16>, + %rhs: vector<[8]xbf16>) -> vector<[8]xf32> { + // expected-error@+1 {{operand #0 must be scalable vector of 32-bit float values of length 4, but got 'vector<[8]xf32>'}} + %0 = arm_sve.intr.bfmmla %acc, %lhs, %rhs : vector<[8]xbf16> to vector<[8]xf32> + return %0 : vector<[8]xf32> +} + +// ----- + +func.func @bfmmla_fixed_dimension_acc(%acc: vector<4xf32>, + %lhs: vector<[8]xbf16>, + %rhs: vector<[8]xbf16>) -> vector<4xf32> { + // expected-error@+1 {{operand #0 must be scalable vector of 32-bit float values of length 4, but got 'vector<4xf32>'}} + %0 = arm_sve.intr.bfmmla %acc, %lhs, %rhs : vector<[8]xbf16> to vector<4xf32> + return %0 : vector<4xf32> +}