Skip to content

Conversation

@momchil-velikov
Copy link
Collaborator

No description provided.

@llvmbot
Copy link
Member

llvmbot commented Jun 20, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-llvm

Author: Momchil Velikov (momchil-velikov)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/145038.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td (+28)
  • (modified) mlir/test/Dialect/ArmNeon/roundtrip.mlir (+12)
  • (modified) mlir/test/Target/LLVMIR/arm-neon.mlir (+13)
diff --git a/mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td b/mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td
index 475b11f12c5f0..ce86ff2cfd922 100644
--- a/mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td
+++ b/mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td
@@ -222,6 +222,34 @@ def UsmmlaOp : ArmNeon_OverloadedOperandsWithOneResultIntrOp<"usmmla",[1], [
     "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($res)";
 }
 
+def BfmmlaOp : ArmNeon_IntrOp<"bfmmla", [], [], 1, [
+                 Pure,
+                 AllTypesMatch<["src1", "src2"]>,
+                 AllTypesMatch<["acc", "res"]>,
+               ]> {
+  let summary = "BFloat16 matrix multiply-accumulate to single-precision";
+  let description = [{
+    BFMMLA: BFloat16 matrix multiply-accumulate to single-precision.
+
+    The operation multiplies the 2x4 BFloat16 matrix in the first source vector
+    with the 4x2 BFloat16 matrix in the second source vector, then accumulates
+    this intermediate result with the 2x2 Float32 matrix in the accumulator
+    vector, yielding the final 2x2 Float32 result.
+
+    Source:
+    https://developer.arm.com/architectures/instruction-sets/intrinsics/vbfmmlaq_f32
+  }];
+  // Supports (vector<8xbf16>, vector<8xbf16>) -> (vector<2xf32>)
+  let arguments = (ins
+    NeonVectorOfLength<4, F32>:$acc,
+    NeonVectorOfLength<8, BF16>:$src1,
+    NeonVectorOfLength<8, BF16>:$src2
+  );
+  let results = (outs NeonVectorOfLength<4, F32>:$res);
+  let assemblyFormat =
+    "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($res)";
+}
+
 class ArmNeon_2dOp<string mnemonic, list<Trait> traits = []>
     : Op</*dialect=*/ArmNeon_Dialect,
          /*opName=*/"2d." # mnemonic,
diff --git a/mlir/test/Dialect/ArmNeon/roundtrip.mlir b/mlir/test/Dialect/ArmNeon/roundtrip.mlir
index b5df0ffa8105c..60133ce0fa6f3 100644
--- a/mlir/test/Dialect/ArmNeon/roundtrip.mlir
+++ b/mlir/test/Dialect/ArmNeon/roundtrip.mlir
@@ -60,3 +60,15 @@ func.func @arm_neon_usmmla(%a: vector<16xi8>,
   %0 = arm_neon.intr.usmmla %c, %a, %b : vector<16xi8> to vector<4xi32>
   return %0 : vector<4xi32>
 }
+
+
+// -----
+
+// CHECK-LABEL: arm_neon_bfmmla
+func.func @arm_neon_bfmmla(%a: vector<8xbf16>,
+                           %b: vector<8xbf16>,
+                           %c: vector<4xf32>) -> vector<4xf32> {
+  // CHECK: arm_neon.intr.bfmmla {{.*}}: vector<8xbf16> to vector<4xf32>
+  %0 = arm_neon.intr.bfmmla %c, %a, %b : vector<8xbf16> to vector<4xf32>
+  return %0 : vector<4xf32>
+}
diff --git a/mlir/test/Target/LLVMIR/arm-neon.mlir b/mlir/test/Target/LLVMIR/arm-neon.mlir
index e096172667c9f..e1328ad448f0a 100644
--- a/mlir/test/Target/LLVMIR/arm-neon.mlir
+++ b/mlir/test/Target/LLVMIR/arm-neon.mlir
@@ -82,3 +82,16 @@ llvm.func @arm_neon_usmmla(%arg0: vector<16xi8>,
         -> vector<4xi32>
   llvm.return %0 : vector<4xi32>
 }
+
+// -----
+
+// CHECK-LABEL: arm_neon_bfmmla
+llvm.func @arm_neon_bfmmla(%arg0: vector<8xbf16>,
+                           %arg1: vector<8xbf16>,
+                           %arg2: vector<4xf32>) -> vector<4xf32> {
+  // CHECK: <4 x float> @llvm.aarch64.neon.bfmmla(<4 x float
+  %0 = "arm_neon.intr.bfmmla"(%arg2, %arg0, %arg1) :
+    (vector<4xf32>, vector<8xbf16>, vector<8xbf16>)
+        -> vector<4xf32>
+  llvm.return %0 : vector<4xf32>
+}

@llvmbot
Copy link
Member

llvmbot commented Jun 20, 2025

@llvm/pr-subscribers-mlir-neon

Author: Momchil Velikov (momchil-velikov)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/145038.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td (+28)
  • (modified) mlir/test/Dialect/ArmNeon/roundtrip.mlir (+12)
  • (modified) mlir/test/Target/LLVMIR/arm-neon.mlir (+13)
diff --git a/mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td b/mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td
index 475b11f12c5f0..ce86ff2cfd922 100644
--- a/mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td
+++ b/mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td
@@ -222,6 +222,34 @@ def UsmmlaOp : ArmNeon_OverloadedOperandsWithOneResultIntrOp<"usmmla",[1], [
     "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($res)";
 }
 
+def BfmmlaOp : ArmNeon_IntrOp<"bfmmla", [], [], 1, [
+                 Pure,
+                 AllTypesMatch<["src1", "src2"]>,
+                 AllTypesMatch<["acc", "res"]>,
+               ]> {
+  let summary = "BFloat16 matrix multiply-accumulate to single-precision";
+  let description = [{
+    BFMMLA: BFloat16 matrix multiply-accumulate to single-precision.
+
+    The operation multiplies the 2x4 BFloat16 matrix in the first source vector
+    with the 4x2 BFloat16 matrix in the second source vector, then accumulates
+    this intermediate result with the 2x2 Float32 matrix in the accumulator
+    vector, yielding the final 2x2 Float32 result.
+
+    Source:
+    https://developer.arm.com/architectures/instruction-sets/intrinsics/vbfmmlaq_f32
+  }];
+  // Supports (vector<8xbf16>, vector<8xbf16>) -> (vector<2xf32>)
+  let arguments = (ins
+    NeonVectorOfLength<4, F32>:$acc,
+    NeonVectorOfLength<8, BF16>:$src1,
+    NeonVectorOfLength<8, BF16>:$src2
+  );
+  let results = (outs NeonVectorOfLength<4, F32>:$res);
+  let assemblyFormat =
+    "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($res)";
+}
+
 class ArmNeon_2dOp<string mnemonic, list<Trait> traits = []>
     : Op</*dialect=*/ArmNeon_Dialect,
          /*opName=*/"2d." # mnemonic,
diff --git a/mlir/test/Dialect/ArmNeon/roundtrip.mlir b/mlir/test/Dialect/ArmNeon/roundtrip.mlir
index b5df0ffa8105c..60133ce0fa6f3 100644
--- a/mlir/test/Dialect/ArmNeon/roundtrip.mlir
+++ b/mlir/test/Dialect/ArmNeon/roundtrip.mlir
@@ -60,3 +60,15 @@ func.func @arm_neon_usmmla(%a: vector<16xi8>,
   %0 = arm_neon.intr.usmmla %c, %a, %b : vector<16xi8> to vector<4xi32>
   return %0 : vector<4xi32>
 }
+
+
+// -----
+
+// CHECK-LABEL: arm_neon_bfmmla
+func.func @arm_neon_bfmmla(%a: vector<8xbf16>,
+                           %b: vector<8xbf16>,
+                           %c: vector<4xf32>) -> vector<4xf32> {
+  // CHECK: arm_neon.intr.bfmmla {{.*}}: vector<8xbf16> to vector<4xf32>
+  %0 = arm_neon.intr.bfmmla %c, %a, %b : vector<8xbf16> to vector<4xf32>
+  return %0 : vector<4xf32>
+}
diff --git a/mlir/test/Target/LLVMIR/arm-neon.mlir b/mlir/test/Target/LLVMIR/arm-neon.mlir
index e096172667c9f..e1328ad448f0a 100644
--- a/mlir/test/Target/LLVMIR/arm-neon.mlir
+++ b/mlir/test/Target/LLVMIR/arm-neon.mlir
@@ -82,3 +82,16 @@ llvm.func @arm_neon_usmmla(%arg0: vector<16xi8>,
         -> vector<4xi32>
   llvm.return %0 : vector<4xi32>
 }
+
+// -----
+
+// CHECK-LABEL: arm_neon_bfmmla
+llvm.func @arm_neon_bfmmla(%arg0: vector<8xbf16>,
+                           %arg1: vector<8xbf16>,
+                           %arg2: vector<4xf32>) -> vector<4xf32> {
+  // CHECK: <4 x float> @llvm.aarch64.neon.bfmmla(<4 x float
+  %0 = "arm_neon.intr.bfmmla"(%arg2, %arg0, %arg1) :
+    (vector<4xf32>, vector<8xbf16>, vector<8xbf16>)
+        -> vector<4xf32>
+  llvm.return %0 : vector<4xf32>
+}

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

Could you add tests in invalid.mlir as well?

@momchil-velikov
Copy link
Collaborator Author

Thanks!

Could you add tests in invalid.mlir as well?

Tests added.

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

@momchil-velikov momchil-velikov merged commit a226542 into main Jun 26, 2025
7 checks passed
@momchil-velikov momchil-velikov deleted the users/momchil-velikov/bfmmla-neon branch June 26, 2025 09:43
@momchil-velikov
Copy link
Collaborator Author

Eh, forgot to add negative tests with scalable vectors. Will add them in #145064

@momchil-velikov
Copy link
Collaborator Author

Actually created a new PR: #145882

anthonyhatran pushed a commit to anthonyhatran/llvm-project that referenced this pull request Jun 26, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants