Skip to content

Commit 660fdce

Browse files
[AArch64] Implement NEON FP8 fused multiply-add intrinsics (non-indexed)
This patch adds the following intrinsics: float16x8_t vmlalbq_f16_mf8_fpm(float16x8_t, mfloat8x16_t, mfloat8x16_t, fpm_t) float16x8_t vmlaltq_f16_mf8_fpm(float16x8_t, mfloat8x16_t, mfloat8x16_t, fpm_t) float32x4_t vmlallbbq_f32_mf8_fpm(float32x4_t, mfloat8x16_t, mfloat8x16_t, fpm_t) float32x4_t vmlallbtq_f32_mf8_fpm(float32x4_t, mfloat8x16_t, mfloat8x16_t, fpm_t) float32x4_t vmlalltbq_f32_mf8_fpm(float32x4_t, mfloat8x16_t, mfloat8x16_t, fpm_t) float32x4_t vmlallttq_f32_mf8_fpm(float32x4_t, mfloat8x16_t, mfloat8x16_t, fpm_t)
1 parent b755f98 commit 660fdce

File tree

9 files changed

+271
-22
lines changed

9 files changed

+271
-22
lines changed

clang/include/clang/Basic/arm_neon.td

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2170,6 +2170,17 @@ let ArchGuard = "defined(__aarch64__)", TargetGuard = "fp8dot4,neon" in {
21702170
def VDOTQ_LANEQ_F32_MF8 : VInst<"vdotq_laneq_f32_mf8_fpm", "(>>F)(>>F)..IV", "Qm", [ImmCheck<3, ImmCheck0_3, 0>]>;
21712171
}
21722172

2173+
2174+
let ArchGuard = "defined(__aarch64__)", TargetGuard = "fp8fma,neon" in {
2175+
def VMLALB_F16_F8 : VInst<"vmlalbq_f16_mf8_fpm", "(>F)(>F)..V", "Qm">;
2176+
def VMLALT_F16_F8 : VInst<"vmlaltq_f16_mf8_fpm", "(>F)(>F)..V", "Qm">;
2177+
2178+
def VMLALLBB_F32_F8 : VInst<"vmlallbbq_f32_mf8_fpm", "(>>F)(>>F)..V", "Qm">;
2179+
def VMLALLBT_F32_F8 : VInst<"vmlallbtq_f32_mf8_fpm", "(>>F)(>>F)..V", "Qm">;
2180+
def VMLALLTB_F32_F8 : VInst<"vmlalltbq_f32_mf8_fpm", "(>>F)(>>F)..V", "Qm">;
2181+
def VMLALLTT_F32_F8 : VInst<"vmlallttq_f32_mf8_fpm", "(>>F)(>>F)..V", "Qm">;
2182+
}
2183+
21732184
let ArchGuard = "defined(__aarch64__)", TargetGuard = "neon,faminmax" in {
21742185
def FAMIN : WInst<"vamin", "...", "fhQdQfQh">;
21752186
def FAMAX : WInst<"vamax", "...", "fhQdQfQh">;

clang/lib/CodeGen/CGBuiltin.cpp

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6869,11 +6869,14 @@ Value *CodeGenFunction::EmitNeonCall(Function *F, SmallVectorImpl<Value*> &Ops,
68696869
return Builder.CreateCall(F, Ops, name);
68706870
}
68716871

6872-
Value *CodeGenFunction::EmitFP8NeonCall(Function *F,
6872+
Value *CodeGenFunction::EmitFP8NeonCall(unsigned IID,
6873+
ArrayRef<llvm::Type *> Tys,
68736874
SmallVectorImpl<Value *> &Ops,
6874-
Value *FPM, const char *name) {
6875+
const CallExpr *E, const char *name) {
6876+
llvm::Value *FPM =
6877+
EmitScalarOrConstFoldImmArg(/* ICEArguments */ 0, E->getNumArgs() - 1, E);
68756878
Builder.CreateCall(CGM.getIntrinsic(Intrinsic::aarch64_set_fpmr), FPM);
6876-
return EmitNeonCall(F, Ops, name);
6879+
return EmitNeonCall(CGM.getIntrinsic(IID, Tys), Ops, name);
68776880
}
68786881

68796882
llvm::Value *CodeGenFunction::EmitFP8NeonFDOTCall(
@@ -6889,9 +6892,7 @@ llvm::Value *CodeGenFunction::EmitFP8NeonFDOTCall(
68896892
Ops[2] = Builder.CreateInsertVector(VT, PoisonValue::get(VT), Ops[2],
68906893
Builder.getInt64(0));
68916894
}
6892-
llvm::Value *FPM =
6893-
EmitScalarOrConstFoldImmArg(/* ICEArguments */ 0, E->getNumArgs() - 1, E);
6894-
return EmitFP8NeonCall(CGM.getIntrinsic(IID, Tys), Ops, FPM, name);
6895+
return EmitFP8NeonCall(IID, Tys, Ops, E, name);
68956896
}
68966897

68976898
Value *CodeGenFunction::EmitNeonShiftVector(Value *V, llvm::Type *Ty,
@@ -6912,9 +6913,7 @@ Value *CodeGenFunction::EmitFP8NeonCvtCall(unsigned IID, llvm::Type *Ty0,
69126913
Tys[1] = llvm::FixedVectorType::get(Int8Ty, 8);
69136914
Ops[0] = Builder.CreateExtractVector(Tys[1], Ops[0], Builder.getInt64(0));
69146915
}
6915-
llvm::Value *FPM =
6916-
EmitScalarOrConstFoldImmArg(/* ICEArguments */ 0, E->getNumArgs() - 1, E);
6917-
return EmitFP8NeonCall(CGM.getIntrinsic(IID, Tys), Ops, FPM, name);
6916+
return EmitFP8NeonCall(IID, Tys, Ops, E, name);
69186917
}
69196918

69206919
// Right-shift a vector by a constant.
@@ -14134,6 +14133,32 @@ Value *CodeGenFunction::EmitAArch64BuiltinExpr(unsigned BuiltinID,
1413414133
case NEON::BI__builtin_neon_vdotq_laneq_f32_mf8_fpm:
1413514134
return EmitFP8NeonFDOTCall(Intrinsic::aarch64_neon_fp8_fdot4_lane,
1413614135
ExtendLane, FloatTy, Ops, E, "fdot4_lane");
14136+
14137+
case NEON::BI__builtin_neon_vmlalbq_f16_mf8_fpm:
14138+
return EmitFP8NeonCall(Intrinsic::aarch64_neon_fp8_fmlalb,
14139+
{llvm::FixedVectorType::get(HalfTy, 8)}, Ops, E,
14140+
"vmlal");
14141+
case NEON::BI__builtin_neon_vmlaltq_f16_mf8_fpm:
14142+
return EmitFP8NeonCall(Intrinsic::aarch64_neon_fp8_fmlalt,
14143+
{llvm::FixedVectorType::get(HalfTy, 8)}, Ops, E,
14144+
"vmlal");
14145+
case NEON::BI__builtin_neon_vmlallbbq_f32_mf8_fpm:
14146+
return EmitFP8NeonCall(Intrinsic::aarch64_neon_fp8_fmlallbb,
14147+
{llvm::FixedVectorType::get(FloatTy, 4)}, Ops, E,
14148+
"vmlall");
14149+
case NEON::BI__builtin_neon_vmlallbtq_f32_mf8_fpm:
14150+
return EmitFP8NeonCall(Intrinsic::aarch64_neon_fp8_fmlallbt,
14151+
{llvm::FixedVectorType::get(FloatTy, 4)}, Ops, E,
14152+
"vmlall");
14153+
case NEON::BI__builtin_neon_vmlalltbq_f32_mf8_fpm:
14154+
return EmitFP8NeonCall(Intrinsic::aarch64_neon_fp8_fmlalltb,
14155+
{llvm::FixedVectorType::get(FloatTy, 4)}, Ops, E,
14156+
"vmlall");
14157+
case NEON::BI__builtin_neon_vmlallttq_f32_mf8_fpm:
14158+
return EmitFP8NeonCall(Intrinsic::aarch64_neon_fp8_fmlalltt,
14159+
{llvm::FixedVectorType::get(FloatTy, 4)}, Ops, E,
14160+
"vmlall");
14161+
1413714162
case NEON::BI__builtin_neon_vamin_f16:
1413814163
case NEON::BI__builtin_neon_vaminq_f16:
1413914164
case NEON::BI__builtin_neon_vamin_f32:

clang/lib/CodeGen/CodeGenFunction.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4648,9 +4648,9 @@ class CodeGenFunction : public CodeGenTypeCache {
46484648
SmallVectorImpl<llvm::Value*> &O,
46494649
const char *name,
46504650
unsigned shift = 0, bool rightshift = false);
4651-
llvm::Value *EmitFP8NeonCall(llvm::Function *F,
4651+
llvm::Value *EmitFP8NeonCall(unsigned IID, ArrayRef<llvm::Type *> Tys,
46524652
SmallVectorImpl<llvm::Value *> &O,
4653-
llvm::Value *FPM, const char *name);
4653+
const CallExpr *E, const char *name);
46544654
llvm::Value *EmitFP8NeonCvtCall(unsigned IID, llvm::Type *Ty0,
46554655
llvm::Type *Ty1, bool Extract,
46564656
SmallVectorImpl<llvm::Value *> &Ops,
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
// NOTE: Assertions have been autogenerated by utils/update_cc_test_checks.py UTC_ARGS: --version 5
2+
// RUN: %clang_cc1 -triple aarch64-none-linux-gnu -target-feature +neon -target-feature +bf16 -target-feature +fp8 -target-feature +fp8fma -disable-O0-optnone -Werror -Wall -emit-llvm -o - %s | opt -S -p mem2reg,instcombine,tailcallelim | FileCheck %s
3+
// RUN: %clang_cc1 -x c++ -triple aarch64-none-linux-gnu -target-feature +neon -target-feature +bf16 -target-feature +fp8 -target-feature +fp8fma -disable-O0-optnone -Werror -Wall -emit-llvm -o - %s | opt -S -p mem2reg,instcombine,tailcallelim | FileCheck %s -check-prefix CHECK-CXX
4+
5+
// RUN: %clang_cc1 -triple aarch64-none-linux-gnu -target-feature +neon -target-feature +bf16 -target-feature +fp8 -target-feature +fp8fma -O3 -Werror -Wall -S -o /dev/null %s
6+
7+
// REQUIRES: aarch64-registered-target
8+
9+
#include <arm_neon.h>
10+
11+
// CHECK-LABEL: define dso_local <8 x half> @test_vmlalb(
12+
// CHECK-SAME: <8 x half> noundef [[VD:%.*]], <16 x i8> [[VN:%.*]], <16 x i8> [[VM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0:[0-9]+]] {
13+
// CHECK-NEXT: [[ENTRY:.*:]]
14+
// CHECK-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPM]])
15+
// CHECK-NEXT: [[VMLAL1_I:%.*]] = tail call <8 x half> @llvm.aarch64.neon.fp8.fmlalb.v8f16(<8 x half> [[VD]], <16 x i8> [[VN]], <16 x i8> [[VM]])
16+
// CHECK-NEXT: ret <8 x half> [[VMLAL1_I]]
17+
//
18+
// CHECK-CXX-LABEL: define dso_local noundef <8 x half> @_Z11test_vmlalb13__Float16x8_t14__Mfloat8x16_tS0_m(
19+
// CHECK-CXX-SAME: <8 x half> noundef [[VD:%.*]], <16 x i8> [[VN:%.*]], <16 x i8> [[VM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0:[0-9]+]] {
20+
// CHECK-CXX-NEXT: [[ENTRY:.*:]]
21+
// CHECK-CXX-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPM]])
22+
// CHECK-CXX-NEXT: [[VMLAL1_I:%.*]] = tail call <8 x half> @llvm.aarch64.neon.fp8.fmlalb.v8f16(<8 x half> [[VD]], <16 x i8> [[VN]], <16 x i8> [[VM]])
23+
// CHECK-CXX-NEXT: ret <8 x half> [[VMLAL1_I]]
24+
//
25+
float16x8_t test_vmlalb(float16x8_t vd, mfloat8x16_t vn, mfloat8x16_t vm, fpm_t fpm) {
26+
return vmlalbq_f16_mf8_fpm(vd, vn, vm, fpm);
27+
}
28+
29+
// CHECK-LABEL: define dso_local <8 x half> @test_vmlalt(
30+
// CHECK-SAME: <8 x half> noundef [[VD:%.*]], <16 x i8> [[VN:%.*]], <16 x i8> [[VM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] {
31+
// CHECK-NEXT: [[ENTRY:.*:]]
32+
// CHECK-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPM]])
33+
// CHECK-NEXT: [[VMLAL1_I:%.*]] = tail call <8 x half> @llvm.aarch64.neon.fp8.fmlalt.v8f16(<8 x half> [[VD]], <16 x i8> [[VN]], <16 x i8> [[VM]])
34+
// CHECK-NEXT: ret <8 x half> [[VMLAL1_I]]
35+
//
36+
// CHECK-CXX-LABEL: define dso_local noundef <8 x half> @_Z11test_vmlalt13__Float16x8_t14__Mfloat8x16_tS0_m(
37+
// CHECK-CXX-SAME: <8 x half> noundef [[VD:%.*]], <16 x i8> [[VN:%.*]], <16 x i8> [[VM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] {
38+
// CHECK-CXX-NEXT: [[ENTRY:.*:]]
39+
// CHECK-CXX-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPM]])
40+
// CHECK-CXX-NEXT: [[VMLAL1_I:%.*]] = tail call <8 x half> @llvm.aarch64.neon.fp8.fmlalt.v8f16(<8 x half> [[VD]], <16 x i8> [[VN]], <16 x i8> [[VM]])
41+
// CHECK-CXX-NEXT: ret <8 x half> [[VMLAL1_I]]
42+
//
43+
float16x8_t test_vmlalt(float16x8_t vd, mfloat8x16_t vn, mfloat8x16_t vm, fpm_t fpm) {
44+
return vmlaltq_f16_mf8_fpm(vd, vn, vm, fpm);
45+
}
46+
47+
// CHECK-LABEL: define dso_local <4 x float> @test_vmlallbb(
48+
// CHECK-SAME: <4 x float> noundef [[VD:%.*]], <16 x i8> [[VN:%.*]], <16 x i8> [[VM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] {
49+
// CHECK-NEXT: [[ENTRY:.*:]]
50+
// CHECK-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPM]])
51+
// CHECK-NEXT: [[VMLALL_I:%.*]] = tail call <4 x float> @llvm.aarch64.neon.fp8.fmlallbb.v4f32(<4 x float> [[VD]], <16 x i8> [[VN]], <16 x i8> [[VM]])
52+
// CHECK-NEXT: ret <4 x float> [[VMLALL_I]]
53+
//
54+
// CHECK-CXX-LABEL: define dso_local noundef <4 x float> @_Z13test_vmlallbb13__Float32x4_t14__Mfloat8x16_tS0_m(
55+
// CHECK-CXX-SAME: <4 x float> noundef [[VD:%.*]], <16 x i8> [[VN:%.*]], <16 x i8> [[VM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] {
56+
// CHECK-CXX-NEXT: [[ENTRY:.*:]]
57+
// CHECK-CXX-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPM]])
58+
// CHECK-CXX-NEXT: [[VMLALL_I:%.*]] = tail call <4 x float> @llvm.aarch64.neon.fp8.fmlallbb.v4f32(<4 x float> [[VD]], <16 x i8> [[VN]], <16 x i8> [[VM]])
59+
// CHECK-CXX-NEXT: ret <4 x float> [[VMLALL_I]]
60+
//
61+
float32x4_t test_vmlallbb(float32x4_t vd, mfloat8x16_t vn, mfloat8x16_t vm, fpm_t fpm) {
62+
return vmlallbbq_f32_mf8_fpm(vd, vn, vm, fpm);
63+
}
64+
65+
// CHECK-LABEL: define dso_local <4 x float> @test_vmlallbt(
66+
// CHECK-SAME: <4 x float> noundef [[VD:%.*]], <16 x i8> [[VN:%.*]], <16 x i8> [[VM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] {
67+
// CHECK-NEXT: [[ENTRY:.*:]]
68+
// CHECK-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPM]])
69+
// CHECK-NEXT: [[VMLALL_I:%.*]] = tail call <4 x float> @llvm.aarch64.neon.fp8.fmlallbt.v4f32(<4 x float> [[VD]], <16 x i8> [[VN]], <16 x i8> [[VM]])
70+
// CHECK-NEXT: ret <4 x float> [[VMLALL_I]]
71+
//
72+
// CHECK-CXX-LABEL: define dso_local noundef <4 x float> @_Z13test_vmlallbt13__Float32x4_t14__Mfloat8x16_tS0_m(
73+
// CHECK-CXX-SAME: <4 x float> noundef [[VD:%.*]], <16 x i8> [[VN:%.*]], <16 x i8> [[VM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] {
74+
// CHECK-CXX-NEXT: [[ENTRY:.*:]]
75+
// CHECK-CXX-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPM]])
76+
// CHECK-CXX-NEXT: [[VMLALL_I:%.*]] = tail call <4 x float> @llvm.aarch64.neon.fp8.fmlallbt.v4f32(<4 x float> [[VD]], <16 x i8> [[VN]], <16 x i8> [[VM]])
77+
// CHECK-CXX-NEXT: ret <4 x float> [[VMLALL_I]]
78+
//
79+
float32x4_t test_vmlallbt(float32x4_t vd, mfloat8x16_t vn, mfloat8x16_t vm, fpm_t fpm) {
80+
return vmlallbtq_f32_mf8_fpm(vd, vn, vm, fpm);
81+
}
82+
83+
// CHECK-LABEL: define dso_local <4 x float> @test_vmlalltb(
84+
// CHECK-SAME: <4 x float> noundef [[VD:%.*]], <16 x i8> [[VN:%.*]], <16 x i8> [[VM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] {
85+
// CHECK-NEXT: [[ENTRY:.*:]]
86+
// CHECK-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPM]])
87+
// CHECK-NEXT: [[VMLALL_I:%.*]] = tail call <4 x float> @llvm.aarch64.neon.fp8.fmlalltb.v4f32(<4 x float> [[VD]], <16 x i8> [[VN]], <16 x i8> [[VM]])
88+
// CHECK-NEXT: ret <4 x float> [[VMLALL_I]]
89+
//
90+
// CHECK-CXX-LABEL: define dso_local noundef <4 x float> @_Z13test_vmlalltb13__Float32x4_t14__Mfloat8x16_tS0_m(
91+
// CHECK-CXX-SAME: <4 x float> noundef [[VD:%.*]], <16 x i8> [[VN:%.*]], <16 x i8> [[VM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] {
92+
// CHECK-CXX-NEXT: [[ENTRY:.*:]]
93+
// CHECK-CXX-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPM]])
94+
// CHECK-CXX-NEXT: [[VMLALL_I:%.*]] = tail call <4 x float> @llvm.aarch64.neon.fp8.fmlalltb.v4f32(<4 x float> [[VD]], <16 x i8> [[VN]], <16 x i8> [[VM]])
95+
// CHECK-CXX-NEXT: ret <4 x float> [[VMLALL_I]]
96+
//
97+
float32x4_t test_vmlalltb(float32x4_t vd, mfloat8x16_t vn, mfloat8x16_t vm, fpm_t fpm) {
98+
return vmlalltbq_f32_mf8_fpm(vd, vn, vm, fpm);
99+
}
100+
101+
// CHECK-LABEL: define dso_local <4 x float> @test_vmlalltt(
102+
// CHECK-SAME: <4 x float> noundef [[VD:%.*]], <16 x i8> [[VN:%.*]], <16 x i8> [[VM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] {
103+
// CHECK-NEXT: [[ENTRY:.*:]]
104+
// CHECK-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPM]])
105+
// CHECK-NEXT: [[VMLALL_I:%.*]] = tail call <4 x float> @llvm.aarch64.neon.fp8.fmlalltt.v4f32(<4 x float> [[VD]], <16 x i8> [[VN]], <16 x i8> [[VM]])
106+
// CHECK-NEXT: ret <4 x float> [[VMLALL_I]]
107+
//
108+
// CHECK-CXX-LABEL: define dso_local noundef <4 x float> @_Z13test_vmlalltt13__Float32x4_t14__Mfloat8x16_tS0_m(
109+
// CHECK-CXX-SAME: <4 x float> noundef [[VD:%.*]], <16 x i8> [[VN:%.*]], <16 x i8> [[VM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] {
110+
// CHECK-CXX-NEXT: [[ENTRY:.*:]]
111+
// CHECK-CXX-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPM]])
112+
// CHECK-CXX-NEXT: [[VMLALL_I:%.*]] = tail call <4 x float> @llvm.aarch64.neon.fp8.fmlalltt.v4f32(<4 x float> [[VD]], <16 x i8> [[VN]], <16 x i8> [[VM]])
113+
// CHECK-CXX-NEXT: ret <4 x float> [[VMLALL_I]]
114+
//
115+
float32x4_t test_vmlalltt(float32x4_t vd, mfloat8x16_t vn, mfloat8x16_t vm, fpm_t fpm) {
116+
return vmlallttq_f32_mf8_fpm(vd, vn, vm, fpm);
117+
}
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
// RUN: %clang_cc1 -triple aarch64-linux-gnu -target-feature +neon -target-feature +bf16 -target-feature +faminmax -target-feature +fp8 -emit-llvm -verify %s -o /dev/null
2+
3+
// REQUIRES: aarch64-registered-target
4+
5+
#include <arm_neon.h>
6+
7+
void test_features(float16x8_t a, float32x4_t b, mfloat8x16_t u, fpm_t fpm) {
8+
9+
(void) vmlalbq_f16_mf8_fpm(a, u, u, fpm);
10+
// expected-error@-1 {{'vmlalbq_f16_mf8_fpm' requires target feature 'fp8fma'}}
11+
(void) vmlaltq_f16_mf8_fpm(a, u, u, fpm);
12+
// expected-error@-1 {{'vmlaltq_f16_mf8_fpm' requires target feature 'fp8fma'}}
13+
(void) vmlallbbq_f32_mf8_fpm(b, u, u, fpm);
14+
// expected-error@-1 {{'vmlallbbq_f32_mf8_fpm' requires target feature 'fp8fma'}}
15+
(void) vmlallbtq_f32_mf8_fpm(b, u, u, fpm);
16+
// expected-error@-1 {{'vmlallbtq_f32_mf8_fpm' requires target feature 'fp8fma'}}
17+
(void) vmlalltbq_f32_mf8_fpm(b, u, u, fpm);
18+
// expected-error@-1 {{'vmlalltbq_f32_mf8_fpm' requires target feature 'fp8fma'}}
19+
(void) vmlallttq_f32_mf8_fpm(b, u, u, fpm);
20+
// expected-error@-1 {{'vmlallttq_f32_mf8_fpm' requires target feature 'fp8fma'}}
21+
}
22+

llvm/include/llvm/IR/IntrinsicsAArch64.td

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1047,6 +1047,23 @@ def int_aarch64_st64bv0: Intrinsic<[llvm_i64_ty], !listconcat([llvm_ptr_ty], dat
10471047

10481048
def int_aarch64_neon_fp8_fdot4 : AdvSIMD_FP8_DOT_Intrinsic;
10491049
def int_aarch64_neon_fp8_fdot4_lane : AdvSIMD_FP8_DOT_LANE_Intrinsic;
1050+
1051+
1052+
// Fused multiply-add
1053+
class AdvSIMD_FP8_FMLA_Intrinsic
1054+
: DefaultAttrsIntrinsic<[llvm_anyvector_ty],
1055+
[LLVMMatchType<0>,
1056+
llvm_v16i8_ty,
1057+
llvm_v16i8_ty],
1058+
[IntrReadMem, IntrInaccessibleMemOnly]>;
1059+
1060+
def int_aarch64_neon_fp8_fmlalb : AdvSIMD_FP8_FMLA_Intrinsic;
1061+
def int_aarch64_neon_fp8_fmlalt : AdvSIMD_FP8_FMLA_Intrinsic;
1062+
1063+
def int_aarch64_neon_fp8_fmlallbb : AdvSIMD_FP8_FMLA_Intrinsic;
1064+
def int_aarch64_neon_fp8_fmlallbt : AdvSIMD_FP8_FMLA_Intrinsic;
1065+
def int_aarch64_neon_fp8_fmlalltb : AdvSIMD_FP8_FMLA_Intrinsic;
1066+
def int_aarch64_neon_fp8_fmlalltt : AdvSIMD_FP8_FMLA_Intrinsic;
10501067
}
10511068

10521069
def llvm_nxv1i1_ty : LLVMType<nxv1i1>;

llvm/lib/Target/AArch64/AArch64InstrFormats.td

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6511,14 +6511,15 @@ multiclass SIMDThreeSameVectorFML<bit U, bit b13, bits<3> size, string asm,
65116511
v4f32, v8f16, OpNode>;
65126512
}
65136513

6514-
multiclass SIMDThreeSameVectorMLA<bit Q, string asm>{
6514+
multiclass SIMDThreeSameVectorMLA<bit Q, string asm, SDPatternOperator op> {
6515+
65156516
def v8f16 : BaseSIMDThreeSameVectorDot<Q, 0b0, 0b11, 0b1111, asm, ".8h", ".16b",
6516-
V128, v8f16, v16i8, null_frag>;
6517+
V128, v8f16, v16i8, op>;
65176518
}
65186519

6519-
multiclass SIMDThreeSameVectorMLAL<bit Q, bits<2> sz, string asm>{
6520+
multiclass SIMDThreeSameVectorMLAL<bit Q, bits<2> sz, string asm, SDPatternOperator op> {
65206521
def v4f32 : BaseSIMDThreeSameVectorDot<Q, 0b0, sz, 0b1000, asm, ".4s", ".16b",
6521-
V128, v4f32, v16i8, null_frag>;
6522+
V128, v4f32, v16i8, op>;
65226523
}
65236524

65246525
// FP8 assembly/disassembly classes

llvm/lib/Target/AArch64/AArch64InstrInfo.td

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10358,20 +10358,20 @@ let Predicates = [HasNEON, HasFAMINMAX] in {
1035810358
defm FAMIN : SIMDThreeSameVectorFP<0b1, 0b1, 0b011, "famin", AArch64famin>;
1035910359
} // End let Predicates = [HasNEON, HasFAMINMAX]
1036010360

10361-
let Uses = [FPMR, FPCR], Predicates = [HasFP8FMA] in {
10361+
let Predicates = [HasFP8FMA], Uses = [FPMR, FPCR], mayLoad = 1 in {
1036210362
defm FMLALBlane : SIMDThreeSameVectorMLAIndex<0b0, "fmlalb">;
1036310363
defm FMLALTlane : SIMDThreeSameVectorMLAIndex<0b1, "fmlalt">;
1036410364
defm FMLALLBBlane : SIMDThreeSameVectorMLALIndex<0b0, 0b00, "fmlallbb">;
1036510365
defm FMLALLBTlane : SIMDThreeSameVectorMLALIndex<0b0, 0b01, "fmlallbt">;
1036610366
defm FMLALLTBlane : SIMDThreeSameVectorMLALIndex<0b1, 0b00, "fmlalltb">;
1036710367
defm FMLALLTTlane : SIMDThreeSameVectorMLALIndex<0b1, 0b01, "fmlalltt">;
1036810368

10369-
defm FMLALB : SIMDThreeSameVectorMLA<0b0, "fmlalb">;
10370-
defm FMLALT : SIMDThreeSameVectorMLA<0b1, "fmlalt">;
10371-
defm FMLALLBB : SIMDThreeSameVectorMLAL<0b0, 0b00, "fmlallbb">;
10372-
defm FMLALLBT : SIMDThreeSameVectorMLAL<0b0, 0b01, "fmlallbt">;
10373-
defm FMLALLTB : SIMDThreeSameVectorMLAL<0b1, 0b00, "fmlalltb">;
10374-
defm FMLALLTT : SIMDThreeSameVectorMLAL<0b1, 0b01, "fmlalltt">;
10369+
defm FMLALB : SIMDThreeSameVectorMLA<0b0, "fmlalb", int_aarch64_neon_fp8_fmlalb>;
10370+
defm FMLALT : SIMDThreeSameVectorMLA<0b1, "fmlalt", int_aarch64_neon_fp8_fmlalt>;
10371+
defm FMLALLBB : SIMDThreeSameVectorMLAL<0b0, 0b00, "fmlallbb", int_aarch64_neon_fp8_fmlallbb>;
10372+
defm FMLALLBT : SIMDThreeSameVectorMLAL<0b0, 0b01, "fmlallbt", int_aarch64_neon_fp8_fmlallbt>;
10373+
defm FMLALLTB : SIMDThreeSameVectorMLAL<0b1, 0b00, "fmlalltb", int_aarch64_neon_fp8_fmlalltb>;
10374+
defm FMLALLTT : SIMDThreeSameVectorMLAL<0b1, 0b01, "fmlalltt", int_aarch64_neon_fp8_fmlalltt>;
1037510375
} // End let Predicates = [HasFP8FMA]
1037610376

1037710377
let Predicates = [HasFP8DOT2] in {

0 commit comments

Comments
 (0)