Skip to content

Commit e677f83

Browse files
[AArch64] Implement NEON FP8 fused multiply-add intrinsics (non-indexed)
This patch adds the following intrinsics: float16x8_t vmlalbq_f16_mf8_fpm(float16x8_t, mfloat8x16_t, mfloat8x16_t, fpm_t) float16x8_t vmlaltq_f16_mf8_fpm(float16x8_t, mfloat8x16_t, mfloat8x16_t, fpm_t) float32x4_t vmlallbbq_f32_mf8_fpm(float32x4_t, mfloat8x16_t, mfloat8x16_t, fpm_t) float32x4_t vmlallbtq_f32_mf8_fpm(float32x4_t, mfloat8x16_t, mfloat8x16_t, fpm_t) float32x4_t vmlalltbq_f32_mf8_fpm(float32x4_t, mfloat8x16_t, mfloat8x16_t, fpm_t) float32x4_t vmlallttq_f32_mf8_fpm(float32x4_t, mfloat8x16_t, mfloat8x16_t, fpm_t) [fixup] Update intrinsics definitions [fixup] Remove some opt passes from RUN lines
1 parent ae09723 commit e677f83

File tree

9 files changed

+274
-22
lines changed

9 files changed

+274
-22
lines changed

clang/include/clang/Basic/arm_neon.td

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2161,6 +2161,16 @@ let ArchGuard = "defined(__aarch64__)", TargetGuard = "fp8dot4,neon" in {
21612161
def VDOTQ_LANEQ_F32_MF8 : VInst<"vdot_laneq_f32_mf8_fpm", "(>>F)(>>F)..IV", "Qm", [ImmCheck<3, ImmCheck0_3, 0>]>;
21622162
}
21632163

2164+
let ArchGuard = "defined(__aarch64__)", TargetGuard = "fp8fma,neon" in {
2165+
def VMLALB_F16_F8 : VInst<"vmlalb_f16_mf8_fpm", "(>F)(>F)..V", "Qm">;
2166+
def VMLALT_F16_F8 : VInst<"vmlalt_f16_mf8_fpm", "(>F)(>F)..V", "Qm">;
2167+
2168+
def VMLALLBB_F32_F8 : VInst<"vmlallbb_f32_mf8_fpm", "(>>F)(>>F)..V", "Qm">;
2169+
def VMLALLBT_F32_F8 : VInst<"vmlallbt_f32_mf8_fpm", "(>>F)(>>F)..V", "Qm">;
2170+
def VMLALLTB_F32_F8 : VInst<"vmlalltb_f32_mf8_fpm", "(>>F)(>>F)..V", "Qm">;
2171+
def VMLALLTT_F32_F8 : VInst<"vmlalltt_f32_mf8_fpm", "(>>F)(>>F)..V", "Qm">;
2172+
}
2173+
21642174
let ArchGuard = "defined(__aarch64__)", TargetGuard = "neon,faminmax" in {
21652175
def FAMIN : WInst<"vamin", "...", "fhQdQfQh">;
21662176
def FAMAX : WInst<"vamax", "...", "fhQdQfQh">;

clang/lib/CodeGen/CGBuiltin.cpp

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6759,11 +6759,14 @@ Value *CodeGenFunction::EmitNeonCall(Function *F, SmallVectorImpl<Value*> &Ops,
67596759
return Builder.CreateCall(F, Ops, name);
67606760
}
67616761

6762-
Value *CodeGenFunction::EmitFP8NeonCall(Function *F,
6762+
Value *CodeGenFunction::EmitFP8NeonCall(unsigned IID,
6763+
ArrayRef<llvm::Type *> Tys,
67636764
SmallVectorImpl<Value *> &Ops,
6764-
Value *FPM, const char *name) {
6765+
const CallExpr *E, const char *name) {
6766+
llvm::Value *FPM =
6767+
EmitScalarOrConstFoldImmArg(/* ICEArguments */ 0, E->getNumArgs() - 1, E);
67656768
Builder.CreateCall(CGM.getIntrinsic(Intrinsic::aarch64_set_fpmr), FPM);
6766-
return EmitNeonCall(F, Ops, name);
6769+
return EmitNeonCall(CGM.getIntrinsic(IID, Tys), Ops, name);
67676770
}
67686771

67696772
llvm::Value *CodeGenFunction::EmitFP8NeonFDOTCall(
@@ -6779,9 +6782,7 @@ llvm::Value *CodeGenFunction::EmitFP8NeonFDOTCall(
67796782
Ops[2] = Builder.CreateInsertVector(VT, PoisonValue::get(VT), Ops[2],
67806783
Builder.getInt64(0));
67816784
}
6782-
llvm::Value *FPM =
6783-
EmitScalarOrConstFoldImmArg(/* ICEArguments */ 0, E->getNumArgs() - 1, E);
6784-
return EmitFP8NeonCall(CGM.getIntrinsic(IID, Tys), Ops, FPM, name);
6785+
return EmitFP8NeonCall(IID, Tys, Ops, E, name);
67856786
}
67866787

67876788
Value *CodeGenFunction::EmitNeonShiftVector(Value *V, llvm::Type *Ty,
@@ -6802,9 +6803,7 @@ Value *CodeGenFunction::EmitFP8NeonCvtCall(unsigned IID, llvm::Type *Ty0,
68026803
Tys[1] = llvm::FixedVectorType::get(Int8Ty, 8);
68036804
Ops[0] = Builder.CreateExtractVector(Tys[1], Ops[0], Builder.getInt64(0));
68046805
}
6805-
llvm::Value *FPM =
6806-
EmitScalarOrConstFoldImmArg(/* ICEArguments */ 0, E->getNumArgs() - 1, E);
6807-
return EmitFP8NeonCall(CGM.getIntrinsic(IID, Tys), Ops, FPM, name);
6806+
return EmitFP8NeonCall(IID, Tys, Ops, E, name);
68086807
}
68096808

68106809
// Right-shift a vector by a constant.
@@ -14072,6 +14071,32 @@ Value *CodeGenFunction::EmitAArch64BuiltinExpr(unsigned BuiltinID,
1407214071
case NEON::BI__builtin_neon_vdotq_laneq_f32_mf8_fpm:
1407314072
return EmitFP8NeonFDOTCall(Intrinsic::aarch64_neon_fp8_fdot4_lane,
1407414073
ExtendLane, FloatTy, Ops, E, "fdot4_lane");
14074+
14075+
case NEON::BI__builtin_neon_vmlalbq_f16_mf8_fpm:
14076+
return EmitFP8NeonCall(Intrinsic::aarch64_neon_fp8_fmlalb,
14077+
{llvm::FixedVectorType::get(HalfTy, 8)}, Ops, E,
14078+
"vmlal");
14079+
case NEON::BI__builtin_neon_vmlaltq_f16_mf8_fpm:
14080+
return EmitFP8NeonCall(Intrinsic::aarch64_neon_fp8_fmlalt,
14081+
{llvm::FixedVectorType::get(HalfTy, 8)}, Ops, E,
14082+
"vmlal");
14083+
case NEON::BI__builtin_neon_vmlallbbq_f32_mf8_fpm:
14084+
return EmitFP8NeonCall(Intrinsic::aarch64_neon_fp8_fmlallbb,
14085+
{llvm::FixedVectorType::get(FloatTy, 4)}, Ops, E,
14086+
"vmlall");
14087+
case NEON::BI__builtin_neon_vmlallbtq_f32_mf8_fpm:
14088+
return EmitFP8NeonCall(Intrinsic::aarch64_neon_fp8_fmlallbt,
14089+
{llvm::FixedVectorType::get(FloatTy, 4)}, Ops, E,
14090+
"vmlall");
14091+
case NEON::BI__builtin_neon_vmlalltbq_f32_mf8_fpm:
14092+
return EmitFP8NeonCall(Intrinsic::aarch64_neon_fp8_fmlalltb,
14093+
{llvm::FixedVectorType::get(FloatTy, 4)}, Ops, E,
14094+
"vmlall");
14095+
case NEON::BI__builtin_neon_vmlallttq_f32_mf8_fpm:
14096+
return EmitFP8NeonCall(Intrinsic::aarch64_neon_fp8_fmlalltt,
14097+
{llvm::FixedVectorType::get(FloatTy, 4)}, Ops, E,
14098+
"vmlall");
14099+
1407514100
case NEON::BI__builtin_neon_vamin_f16:
1407614101
case NEON::BI__builtin_neon_vaminq_f16:
1407714102
case NEON::BI__builtin_neon_vamin_f32:

clang/lib/CodeGen/CodeGenFunction.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4692,9 +4692,9 @@ class CodeGenFunction : public CodeGenTypeCache {
46924692
SmallVectorImpl<llvm::Value*> &O,
46934693
const char *name,
46944694
unsigned shift = 0, bool rightshift = false);
4695-
llvm::Value *EmitFP8NeonCall(llvm::Function *F,
4695+
llvm::Value *EmitFP8NeonCall(unsigned IID, ArrayRef<llvm::Type *> Tys,
46964696
SmallVectorImpl<llvm::Value *> &O,
4697-
llvm::Value *FPM, const char *name);
4697+
const CallExpr *E, const char *name);
46984698
llvm::Value *EmitFP8NeonCvtCall(unsigned IID, llvm::Type *Ty0,
46994699
llvm::Type *Ty1, bool Extract,
47004700
SmallVectorImpl<llvm::Value *> &Ops,
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
// NOTE: Assertions have been autogenerated by utils/update_cc_test_checks.py UTC_ARGS: --version 5
2+
// RUN: %clang_cc1 -triple aarch64-none-linux-gnu -target-feature +neon -target-feature +bf16 -target-feature +fp8 -target-feature +fp8fma -disable-O0-optnone -Werror -Wall -emit-llvm -o - %s | opt -S -p mem2reg | FileCheck %s
3+
// RUN: %clang_cc1 -x c++ -triple aarch64-none-linux-gnu -target-feature +neon -target-feature +bf16 -target-feature +fp8 -target-feature +fp8fma -disable-O0-optnone -Werror -Wall -emit-llvm -o - %s | opt -S -p mem2reg | FileCheck %s -check-prefix CHECK-CXX
4+
5+
// RUN: %clang_cc1 -triple aarch64-none-linux-gnu -target-feature +neon -target-feature +bf16 -target-feature +fp8 -target-feature +fp8fma -disable-O0-optnone -Werror -Wall -S -o /dev/null %s
6+
7+
// REQUIRES: aarch64-registered-target
8+
9+
#include <arm_neon.h>
10+
11+
// CHECK-LABEL: define dso_local <8 x half> @test_vmlalb(
12+
// CHECK-SAME: <8 x half> noundef [[VD:%.*]], <16 x i8> [[VN:%.*]], <16 x i8> [[VM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0:[0-9]+]] {
13+
// CHECK-NEXT: [[ENTRY:.*:]]
14+
// CHECK-NEXT: [[TMP0:%.*]] = bitcast <8 x half> [[VD]] to <16 x i8>
15+
// CHECK-NEXT: call void @llvm.aarch64.set.fpmr(i64 [[FPM]])
16+
// CHECK-NEXT: [[VMLAL1_I:%.*]] = call <8 x half> @llvm.aarch64.neon.fp8.fmlalb.v8f16(<8 x half> [[VD]], <16 x i8> [[VN]], <16 x i8> [[VM]])
17+
// CHECK-NEXT: ret <8 x half> [[VMLAL1_I]]
18+
//
19+
// CHECK-CXX-LABEL: define dso_local noundef <8 x half> @_Z11test_vmlalb13__Float16x8_t14__Mfloat8x16_tS0_m(
20+
// CHECK-CXX-SAME: <8 x half> noundef [[VD:%.*]], <16 x i8> [[VN:%.*]], <16 x i8> [[VM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0:[0-9]+]] {
21+
// CHECK-CXX-NEXT: [[ENTRY:.*:]]
22+
// CHECK-CXX-NEXT: [[TMP0:%.*]] = bitcast <8 x half> [[VD]] to <16 x i8>
23+
// CHECK-CXX-NEXT: call void @llvm.aarch64.set.fpmr(i64 [[FPM]])
24+
// CHECK-CXX-NEXT: [[VMLAL1_I:%.*]] = call <8 x half> @llvm.aarch64.neon.fp8.fmlalb.v8f16(<8 x half> [[VD]], <16 x i8> [[VN]], <16 x i8> [[VM]])
25+
// CHECK-CXX-NEXT: ret <8 x half> [[VMLAL1_I]]
26+
//
27+
float16x8_t test_vmlalb(float16x8_t vd, mfloat8x16_t vn, mfloat8x16_t vm, fpm_t fpm) {
28+
return vmlalbq_f16_mf8_fpm(vd, vn, vm, fpm);
29+
}
30+
31+
// CHECK-LABEL: define dso_local <8 x half> @test_vmlalt(
32+
// CHECK-SAME: <8 x half> noundef [[VD:%.*]], <16 x i8> [[VN:%.*]], <16 x i8> [[VM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] {
33+
// CHECK-NEXT: [[ENTRY:.*:]]
34+
// CHECK-NEXT: [[TMP0:%.*]] = bitcast <8 x half> [[VD]] to <16 x i8>
35+
// CHECK-NEXT: call void @llvm.aarch64.set.fpmr(i64 [[FPM]])
36+
// CHECK-NEXT: [[VMLAL1_I:%.*]] = call <8 x half> @llvm.aarch64.neon.fp8.fmlalt.v8f16(<8 x half> [[VD]], <16 x i8> [[VN]], <16 x i8> [[VM]])
37+
// CHECK-NEXT: ret <8 x half> [[VMLAL1_I]]
38+
//
39+
// CHECK-CXX-LABEL: define dso_local noundef <8 x half> @_Z11test_vmlalt13__Float16x8_t14__Mfloat8x16_tS0_m(
40+
// CHECK-CXX-SAME: <8 x half> noundef [[VD:%.*]], <16 x i8> [[VN:%.*]], <16 x i8> [[VM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] {
41+
// CHECK-CXX-NEXT: [[ENTRY:.*:]]
42+
// CHECK-CXX-NEXT: [[TMP0:%.*]] = bitcast <8 x half> [[VD]] to <16 x i8>
43+
// CHECK-CXX-NEXT: call void @llvm.aarch64.set.fpmr(i64 [[FPM]])
44+
// CHECK-CXX-NEXT: [[VMLAL1_I:%.*]] = call <8 x half> @llvm.aarch64.neon.fp8.fmlalt.v8f16(<8 x half> [[VD]], <16 x i8> [[VN]], <16 x i8> [[VM]])
45+
// CHECK-CXX-NEXT: ret <8 x half> [[VMLAL1_I]]
46+
//
47+
float16x8_t test_vmlalt(float16x8_t vd, mfloat8x16_t vn, mfloat8x16_t vm, fpm_t fpm) {
48+
return vmlaltq_f16_mf8_fpm(vd, vn, vm, fpm);
49+
}
50+
51+
// CHECK-LABEL: define dso_local <4 x float> @test_vmlallbb(
52+
// CHECK-SAME: <4 x float> noundef [[VD:%.*]], <16 x i8> [[VN:%.*]], <16 x i8> [[VM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] {
53+
// CHECK-NEXT: [[ENTRY:.*:]]
54+
// CHECK-NEXT: call void @llvm.aarch64.set.fpmr(i64 [[FPM]])
55+
// CHECK-NEXT: [[VMLALL_I:%.*]] = call <4 x float> @llvm.aarch64.neon.fp8.fmlallbb.v4f32(<4 x float> [[VD]], <16 x i8> [[VN]], <16 x i8> [[VM]])
56+
// CHECK-NEXT: ret <4 x float> [[VMLALL_I]]
57+
//
58+
// CHECK-CXX-LABEL: define dso_local noundef <4 x float> @_Z13test_vmlallbb13__Float32x4_t14__Mfloat8x16_tS0_m(
59+
// CHECK-CXX-SAME: <4 x float> noundef [[VD:%.*]], <16 x i8> [[VN:%.*]], <16 x i8> [[VM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] {
60+
// CHECK-CXX-NEXT: [[ENTRY:.*:]]
61+
// CHECK-CXX-NEXT: call void @llvm.aarch64.set.fpmr(i64 [[FPM]])
62+
// CHECK-CXX-NEXT: [[VMLALL_I:%.*]] = call <4 x float> @llvm.aarch64.neon.fp8.fmlallbb.v4f32(<4 x float> [[VD]], <16 x i8> [[VN]], <16 x i8> [[VM]])
63+
// CHECK-CXX-NEXT: ret <4 x float> [[VMLALL_I]]
64+
//
65+
float32x4_t test_vmlallbb(float32x4_t vd, mfloat8x16_t vn, mfloat8x16_t vm, fpm_t fpm) {
66+
return vmlallbbq_f32_mf8_fpm(vd, vn, vm, fpm);
67+
}
68+
69+
// CHECK-LABEL: define dso_local <4 x float> @test_vmlallbt(
70+
// CHECK-SAME: <4 x float> noundef [[VD:%.*]], <16 x i8> [[VN:%.*]], <16 x i8> [[VM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] {
71+
// CHECK-NEXT: [[ENTRY:.*:]]
72+
// CHECK-NEXT: call void @llvm.aarch64.set.fpmr(i64 [[FPM]])
73+
// CHECK-NEXT: [[VMLALL_I:%.*]] = call <4 x float> @llvm.aarch64.neon.fp8.fmlallbt.v4f32(<4 x float> [[VD]], <16 x i8> [[VN]], <16 x i8> [[VM]])
74+
// CHECK-NEXT: ret <4 x float> [[VMLALL_I]]
75+
//
76+
// CHECK-CXX-LABEL: define dso_local noundef <4 x float> @_Z13test_vmlallbt13__Float32x4_t14__Mfloat8x16_tS0_m(
77+
// CHECK-CXX-SAME: <4 x float> noundef [[VD:%.*]], <16 x i8> [[VN:%.*]], <16 x i8> [[VM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] {
78+
// CHECK-CXX-NEXT: [[ENTRY:.*:]]
79+
// CHECK-CXX-NEXT: call void @llvm.aarch64.set.fpmr(i64 [[FPM]])
80+
// CHECK-CXX-NEXT: [[VMLALL_I:%.*]] = call <4 x float> @llvm.aarch64.neon.fp8.fmlallbt.v4f32(<4 x float> [[VD]], <16 x i8> [[VN]], <16 x i8> [[VM]])
81+
// CHECK-CXX-NEXT: ret <4 x float> [[VMLALL_I]]
82+
//
83+
float32x4_t test_vmlallbt(float32x4_t vd, mfloat8x16_t vn, mfloat8x16_t vm, fpm_t fpm) {
84+
return vmlallbtq_f32_mf8_fpm(vd, vn, vm, fpm);
85+
}
86+
87+
// CHECK-LABEL: define dso_local <4 x float> @test_vmlalltb(
88+
// CHECK-SAME: <4 x float> noundef [[VD:%.*]], <16 x i8> [[VN:%.*]], <16 x i8> [[VM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] {
89+
// CHECK-NEXT: [[ENTRY:.*:]]
90+
// CHECK-NEXT: call void @llvm.aarch64.set.fpmr(i64 [[FPM]])
91+
// CHECK-NEXT: [[VMLALL_I:%.*]] = call <4 x float> @llvm.aarch64.neon.fp8.fmlalltb.v4f32(<4 x float> [[VD]], <16 x i8> [[VN]], <16 x i8> [[VM]])
92+
// CHECK-NEXT: ret <4 x float> [[VMLALL_I]]
93+
//
94+
// CHECK-CXX-LABEL: define dso_local noundef <4 x float> @_Z13test_vmlalltb13__Float32x4_t14__Mfloat8x16_tS0_m(
95+
// CHECK-CXX-SAME: <4 x float> noundef [[VD:%.*]], <16 x i8> [[VN:%.*]], <16 x i8> [[VM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] {
96+
// CHECK-CXX-NEXT: [[ENTRY:.*:]]
97+
// CHECK-CXX-NEXT: call void @llvm.aarch64.set.fpmr(i64 [[FPM]])
98+
// CHECK-CXX-NEXT: [[VMLALL_I:%.*]] = call <4 x float> @llvm.aarch64.neon.fp8.fmlalltb.v4f32(<4 x float> [[VD]], <16 x i8> [[VN]], <16 x i8> [[VM]])
99+
// CHECK-CXX-NEXT: ret <4 x float> [[VMLALL_I]]
100+
//
101+
float32x4_t test_vmlalltb(float32x4_t vd, mfloat8x16_t vn, mfloat8x16_t vm, fpm_t fpm) {
102+
return vmlalltbq_f32_mf8_fpm(vd, vn, vm, fpm);
103+
}
104+
105+
// CHECK-LABEL: define dso_local <4 x float> @test_vmlalltt(
106+
// CHECK-SAME: <4 x float> noundef [[VD:%.*]], <16 x i8> [[VN:%.*]], <16 x i8> [[VM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] {
107+
// CHECK-NEXT: [[ENTRY:.*:]]
108+
// CHECK-NEXT: call void @llvm.aarch64.set.fpmr(i64 [[FPM]])
109+
// CHECK-NEXT: [[VMLALL_I:%.*]] = call <4 x float> @llvm.aarch64.neon.fp8.fmlalltt.v4f32(<4 x float> [[VD]], <16 x i8> [[VN]], <16 x i8> [[VM]])
110+
// CHECK-NEXT: ret <4 x float> [[VMLALL_I]]
111+
//
112+
// CHECK-CXX-LABEL: define dso_local noundef <4 x float> @_Z13test_vmlalltt13__Float32x4_t14__Mfloat8x16_tS0_m(
113+
// CHECK-CXX-SAME: <4 x float> noundef [[VD:%.*]], <16 x i8> [[VN:%.*]], <16 x i8> [[VM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] {
114+
// CHECK-CXX-NEXT: [[ENTRY:.*:]]
115+
// CHECK-CXX-NEXT: call void @llvm.aarch64.set.fpmr(i64 [[FPM]])
116+
// CHECK-CXX-NEXT: [[VMLALL_I:%.*]] = call <4 x float> @llvm.aarch64.neon.fp8.fmlalltt.v4f32(<4 x float> [[VD]], <16 x i8> [[VN]], <16 x i8> [[VM]])
117+
// CHECK-CXX-NEXT: ret <4 x float> [[VMLALL_I]]
118+
//
119+
float32x4_t test_vmlalltt(float32x4_t vd, mfloat8x16_t vn, mfloat8x16_t vm, fpm_t fpm) {
120+
return vmlallttq_f32_mf8_fpm(vd, vn, vm, fpm);
121+
}
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
// RUN: %clang_cc1 -triple aarch64-linux-gnu -target-feature +neon -target-feature +bf16 -target-feature +faminmax -target-feature +fp8 -emit-llvm -verify %s -o /dev/null
2+
3+
// REQUIRES: aarch64-registered-target
4+
5+
#include <arm_neon.h>
6+
7+
void test_features(float16x8_t a, float32x4_t b, mfloat8x16_t u, fpm_t fpm) {
8+
9+
(void) vmlalbq_f16_mf8_fpm(a, u, u, fpm);
10+
// expected-error@-1 {{'vmlalbq_f16_mf8_fpm' requires target feature 'fp8fma'}}
11+
(void) vmlaltq_f16_mf8_fpm(a, u, u, fpm);
12+
// expected-error@-1 {{'vmlaltq_f16_mf8_fpm' requires target feature 'fp8fma'}}
13+
(void) vmlallbbq_f32_mf8_fpm(b, u, u, fpm);
14+
// expected-error@-1 {{'vmlallbbq_f32_mf8_fpm' requires target feature 'fp8fma'}}
15+
(void) vmlallbtq_f32_mf8_fpm(b, u, u, fpm);
16+
// expected-error@-1 {{'vmlallbtq_f32_mf8_fpm' requires target feature 'fp8fma'}}
17+
(void) vmlalltbq_f32_mf8_fpm(b, u, u, fpm);
18+
// expected-error@-1 {{'vmlalltbq_f32_mf8_fpm' requires target feature 'fp8fma'}}
19+
(void) vmlallttq_f32_mf8_fpm(b, u, u, fpm);
20+
// expected-error@-1 {{'vmlallttq_f32_mf8_fpm' requires target feature 'fp8fma'}}
21+
}
22+

llvm/include/llvm/IR/IntrinsicsAArch64.td

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1036,6 +1036,23 @@ def int_aarch64_st64bv0: Intrinsic<[llvm_i64_ty], !listconcat([llvm_ptr_ty], dat
10361036

10371037
def int_aarch64_neon_fp8_fdot4 : AdvSIMD_FP8_DOT_Intrinsic;
10381038
def int_aarch64_neon_fp8_fdot4_lane : AdvSIMD_FP8_DOT_LANE_Intrinsic;
1039+
1040+
1041+
// Fused multiply-add
1042+
class AdvSIMD_FP8_FMLA_Intrinsic
1043+
: DefaultAttrsIntrinsic<[llvm_anyvector_ty],
1044+
[LLVMMatchType<0>,
1045+
llvm_v16i8_ty,
1046+
llvm_v16i8_ty],
1047+
[IntrReadMem, IntrInaccessibleMemOnly]>;
1048+
1049+
def int_aarch64_neon_fp8_fmlalb : AdvSIMD_FP8_FMLA_Intrinsic;
1050+
def int_aarch64_neon_fp8_fmlalt : AdvSIMD_FP8_FMLA_Intrinsic;
1051+
1052+
def int_aarch64_neon_fp8_fmlallbb : AdvSIMD_FP8_FMLA_Intrinsic;
1053+
def int_aarch64_neon_fp8_fmlallbt : AdvSIMD_FP8_FMLA_Intrinsic;
1054+
def int_aarch64_neon_fp8_fmlalltb : AdvSIMD_FP8_FMLA_Intrinsic;
1055+
def int_aarch64_neon_fp8_fmlalltt : AdvSIMD_FP8_FMLA_Intrinsic;
10391056
}
10401057

10411058
def llvm_nxv1i1_ty : LLVMType<nxv1i1>;

llvm/lib/Target/AArch64/AArch64InstrFormats.td

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6519,14 +6519,15 @@ multiclass SIMDThreeSameVectorFML<bit U, bit b13, bits<3> size, string asm,
65196519
v4f32, v8f16, OpNode>;
65206520
}
65216521

6522-
multiclass SIMDThreeSameVectorMLA<bit Q, string asm>{
6522+
multiclass SIMDThreeSameVectorMLA<bit Q, string asm, SDPatternOperator op> {
6523+
65236524
def v8f16 : BaseSIMDThreeSameVectorDot<Q, 0b0, 0b11, 0b1111, asm, ".8h", ".16b",
6524-
V128, v8f16, v16i8, null_frag>;
6525+
V128, v8f16, v16i8, op>;
65256526
}
65266527

6527-
multiclass SIMDThreeSameVectorMLAL<bit Q, bits<2> sz, string asm>{
6528+
multiclass SIMDThreeSameVectorMLAL<bit Q, bits<2> sz, string asm, SDPatternOperator op> {
65286529
def v4f32 : BaseSIMDThreeSameVectorDot<Q, 0b0, sz, 0b1000, asm, ".4s", ".16b",
6529-
V128, v4f32, v16i8, null_frag>;
6530+
V128, v4f32, v16i8, op>;
65306531
}
65316532

65326533
// FP8 assembly/disassembly classes

llvm/lib/Target/AArch64/AArch64InstrInfo.td

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10353,20 +10353,20 @@ let Predicates = [HasNEON, HasFAMINMAX] in {
1035310353
defm FAMIN : SIMDThreeSameVectorFP<0b1, 0b1, 0b011, "famin", AArch64famin>;
1035410354
} // End let Predicates = [HasNEON, HasFAMINMAX]
1035510355

10356-
let Uses = [FPMR, FPCR], Predicates = [HasFP8FMA] in {
10356+
let Predicates = [HasFP8FMA], Uses = [FPMR, FPCR], mayLoad = 1 in {
1035710357
defm FMLALBlane : SIMDThreeSameVectorMLAIndex<0b0, "fmlalb">;
1035810358
defm FMLALTlane : SIMDThreeSameVectorMLAIndex<0b1, "fmlalt">;
1035910359
defm FMLALLBBlane : SIMDThreeSameVectorMLALIndex<0b0, 0b00, "fmlallbb">;
1036010360
defm FMLALLBTlane : SIMDThreeSameVectorMLALIndex<0b0, 0b01, "fmlallbt">;
1036110361
defm FMLALLTBlane : SIMDThreeSameVectorMLALIndex<0b1, 0b00, "fmlalltb">;
1036210362
defm FMLALLTTlane : SIMDThreeSameVectorMLALIndex<0b1, 0b01, "fmlalltt">;
1036310363

10364-
defm FMLALB : SIMDThreeSameVectorMLA<0b0, "fmlalb">;
10365-
defm FMLALT : SIMDThreeSameVectorMLA<0b1, "fmlalt">;
10366-
defm FMLALLBB : SIMDThreeSameVectorMLAL<0b0, 0b00, "fmlallbb">;
10367-
defm FMLALLBT : SIMDThreeSameVectorMLAL<0b0, 0b01, "fmlallbt">;
10368-
defm FMLALLTB : SIMDThreeSameVectorMLAL<0b1, 0b00, "fmlalltb">;
10369-
defm FMLALLTT : SIMDThreeSameVectorMLAL<0b1, 0b01, "fmlalltt">;
10364+
defm FMLALB : SIMDThreeSameVectorMLA<0b0, "fmlalb", int_aarch64_neon_fp8_fmlalb>;
10365+
defm FMLALT : SIMDThreeSameVectorMLA<0b1, "fmlalt", int_aarch64_neon_fp8_fmlalt>;
10366+
defm FMLALLBB : SIMDThreeSameVectorMLAL<0b0, 0b00, "fmlallbb", int_aarch64_neon_fp8_fmlallbb>;
10367+
defm FMLALLBT : SIMDThreeSameVectorMLAL<0b0, 0b01, "fmlallbt", int_aarch64_neon_fp8_fmlallbt>;
10368+
defm FMLALLTB : SIMDThreeSameVectorMLAL<0b1, 0b00, "fmlalltb", int_aarch64_neon_fp8_fmlalltb>;
10369+
defm FMLALLTT : SIMDThreeSameVectorMLAL<0b1, 0b01, "fmlalltt", int_aarch64_neon_fp8_fmlalltt>;
1037010370
} // End let Predicates = [HasFP8FMA]
1037110371

1037210372
let Predicates = [HasFP8DOT2] in {

0 commit comments

Comments
 (0)