-
Notifications
You must be signed in to change notification settings - Fork 15.4k
[MLIR][ArmNeon] Add an ArmNeon operation which maps to bfmmla
#145038
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-llvm Author: Momchil Velikov (momchil-velikov) ChangesFull diff: https://github.com/llvm/llvm-project/pull/145038.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td b/mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td
index 475b11f12c5f0..ce86ff2cfd922 100644
--- a/mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td
+++ b/mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td
@@ -222,6 +222,34 @@ def UsmmlaOp : ArmNeon_OverloadedOperandsWithOneResultIntrOp<"usmmla",[1], [
"$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($res)";
}
+def BfmmlaOp : ArmNeon_IntrOp<"bfmmla", [], [], 1, [
+ Pure,
+ AllTypesMatch<["src1", "src2"]>,
+ AllTypesMatch<["acc", "res"]>,
+ ]> {
+ let summary = "BFloat16 matrix multiply-accumulate to single-precision";
+ let description = [{
+ BFMMLA: BFloat16 matrix multiply-accumulate to single-precision.
+
+ The operation multiplies the 2x4 BFloat16 matrix in the first source vector
+ with the 4x2 BFloat16 matrix in the second source vector, then accumulates
+ this intermediate result with the 2x2 Float32 matrix in the accumulator
+ vector, yielding the final 2x2 Float32 result.
+
+ Source:
+ https://developer.arm.com/architectures/instruction-sets/intrinsics/vbfmmlaq_f32
+ }];
+ // Supports (vector<8xbf16>, vector<8xbf16>) -> (vector<2xf32>)
+ let arguments = (ins
+ NeonVectorOfLength<4, F32>:$acc,
+ NeonVectorOfLength<8, BF16>:$src1,
+ NeonVectorOfLength<8, BF16>:$src2
+ );
+ let results = (outs NeonVectorOfLength<4, F32>:$res);
+ let assemblyFormat =
+ "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($res)";
+}
+
class ArmNeon_2dOp<string mnemonic, list<Trait> traits = []>
: Op</*dialect=*/ArmNeon_Dialect,
/*opName=*/"2d." # mnemonic,
diff --git a/mlir/test/Dialect/ArmNeon/roundtrip.mlir b/mlir/test/Dialect/ArmNeon/roundtrip.mlir
index b5df0ffa8105c..60133ce0fa6f3 100644
--- a/mlir/test/Dialect/ArmNeon/roundtrip.mlir
+++ b/mlir/test/Dialect/ArmNeon/roundtrip.mlir
@@ -60,3 +60,15 @@ func.func @arm_neon_usmmla(%a: vector<16xi8>,
%0 = arm_neon.intr.usmmla %c, %a, %b : vector<16xi8> to vector<4xi32>
return %0 : vector<4xi32>
}
+
+
+// -----
+
+// CHECK-LABEL: arm_neon_bfmmla
+func.func @arm_neon_bfmmla(%a: vector<8xbf16>,
+ %b: vector<8xbf16>,
+ %c: vector<4xf32>) -> vector<4xf32> {
+ // CHECK: arm_neon.intr.bfmmla {{.*}}: vector<8xbf16> to vector<4xf32>
+ %0 = arm_neon.intr.bfmmla %c, %a, %b : vector<8xbf16> to vector<4xf32>
+ return %0 : vector<4xf32>
+}
diff --git a/mlir/test/Target/LLVMIR/arm-neon.mlir b/mlir/test/Target/LLVMIR/arm-neon.mlir
index e096172667c9f..e1328ad448f0a 100644
--- a/mlir/test/Target/LLVMIR/arm-neon.mlir
+++ b/mlir/test/Target/LLVMIR/arm-neon.mlir
@@ -82,3 +82,16 @@ llvm.func @arm_neon_usmmla(%arg0: vector<16xi8>,
-> vector<4xi32>
llvm.return %0 : vector<4xi32>
}
+
+// -----
+
+// CHECK-LABEL: arm_neon_bfmmla
+llvm.func @arm_neon_bfmmla(%arg0: vector<8xbf16>,
+ %arg1: vector<8xbf16>,
+ %arg2: vector<4xf32>) -> vector<4xf32> {
+ // CHECK: <4 x float> @llvm.aarch64.neon.bfmmla(<4 x float
+ %0 = "arm_neon.intr.bfmmla"(%arg2, %arg0, %arg1) :
+ (vector<4xf32>, vector<8xbf16>, vector<8xbf16>)
+ -> vector<4xf32>
+ llvm.return %0 : vector<4xf32>
+}
|
|
@llvm/pr-subscribers-mlir-neon Author: Momchil Velikov (momchil-velikov) ChangesFull diff: https://github.com/llvm/llvm-project/pull/145038.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td b/mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td
index 475b11f12c5f0..ce86ff2cfd922 100644
--- a/mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td
+++ b/mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td
@@ -222,6 +222,34 @@ def UsmmlaOp : ArmNeon_OverloadedOperandsWithOneResultIntrOp<"usmmla",[1], [
"$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($res)";
}
+def BfmmlaOp : ArmNeon_IntrOp<"bfmmla", [], [], 1, [
+ Pure,
+ AllTypesMatch<["src1", "src2"]>,
+ AllTypesMatch<["acc", "res"]>,
+ ]> {
+ let summary = "BFloat16 matrix multiply-accumulate to single-precision";
+ let description = [{
+ BFMMLA: BFloat16 matrix multiply-accumulate to single-precision.
+
+ The operation multiplies the 2x4 BFloat16 matrix in the first source vector
+ with the 4x2 BFloat16 matrix in the second source vector, then accumulates
+ this intermediate result with the 2x2 Float32 matrix in the accumulator
+ vector, yielding the final 2x2 Float32 result.
+
+ Source:
+ https://developer.arm.com/architectures/instruction-sets/intrinsics/vbfmmlaq_f32
+ }];
+ // Supports (vector<8xbf16>, vector<8xbf16>) -> (vector<2xf32>)
+ let arguments = (ins
+ NeonVectorOfLength<4, F32>:$acc,
+ NeonVectorOfLength<8, BF16>:$src1,
+ NeonVectorOfLength<8, BF16>:$src2
+ );
+ let results = (outs NeonVectorOfLength<4, F32>:$res);
+ let assemblyFormat =
+ "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($res)";
+}
+
class ArmNeon_2dOp<string mnemonic, list<Trait> traits = []>
: Op</*dialect=*/ArmNeon_Dialect,
/*opName=*/"2d." # mnemonic,
diff --git a/mlir/test/Dialect/ArmNeon/roundtrip.mlir b/mlir/test/Dialect/ArmNeon/roundtrip.mlir
index b5df0ffa8105c..60133ce0fa6f3 100644
--- a/mlir/test/Dialect/ArmNeon/roundtrip.mlir
+++ b/mlir/test/Dialect/ArmNeon/roundtrip.mlir
@@ -60,3 +60,15 @@ func.func @arm_neon_usmmla(%a: vector<16xi8>,
%0 = arm_neon.intr.usmmla %c, %a, %b : vector<16xi8> to vector<4xi32>
return %0 : vector<4xi32>
}
+
+
+// -----
+
+// CHECK-LABEL: arm_neon_bfmmla
+func.func @arm_neon_bfmmla(%a: vector<8xbf16>,
+ %b: vector<8xbf16>,
+ %c: vector<4xf32>) -> vector<4xf32> {
+ // CHECK: arm_neon.intr.bfmmla {{.*}}: vector<8xbf16> to vector<4xf32>
+ %0 = arm_neon.intr.bfmmla %c, %a, %b : vector<8xbf16> to vector<4xf32>
+ return %0 : vector<4xf32>
+}
diff --git a/mlir/test/Target/LLVMIR/arm-neon.mlir b/mlir/test/Target/LLVMIR/arm-neon.mlir
index e096172667c9f..e1328ad448f0a 100644
--- a/mlir/test/Target/LLVMIR/arm-neon.mlir
+++ b/mlir/test/Target/LLVMIR/arm-neon.mlir
@@ -82,3 +82,16 @@ llvm.func @arm_neon_usmmla(%arg0: vector<16xi8>,
-> vector<4xi32>
llvm.return %0 : vector<4xi32>
}
+
+// -----
+
+// CHECK-LABEL: arm_neon_bfmmla
+llvm.func @arm_neon_bfmmla(%arg0: vector<8xbf16>,
+ %arg1: vector<8xbf16>,
+ %arg2: vector<4xf32>) -> vector<4xf32> {
+ // CHECK: <4 x float> @llvm.aarch64.neon.bfmmla(<4 x float
+ %0 = "arm_neon.intr.bfmmla"(%arg2, %arg0, %arg1) :
+ (vector<4xf32>, vector<8xbf16>, vector<8xbf16>)
+ -> vector<4xf32>
+ llvm.return %0 : vector<4xf32>
+}
|
banach-space
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
Could you add tests in invalid.mlir as well?
Tests added. |
banach-space
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
|
Eh, forgot to add negative tests with scalable vectors. Will add them in #145064 |
|
Actually created a new PR: #145882 |
No description provided.