Skip to content

Commit 154b104

Browse files
[MLIR][ArmNeon] Add an ArmNeon operation which maps to bfmmla
1 parent eb694b2 commit 154b104

File tree

3 files changed

+53
-0
lines changed

3 files changed

+53
-0
lines changed

mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,34 @@ def UsmmlaOp : ArmNeon_OverloadedOperandsWithOneResultIntrOp<"usmmla",[1], [
222222
"$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($res)";
223223
}
224224

225+
def BfmmlaOp : ArmNeon_IntrOp<"bfmmla", [], [], 1, [
226+
Pure,
227+
AllTypesMatch<["src1", "src2"]>,
228+
AllTypesMatch<["acc", "res"]>,
229+
]> {
230+
let summary = "BFloat16 matrix multiply-accumulate to single-precision";
231+
let description = [{
232+
BFMMLA: BFloat16 matrix multiply-accumulate to single-precision.
233+
234+
The operation multiplies the 2x4 BFloat16 matrix in the first source vector
235+
with the 4x2 BFloat16 matrix in the second source vector, then accumulates
236+
this intermediate result with the 2x2 Float32 matrix in the accumulator
237+
vector, yielding the final 2x2 Float32 result.
238+
239+
Source:
240+
https://developer.arm.com/architectures/instruction-sets/intrinsics/vbfmmlaq_f32
241+
}];
242+
// Supports (vector<8xbf16>, vector<8xbf16>) -> (vector<2xf32>)
243+
let arguments = (ins
244+
NeonVectorOfLength<4, F32>:$acc,
245+
NeonVectorOfLength<8, BF16>:$src1,
246+
NeonVectorOfLength<8, BF16>:$src2
247+
);
248+
let results = (outs NeonVectorOfLength<4, F32>:$res);
249+
let assemblyFormat =
250+
"$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($res)";
251+
}
252+
225253
class ArmNeon_2dOp<string mnemonic, list<Trait> traits = []>
226254
: Op</*dialect=*/ArmNeon_Dialect,
227255
/*opName=*/"2d." # mnemonic,

mlir/test/Dialect/ArmNeon/roundtrip.mlir

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,3 +60,15 @@ func.func @arm_neon_usmmla(%a: vector<16xi8>,
6060
%0 = arm_neon.intr.usmmla %c, %a, %b : vector<16xi8> to vector<4xi32>
6161
return %0 : vector<4xi32>
6262
}
63+
64+
65+
// -----
66+
67+
// CHECK-LABEL: arm_neon_bfmmla
68+
func.func @arm_neon_bfmmla(%a: vector<8xbf16>,
69+
%b: vector<8xbf16>,
70+
%c: vector<4xf32>) -> vector<4xf32> {
71+
// CHECK: arm_neon.intr.bfmmla {{.*}}: vector<8xbf16> to vector<4xf32>
72+
%0 = arm_neon.intr.bfmmla %c, %a, %b : vector<8xbf16> to vector<4xf32>
73+
return %0 : vector<4xf32>
74+
}

mlir/test/Target/LLVMIR/arm-neon.mlir

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,3 +82,16 @@ llvm.func @arm_neon_usmmla(%arg0: vector<16xi8>,
8282
-> vector<4xi32>
8383
llvm.return %0 : vector<4xi32>
8484
}
85+
86+
// -----
87+
88+
// CHECK-LABEL: arm_neon_bfmmla
89+
llvm.func @arm_neon_bfmmla(%arg0: vector<8xbf16>,
90+
%arg1: vector<8xbf16>,
91+
%arg2: vector<4xf32>) -> vector<4xf32> {
92+
// CHECK: <4 x float> @llvm.aarch64.neon.bfmmla(<4 x float
93+
%0 = "arm_neon.intr.bfmmla"(%arg2, %arg0, %arg1) :
94+
(vector<4xf32>, vector<8xbf16>, vector<8xbf16>)
95+
-> vector<4xf32>
96+
llvm.return %0 : vector<4xf32>
97+
}

0 commit comments

Comments
 (0)