Skip to content

Commit 2ceaf4b

Browse files
[AArch64] Implement NEON FP8 fused multiply-add intrinsics (non-indexed)
This patch adds the following intrinsics: float16x8_t vmlalbq_f16_mf8_fpm(float16x8_t, mfloat8x16_t, mfloat8x16_t, fpm_t) float16x8_t vmlaltq_f16_mf8_fpm(float16x8_t, mfloat8x16_t, mfloat8x16_t, fpm_t) float32x4_t vmlallbbq_f32_mf8_fpm(float32x4_t, mfloat8x16_t, mfloat8x16_t, fpm_t) float32x4_t vmlallbtq_f32_mf8_fpm(float32x4_t, mfloat8x16_t, mfloat8x16_t, fpm_t) float32x4_t vmlalltbq_f32_mf8_fpm(float32x4_t, mfloat8x16_t, mfloat8x16_t, fpm_t) float32x4_t vmlallttq_f32_mf8_fpm(float32x4_t, mfloat8x16_t, mfloat8x16_t, fpm_t) [fixup] Update intrinsics definitions [fixup] Remove some opt passes from RUN lines
1 parent 1c6a50e commit 2ceaf4b

File tree

9 files changed

+274
-22
lines changed

9 files changed

+274
-22
lines changed

clang/include/clang/Basic/arm_neon.td

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2167,6 +2167,16 @@ let ArchGuard = "defined(__aarch64__)", TargetGuard = "fp8dot4,neon" in {
21672167
def VDOTQ_LANEQ_F32_MF8 : VInst<"vdot_laneq_f32_mf8_fpm", "(>>F)(>>F)..IV", "Qm", [ImmCheck<3, ImmCheck0_3, 0>]>;
21682168
}
21692169

2170+
let ArchGuard = "defined(__aarch64__)", TargetGuard = "fp8fma,neon" in {
2171+
def VMLALB_F16_F8 : VInst<"vmlalb_f16_mf8_fpm", "(>F)(>F)..V", "Qm">;
2172+
def VMLALT_F16_F8 : VInst<"vmlalt_f16_mf8_fpm", "(>F)(>F)..V", "Qm">;
2173+
2174+
def VMLALLBB_F32_F8 : VInst<"vmlallbb_f32_mf8_fpm", "(>>F)(>>F)..V", "Qm">;
2175+
def VMLALLBT_F32_F8 : VInst<"vmlallbt_f32_mf8_fpm", "(>>F)(>>F)..V", "Qm">;
2176+
def VMLALLTB_F32_F8 : VInst<"vmlalltb_f32_mf8_fpm", "(>>F)(>>F)..V", "Qm">;
2177+
def VMLALLTT_F32_F8 : VInst<"vmlalltt_f32_mf8_fpm", "(>>F)(>>F)..V", "Qm">;
2178+
}
2179+
21702180
let ArchGuard = "defined(__aarch64__)", TargetGuard = "neon,faminmax" in {
21712181
def FAMIN : WInst<"vamin", "...", "fhQdQfQh">;
21722182
def FAMAX : WInst<"vamax", "...", "fhQdQfQh">;

clang/lib/CodeGen/CGBuiltin.cpp

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6899,11 +6899,14 @@ Value *CodeGenFunction::EmitNeonCall(Function *F, SmallVectorImpl<Value*> &Ops,
68996899
return Builder.CreateCall(F, Ops, name);
69006900
}
69016901

6902-
Value *CodeGenFunction::EmitFP8NeonCall(Function *F,
6902+
Value *CodeGenFunction::EmitFP8NeonCall(unsigned IID,
6903+
ArrayRef<llvm::Type *> Tys,
69036904
SmallVectorImpl<Value *> &Ops,
6904-
Value *FPM, const char *name) {
6905+
const CallExpr *E, const char *name) {
6906+
llvm::Value *FPM =
6907+
EmitScalarOrConstFoldImmArg(/* ICEArguments */ 0, E->getNumArgs() - 1, E);
69056908
Builder.CreateCall(CGM.getIntrinsic(Intrinsic::aarch64_set_fpmr), FPM);
6906-
return EmitNeonCall(F, Ops, name);
6909+
return EmitNeonCall(CGM.getIntrinsic(IID, Tys), Ops, name);
69076910
}
69086911

69096912
llvm::Value *CodeGenFunction::EmitFP8NeonFDOTCall(
@@ -6919,9 +6922,7 @@ llvm::Value *CodeGenFunction::EmitFP8NeonFDOTCall(
69196922
Ops[2] = Builder.CreateInsertVector(VT, PoisonValue::get(VT), Ops[2],
69206923
Builder.getInt64(0));
69216924
}
6922-
llvm::Value *FPM =
6923-
EmitScalarOrConstFoldImmArg(/* ICEArguments */ 0, E->getNumArgs() - 1, E);
6924-
return EmitFP8NeonCall(CGM.getIntrinsic(IID, Tys), Ops, FPM, name);
6925+
return EmitFP8NeonCall(IID, Tys, Ops, E, name);
69256926
}
69266927

69276928
Value *CodeGenFunction::EmitNeonShiftVector(Value *V, llvm::Type *Ty,
@@ -6942,9 +6943,7 @@ Value *CodeGenFunction::EmitFP8NeonCvtCall(unsigned IID, llvm::Type *Ty0,
69426943
Tys[1] = llvm::FixedVectorType::get(Int8Ty, 8);
69436944
Ops[0] = Builder.CreateExtractVector(Tys[1], Ops[0], Builder.getInt64(0));
69446945
}
6945-
llvm::Value *FPM =
6946-
EmitScalarOrConstFoldImmArg(/* ICEArguments */ 0, E->getNumArgs() - 1, E);
6947-
return EmitFP8NeonCall(CGM.getIntrinsic(IID, Tys), Ops, FPM, name);
6946+
return EmitFP8NeonCall(IID, Tys, Ops, E, name);
69486947
}
69496948

69506949
// Right-shift a vector by a constant.
@@ -14177,6 +14176,32 @@ Value *CodeGenFunction::EmitAArch64BuiltinExpr(unsigned BuiltinID,
1417714176
case NEON::BI__builtin_neon_vdotq_laneq_f32_mf8_fpm:
1417814177
return EmitFP8NeonFDOTCall(Intrinsic::aarch64_neon_fp8_fdot4_lane,
1417914178
ExtendLane, FloatTy, Ops, E, "fdot4_lane");
14179+
14180+
case NEON::BI__builtin_neon_vmlalbq_f16_mf8_fpm:
14181+
return EmitFP8NeonCall(Intrinsic::aarch64_neon_fp8_fmlalb,
14182+
{llvm::FixedVectorType::get(HalfTy, 8)}, Ops, E,
14183+
"vmlal");
14184+
case NEON::BI__builtin_neon_vmlaltq_f16_mf8_fpm:
14185+
return EmitFP8NeonCall(Intrinsic::aarch64_neon_fp8_fmlalt,
14186+
{llvm::FixedVectorType::get(HalfTy, 8)}, Ops, E,
14187+
"vmlal");
14188+
case NEON::BI__builtin_neon_vmlallbbq_f32_mf8_fpm:
14189+
return EmitFP8NeonCall(Intrinsic::aarch64_neon_fp8_fmlallbb,
14190+
{llvm::FixedVectorType::get(FloatTy, 4)}, Ops, E,
14191+
"vmlall");
14192+
case NEON::BI__builtin_neon_vmlallbtq_f32_mf8_fpm:
14193+
return EmitFP8NeonCall(Intrinsic::aarch64_neon_fp8_fmlallbt,
14194+
{llvm::FixedVectorType::get(FloatTy, 4)}, Ops, E,
14195+
"vmlall");
14196+
case NEON::BI__builtin_neon_vmlalltbq_f32_mf8_fpm:
14197+
return EmitFP8NeonCall(Intrinsic::aarch64_neon_fp8_fmlalltb,
14198+
{llvm::FixedVectorType::get(FloatTy, 4)}, Ops, E,
14199+
"vmlall");
14200+
case NEON::BI__builtin_neon_vmlallttq_f32_mf8_fpm:
14201+
return EmitFP8NeonCall(Intrinsic::aarch64_neon_fp8_fmlalltt,
14202+
{llvm::FixedVectorType::get(FloatTy, 4)}, Ops, E,
14203+
"vmlall");
14204+
1418014205
case NEON::BI__builtin_neon_vamin_f16:
1418114206
case NEON::BI__builtin_neon_vaminq_f16:
1418214207
case NEON::BI__builtin_neon_vamin_f32:

clang/lib/CodeGen/CodeGenFunction.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4691,9 +4691,9 @@ class CodeGenFunction : public CodeGenTypeCache {
46914691
SmallVectorImpl<llvm::Value*> &O,
46924692
const char *name,
46934693
unsigned shift = 0, bool rightshift = false);
4694-
llvm::Value *EmitFP8NeonCall(llvm::Function *F,
4694+
llvm::Value *EmitFP8NeonCall(unsigned IID, ArrayRef<llvm::Type *> Tys,
46954695
SmallVectorImpl<llvm::Value *> &O,
4696-
llvm::Value *FPM, const char *name);
4696+
const CallExpr *E, const char *name);
46974697
llvm::Value *EmitFP8NeonCvtCall(unsigned IID, llvm::Type *Ty0,
46984698
llvm::Type *Ty1, bool Extract,
46994699
SmallVectorImpl<llvm::Value *> &Ops,
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
// NOTE: Assertions have been autogenerated by utils/update_cc_test_checks.py UTC_ARGS: --version 5
2+
// RUN: %clang_cc1 -triple aarch64-none-linux-gnu -target-feature +neon -target-feature +bf16 -target-feature +fp8 -target-feature +fp8fma -disable-O0-optnone -Werror -Wall -emit-llvm -o - %s | opt -S -p mem2reg | FileCheck %s
3+
// RUN: %clang_cc1 -x c++ -triple aarch64-none-linux-gnu -target-feature +neon -target-feature +bf16 -target-feature +fp8 -target-feature +fp8fma -disable-O0-optnone -Werror -Wall -emit-llvm -o - %s | opt -S -p mem2reg | FileCheck %s -check-prefix CHECK-CXX
4+
5+
// RUN: %clang_cc1 -triple aarch64-none-linux-gnu -target-feature +neon -target-feature +bf16 -target-feature +fp8 -target-feature +fp8fma -disable-O0-optnone -Werror -Wall -S -o /dev/null %s
6+
7+
// REQUIRES: aarch64-registered-target
8+
9+
#include <arm_neon.h>
10+
11+
// CHECK-LABEL: define dso_local <8 x half> @test_vmlalb(
12+
// CHECK-SAME: <8 x half> noundef [[VD:%.*]], <16 x i8> [[VN:%.*]], <16 x i8> [[VM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0:[0-9]+]] {
13+
// CHECK-NEXT: [[ENTRY:.*:]]
14+
// CHECK-NEXT: [[TMP0:%.*]] = bitcast <8 x half> [[VD]] to <16 x i8>
15+
// CHECK-NEXT: call void @llvm.aarch64.set.fpmr(i64 [[FPM]])
16+
// CHECK-NEXT: [[VMLAL1_I:%.*]] = call <8 x half> @llvm.aarch64.neon.fp8.fmlalb.v8f16(<8 x half> [[VD]], <16 x i8> [[VN]], <16 x i8> [[VM]])
17+
// CHECK-NEXT: ret <8 x half> [[VMLAL1_I]]
18+
//
19+
// CHECK-CXX-LABEL: define dso_local noundef <8 x half> @_Z11test_vmlalb13__Float16x8_t14__Mfloat8x16_tS0_m(
20+
// CHECK-CXX-SAME: <8 x half> noundef [[VD:%.*]], <16 x i8> [[VN:%.*]], <16 x i8> [[VM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0:[0-9]+]] {
21+
// CHECK-CXX-NEXT: [[ENTRY:.*:]]
22+
// CHECK-CXX-NEXT: [[TMP0:%.*]] = bitcast <8 x half> [[VD]] to <16 x i8>
23+
// CHECK-CXX-NEXT: call void @llvm.aarch64.set.fpmr(i64 [[FPM]])
24+
// CHECK-CXX-NEXT: [[VMLAL1_I:%.*]] = call <8 x half> @llvm.aarch64.neon.fp8.fmlalb.v8f16(<8 x half> [[VD]], <16 x i8> [[VN]], <16 x i8> [[VM]])
25+
// CHECK-CXX-NEXT: ret <8 x half> [[VMLAL1_I]]
26+
//
27+
float16x8_t test_vmlalb(float16x8_t vd, mfloat8x16_t vn, mfloat8x16_t vm, fpm_t fpm) {
28+
return vmlalbq_f16_mf8_fpm(vd, vn, vm, fpm);
29+
}
30+
31+
// CHECK-LABEL: define dso_local <8 x half> @test_vmlalt(
32+
// CHECK-SAME: <8 x half> noundef [[VD:%.*]], <16 x i8> [[VN:%.*]], <16 x i8> [[VM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] {
33+
// CHECK-NEXT: [[ENTRY:.*:]]
34+
// CHECK-NEXT: [[TMP0:%.*]] = bitcast <8 x half> [[VD]] to <16 x i8>
35+
// CHECK-NEXT: call void @llvm.aarch64.set.fpmr(i64 [[FPM]])
36+
// CHECK-NEXT: [[VMLAL1_I:%.*]] = call <8 x half> @llvm.aarch64.neon.fp8.fmlalt.v8f16(<8 x half> [[VD]], <16 x i8> [[VN]], <16 x i8> [[VM]])
37+
// CHECK-NEXT: ret <8 x half> [[VMLAL1_I]]
38+
//
39+
// CHECK-CXX-LABEL: define dso_local noundef <8 x half> @_Z11test_vmlalt13__Float16x8_t14__Mfloat8x16_tS0_m(
40+
// CHECK-CXX-SAME: <8 x half> noundef [[VD:%.*]], <16 x i8> [[VN:%.*]], <16 x i8> [[VM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] {
41+
// CHECK-CXX-NEXT: [[ENTRY:.*:]]
42+
// CHECK-CXX-NEXT: [[TMP0:%.*]] = bitcast <8 x half> [[VD]] to <16 x i8>
43+
// CHECK-CXX-NEXT: call void @llvm.aarch64.set.fpmr(i64 [[FPM]])
44+
// CHECK-CXX-NEXT: [[VMLAL1_I:%.*]] = call <8 x half> @llvm.aarch64.neon.fp8.fmlalt.v8f16(<8 x half> [[VD]], <16 x i8> [[VN]], <16 x i8> [[VM]])
45+
// CHECK-CXX-NEXT: ret <8 x half> [[VMLAL1_I]]
46+
//
47+
float16x8_t test_vmlalt(float16x8_t vd, mfloat8x16_t vn, mfloat8x16_t vm, fpm_t fpm) {
48+
return vmlaltq_f16_mf8_fpm(vd, vn, vm, fpm);
49+
}
50+
51+
// CHECK-LABEL: define dso_local <4 x float> @test_vmlallbb(
52+
// CHECK-SAME: <4 x float> noundef [[VD:%.*]], <16 x i8> [[VN:%.*]], <16 x i8> [[VM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] {
53+
// CHECK-NEXT: [[ENTRY:.*:]]
54+
// CHECK-NEXT: call void @llvm.aarch64.set.fpmr(i64 [[FPM]])
55+
// CHECK-NEXT: [[VMLALL_I:%.*]] = call <4 x float> @llvm.aarch64.neon.fp8.fmlallbb.v4f32(<4 x float> [[VD]], <16 x i8> [[VN]], <16 x i8> [[VM]])
56+
// CHECK-NEXT: ret <4 x float> [[VMLALL_I]]
57+
//
58+
// CHECK-CXX-LABEL: define dso_local noundef <4 x float> @_Z13test_vmlallbb13__Float32x4_t14__Mfloat8x16_tS0_m(
59+
// CHECK-CXX-SAME: <4 x float> noundef [[VD:%.*]], <16 x i8> [[VN:%.*]], <16 x i8> [[VM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] {
60+
// CHECK-CXX-NEXT: [[ENTRY:.*:]]
61+
// CHECK-CXX-NEXT: call void @llvm.aarch64.set.fpmr(i64 [[FPM]])
62+
// CHECK-CXX-NEXT: [[VMLALL_I:%.*]] = call <4 x float> @llvm.aarch64.neon.fp8.fmlallbb.v4f32(<4 x float> [[VD]], <16 x i8> [[VN]], <16 x i8> [[VM]])
63+
// CHECK-CXX-NEXT: ret <4 x float> [[VMLALL_I]]
64+
//
65+
float32x4_t test_vmlallbb(float32x4_t vd, mfloat8x16_t vn, mfloat8x16_t vm, fpm_t fpm) {
66+
return vmlallbbq_f32_mf8_fpm(vd, vn, vm, fpm);
67+
}
68+
69+
// CHECK-LABEL: define dso_local <4 x float> @test_vmlallbt(
70+
// CHECK-SAME: <4 x float> noundef [[VD:%.*]], <16 x i8> [[VN:%.*]], <16 x i8> [[VM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] {
71+
// CHECK-NEXT: [[ENTRY:.*:]]
72+
// CHECK-NEXT: call void @llvm.aarch64.set.fpmr(i64 [[FPM]])
73+
// CHECK-NEXT: [[VMLALL_I:%.*]] = call <4 x float> @llvm.aarch64.neon.fp8.fmlallbt.v4f32(<4 x float> [[VD]], <16 x i8> [[VN]], <16 x i8> [[VM]])
74+
// CHECK-NEXT: ret <4 x float> [[VMLALL_I]]
75+
//
76+
// CHECK-CXX-LABEL: define dso_local noundef <4 x float> @_Z13test_vmlallbt13__Float32x4_t14__Mfloat8x16_tS0_m(
77+
// CHECK-CXX-SAME: <4 x float> noundef [[VD:%.*]], <16 x i8> [[VN:%.*]], <16 x i8> [[VM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] {
78+
// CHECK-CXX-NEXT: [[ENTRY:.*:]]
79+
// CHECK-CXX-NEXT: call void @llvm.aarch64.set.fpmr(i64 [[FPM]])
80+
// CHECK-CXX-NEXT: [[VMLALL_I:%.*]] = call <4 x float> @llvm.aarch64.neon.fp8.fmlallbt.v4f32(<4 x float> [[VD]], <16 x i8> [[VN]], <16 x i8> [[VM]])
81+
// CHECK-CXX-NEXT: ret <4 x float> [[VMLALL_I]]
82+
//
83+
float32x4_t test_vmlallbt(float32x4_t vd, mfloat8x16_t vn, mfloat8x16_t vm, fpm_t fpm) {
84+
return vmlallbtq_f32_mf8_fpm(vd, vn, vm, fpm);
85+
}
86+
87+
// CHECK-LABEL: define dso_local <4 x float> @test_vmlalltb(
88+
// CHECK-SAME: <4 x float> noundef [[VD:%.*]], <16 x i8> [[VN:%.*]], <16 x i8> [[VM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] {
89+
// CHECK-NEXT: [[ENTRY:.*:]]
90+
// CHECK-NEXT: call void @llvm.aarch64.set.fpmr(i64 [[FPM]])
91+
// CHECK-NEXT: [[VMLALL_I:%.*]] = call <4 x float> @llvm.aarch64.neon.fp8.fmlalltb.v4f32(<4 x float> [[VD]], <16 x i8> [[VN]], <16 x i8> [[VM]])
92+
// CHECK-NEXT: ret <4 x float> [[VMLALL_I]]
93+
//
94+
// CHECK-CXX-LABEL: define dso_local noundef <4 x float> @_Z13test_vmlalltb13__Float32x4_t14__Mfloat8x16_tS0_m(
95+
// CHECK-CXX-SAME: <4 x float> noundef [[VD:%.*]], <16 x i8> [[VN:%.*]], <16 x i8> [[VM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] {
96+
// CHECK-CXX-NEXT: [[ENTRY:.*:]]
97+
// CHECK-CXX-NEXT: call void @llvm.aarch64.set.fpmr(i64 [[FPM]])
98+
// CHECK-CXX-NEXT: [[VMLALL_I:%.*]] = call <4 x float> @llvm.aarch64.neon.fp8.fmlalltb.v4f32(<4 x float> [[VD]], <16 x i8> [[VN]], <16 x i8> [[VM]])
99+
// CHECK-CXX-NEXT: ret <4 x float> [[VMLALL_I]]
100+
//
101+
float32x4_t test_vmlalltb(float32x4_t vd, mfloat8x16_t vn, mfloat8x16_t vm, fpm_t fpm) {
102+
return vmlalltbq_f32_mf8_fpm(vd, vn, vm, fpm);
103+
}
104+
105+
// CHECK-LABEL: define dso_local <4 x float> @test_vmlalltt(
106+
// CHECK-SAME: <4 x float> noundef [[VD:%.*]], <16 x i8> [[VN:%.*]], <16 x i8> [[VM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] {
107+
// CHECK-NEXT: [[ENTRY:.*:]]
108+
// CHECK-NEXT: call void @llvm.aarch64.set.fpmr(i64 [[FPM]])
109+
// CHECK-NEXT: [[VMLALL_I:%.*]] = call <4 x float> @llvm.aarch64.neon.fp8.fmlalltt.v4f32(<4 x float> [[VD]], <16 x i8> [[VN]], <16 x i8> [[VM]])
110+
// CHECK-NEXT: ret <4 x float> [[VMLALL_I]]
111+
//
112+
// CHECK-CXX-LABEL: define dso_local noundef <4 x float> @_Z13test_vmlalltt13__Float32x4_t14__Mfloat8x16_tS0_m(
113+
// CHECK-CXX-SAME: <4 x float> noundef [[VD:%.*]], <16 x i8> [[VN:%.*]], <16 x i8> [[VM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] {
114+
// CHECK-CXX-NEXT: [[ENTRY:.*:]]
115+
// CHECK-CXX-NEXT: call void @llvm.aarch64.set.fpmr(i64 [[FPM]])
116+
// CHECK-CXX-NEXT: [[VMLALL_I:%.*]] = call <4 x float> @llvm.aarch64.neon.fp8.fmlalltt.v4f32(<4 x float> [[VD]], <16 x i8> [[VN]], <16 x i8> [[VM]])
117+
// CHECK-CXX-NEXT: ret <4 x float> [[VMLALL_I]]
118+
//
119+
float32x4_t test_vmlalltt(float32x4_t vd, mfloat8x16_t vn, mfloat8x16_t vm, fpm_t fpm) {
120+
return vmlallttq_f32_mf8_fpm(vd, vn, vm, fpm);
121+
}
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
// RUN: %clang_cc1 -triple aarch64-linux-gnu -target-feature +neon -target-feature +bf16 -target-feature +faminmax -target-feature +fp8 -emit-llvm -verify %s -o /dev/null
2+
3+
// REQUIRES: aarch64-registered-target
4+
5+
#include <arm_neon.h>
6+
7+
void test_features(float16x8_t a, float32x4_t b, mfloat8x16_t u, fpm_t fpm) {
8+
9+
(void) vmlalbq_f16_mf8_fpm(a, u, u, fpm);
10+
// expected-error@-1 {{'vmlalbq_f16_mf8_fpm' requires target feature 'fp8fma'}}
11+
(void) vmlaltq_f16_mf8_fpm(a, u, u, fpm);
12+
// expected-error@-1 {{'vmlaltq_f16_mf8_fpm' requires target feature 'fp8fma'}}
13+
(void) vmlallbbq_f32_mf8_fpm(b, u, u, fpm);
14+
// expected-error@-1 {{'vmlallbbq_f32_mf8_fpm' requires target feature 'fp8fma'}}
15+
(void) vmlallbtq_f32_mf8_fpm(b, u, u, fpm);
16+
// expected-error@-1 {{'vmlallbtq_f32_mf8_fpm' requires target feature 'fp8fma'}}
17+
(void) vmlalltbq_f32_mf8_fpm(b, u, u, fpm);
18+
// expected-error@-1 {{'vmlalltbq_f32_mf8_fpm' requires target feature 'fp8fma'}}
19+
(void) vmlallttq_f32_mf8_fpm(b, u, u, fpm);
20+
// expected-error@-1 {{'vmlallttq_f32_mf8_fpm' requires target feature 'fp8fma'}}
21+
}
22+

llvm/include/llvm/IR/IntrinsicsAArch64.td

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1047,6 +1047,23 @@ def int_aarch64_st64bv0: Intrinsic<[llvm_i64_ty], !listconcat([llvm_ptr_ty], dat
10471047

10481048
def int_aarch64_neon_fp8_fdot4 : AdvSIMD_FP8_DOT_Intrinsic;
10491049
def int_aarch64_neon_fp8_fdot4_lane : AdvSIMD_FP8_DOT_LANE_Intrinsic;
1050+
1051+
1052+
// Fused multiply-add
1053+
class AdvSIMD_FP8_FMLA_Intrinsic
1054+
: DefaultAttrsIntrinsic<[llvm_anyvector_ty],
1055+
[LLVMMatchType<0>,
1056+
llvm_v16i8_ty,
1057+
llvm_v16i8_ty],
1058+
[IntrReadMem, IntrInaccessibleMemOnly]>;
1059+
1060+
def int_aarch64_neon_fp8_fmlalb : AdvSIMD_FP8_FMLA_Intrinsic;
1061+
def int_aarch64_neon_fp8_fmlalt : AdvSIMD_FP8_FMLA_Intrinsic;
1062+
1063+
def int_aarch64_neon_fp8_fmlallbb : AdvSIMD_FP8_FMLA_Intrinsic;
1064+
def int_aarch64_neon_fp8_fmlallbt : AdvSIMD_FP8_FMLA_Intrinsic;
1065+
def int_aarch64_neon_fp8_fmlalltb : AdvSIMD_FP8_FMLA_Intrinsic;
1066+
def int_aarch64_neon_fp8_fmlalltt : AdvSIMD_FP8_FMLA_Intrinsic;
10501067
}
10511068

10521069
def llvm_nxv1i1_ty : LLVMType<nxv1i1>;

llvm/lib/Target/AArch64/AArch64InstrFormats.td

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6519,14 +6519,15 @@ multiclass SIMDThreeSameVectorFML<bit U, bit b13, bits<3> size, string asm,
65196519
v4f32, v8f16, OpNode>;
65206520
}
65216521

6522-
multiclass SIMDThreeSameVectorMLA<bit Q, string asm>{
6522+
multiclass SIMDThreeSameVectorMLA<bit Q, string asm, SDPatternOperator op> {
6523+
65236524
def v8f16 : BaseSIMDThreeSameVectorDot<Q, 0b0, 0b11, 0b1111, asm, ".8h", ".16b",
6524-
V128, v8f16, v16i8, null_frag>;
6525+
V128, v8f16, v16i8, op>;
65256526
}
65266527

6527-
multiclass SIMDThreeSameVectorMLAL<bit Q, bits<2> sz, string asm>{
6528+
multiclass SIMDThreeSameVectorMLAL<bit Q, bits<2> sz, string asm, SDPatternOperator op> {
65286529
def v4f32 : BaseSIMDThreeSameVectorDot<Q, 0b0, sz, 0b1000, asm, ".4s", ".16b",
6529-
V128, v4f32, v16i8, null_frag>;
6530+
V128, v4f32, v16i8, op>;
65306531
}
65316532

65326533
// FP8 assembly/disassembly classes

llvm/lib/Target/AArch64/AArch64InstrInfo.td

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10345,20 +10345,20 @@ let Predicates = [HasNEON, HasFAMINMAX] in {
1034510345
defm FAMIN : SIMDThreeSameVectorFP<0b1, 0b1, 0b011, "famin", AArch64famin>;
1034610346
} // End let Predicates = [HasNEON, HasFAMINMAX]
1034710347

10348-
let Uses = [FPMR, FPCR], Predicates = [HasFP8FMA] in {
10348+
let Predicates = [HasFP8FMA], Uses = [FPMR, FPCR], mayLoad = 1 in {
1034910349
defm FMLALBlane : SIMDThreeSameVectorMLAIndex<0b0, "fmlalb">;
1035010350
defm FMLALTlane : SIMDThreeSameVectorMLAIndex<0b1, "fmlalt">;
1035110351
defm FMLALLBBlane : SIMDThreeSameVectorMLALIndex<0b0, 0b00, "fmlallbb">;
1035210352
defm FMLALLBTlane : SIMDThreeSameVectorMLALIndex<0b0, 0b01, "fmlallbt">;
1035310353
defm FMLALLTBlane : SIMDThreeSameVectorMLALIndex<0b1, 0b00, "fmlalltb">;
1035410354
defm FMLALLTTlane : SIMDThreeSameVectorMLALIndex<0b1, 0b01, "fmlalltt">;
1035510355

10356-
defm FMLALB : SIMDThreeSameVectorMLA<0b0, "fmlalb">;
10357-
defm FMLALT : SIMDThreeSameVectorMLA<0b1, "fmlalt">;
10358-
defm FMLALLBB : SIMDThreeSameVectorMLAL<0b0, 0b00, "fmlallbb">;
10359-
defm FMLALLBT : SIMDThreeSameVectorMLAL<0b0, 0b01, "fmlallbt">;
10360-
defm FMLALLTB : SIMDThreeSameVectorMLAL<0b1, 0b00, "fmlalltb">;
10361-
defm FMLALLTT : SIMDThreeSameVectorMLAL<0b1, 0b01, "fmlalltt">;
10356+
defm FMLALB : SIMDThreeSameVectorMLA<0b0, "fmlalb", int_aarch64_neon_fp8_fmlalb>;
10357+
defm FMLALT : SIMDThreeSameVectorMLA<0b1, "fmlalt", int_aarch64_neon_fp8_fmlalt>;
10358+
defm FMLALLBB : SIMDThreeSameVectorMLAL<0b0, 0b00, "fmlallbb", int_aarch64_neon_fp8_fmlallbb>;
10359+
defm FMLALLBT : SIMDThreeSameVectorMLAL<0b0, 0b01, "fmlallbt", int_aarch64_neon_fp8_fmlallbt>;
10360+
defm FMLALLTB : SIMDThreeSameVectorMLAL<0b1, 0b00, "fmlalltb", int_aarch64_neon_fp8_fmlalltb>;
10361+
defm FMLALLTT : SIMDThreeSameVectorMLAL<0b1, 0b01, "fmlalltt", int_aarch64_neon_fp8_fmlalltt>;
1036210362
} // End let Predicates = [HasFP8FMA]
1036310363

1036410364
let Predicates = [HasFP8DOT2] in {

0 commit comments

Comments
 (0)