Skip to content

Commit f75860f

Browse files
[AArch64] Implement NEON FP8 intrinsics for fused multiply-add (llvm#123615)
This patch adds the following intrinsics: * Fused multiply-add non-indexed float16x8_t vmlalbq_f16_mf8_fpm(float16x8_t, mfloat8x16_t, mfloat8x16_t, fpm_t) float16x8_t vmlaltq_f16_mf8_fpm(float16x8_t, mfloat8x16_t, mfloat8x16_t, fpm_t) float32x4_t vmlallbbq_f32_mf8_fpm(float32x4_t, mfloat8x16_t, mfloat8x16_t, fpm_t) float32x4_t vmlallbtq_f32_mf8_fpm(float32x4_t, mfloat8x16_t, mfloat8x16_t, fpm_t) float32x4_t vmlalltbq_f32_mf8_fpm(float32x4_t, mfloat8x16_t, mfloat8x16_t, fpm_t) float32x4_t vmlallttq_f32_mf8_fpm(float32x4_t, mfloat8x16_t, mfloat8x16_t, fpm_t) * Floating-point multiply-add long to half-precision (vector, by element) float16x8_t vmlalbq_lane_f16_mf8_fpm(float16x8_t vd, mfloat8x16_t vn, mfloat8x8_t vm, __builtin_constant_p(lane), fpm_t fpm) float16x8_t vmlalbq_laneq_f16_mf8_fpm(float16x8_t vd, mfloat8x16_t vn, mfloat8x16_t vm, __builtin_constant_p(lane), fpm_t fpm) float16x8_t vmlaltq_lane_f16_mf8_fpm(float16x8_t vd, mfloat8x16_t vn, mfloat8x8_t vm, __builtin_constant_p(lane), fpm_t fpm) float16x8_t vmlaltq_laneq_f16_mf8_fpm(float16x8_t vd, mfloat8x16_t vn, mfloat8x16_t vm, __builtin_constant_p(lane), fpm_t fpm) * Floating-point multiply-add long-long to single-precision (vector, by element) float32x4_t vmlallbbq_lane_f32_mf8_fpm(float32x4_t vd, mfloat8x16_t vn, mfloat8x8_t vm, __builtin_constant_p(lane), fpm_t fpm) float32x4_t vmlallbbq_laneq_f32_mf8_fpm(float32x4_t vd, mfloat8x16_t vn, mfloat8x16_t vm, __builtin_constant_p(lane), fpm_t fpm) float32x4_t vmlallbtq_lane_f32_mf8_fpm(float32x4_t vd, mfloat8x16_t vn, mfloat8x8_t vm, __builtin_constant_p(lane), fpm_t fpm) float32x4_t vmlallbtq_laneq_f32_mf8_fpm(float32x4_t vd, mfloat8x16_t vn, mfloat8x16_t vm, __builtin_constant_p(lane), fpm_t fpm) float32x4_t vmlalltbq_lane_f32_mf8_fpm(float32x4_t vd, mfloat8x16_t vn, mfloat8x8_t vm, __builtin_constant_p(lane), fpm_t fpm) float32x4_t vmlalltbq_laneq_f32_mf8_fpm(float32x4_t vd, mfloat8x16_t vn, mfloat8x16_t vm, __builtin_constant_p(lane), fpm_t fpm) float32x4_t vmlallttq_lane_f32_mf8_fpm(float32x4_t vd, mfloat8x16_t vn, mfloat8x8_t vm, __builtin_constant_p(lane), fpm_t fpm) float32x4_t vmlallttq_laneq_f32_mf8_fpm(float32x4_t vd, mfloat8x16_t vn, mfloat8x16_t vm, __builtin_constant_p(lane), fpm_t fpm)
1 parent f949f87 commit f75860f

File tree

9 files changed

+717
-44
lines changed

9 files changed

+717
-44
lines changed

clang/include/clang/Basic/arm_neon.td

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2161,6 +2161,30 @@ let ArchGuard = "defined(__aarch64__)", TargetGuard = "fp8dot4,neon" in {
21612161
def VDOTQ_LANEQ_F32_MF8 : VInst<"vdot_laneq_f32_mf8_fpm", "(>>F)(>>F)..IV", "Qm", [ImmCheck<3, ImmCheck0_3, 0>]>;
21622162
}
21632163

2164+
let ArchGuard = "defined(__aarch64__)", TargetGuard = "fp8fma,neon" in {
2165+
def VMLALB_F16_F8 : VInst<"vmlalb_f16_mf8_fpm", "(>F)(>F)..V", "Qm">;
2166+
def VMLALT_F16_F8 : VInst<"vmlalt_f16_mf8_fpm", "(>F)(>F)..V", "Qm">;
2167+
2168+
def VMLALLBB_F32_F8 : VInst<"vmlallbb_f32_mf8_fpm", "(>>F)(>>F)..V", "Qm">;
2169+
def VMLALLBT_F32_F8 : VInst<"vmlallbt_f32_mf8_fpm", "(>>F)(>>F)..V", "Qm">;
2170+
def VMLALLTB_F32_F8 : VInst<"vmlalltb_f32_mf8_fpm", "(>>F)(>>F)..V", "Qm">;
2171+
def VMLALLTT_F32_F8 : VInst<"vmlalltt_f32_mf8_fpm", "(>>F)(>>F)..V", "Qm">;
2172+
2173+
def VMLALB_F16_F8_LANE : VInst<"vmlalb_lane_f16_mf8_fpm", "(>F)(>F).qIV", "Qm", [ImmCheck<3, ImmCheck0_7, 0>]>;
2174+
def VMLALB_F16_F8_LANEQ : VInst<"vmlalb_laneq_f16_mf8_fpm", "(>F)(>F)..IV", "Qm", [ImmCheck<3, ImmCheck0_15, 0>]>;
2175+
def VMLALT_F16_F8_LANE : VInst<"vmlalt_lane_f16_mf8_fpm", "(>F)(>F).qIV", "Qm", [ImmCheck<3, ImmCheck0_7, 0>]>;
2176+
def VMLALT_F16_F8_LANEQ : VInst<"vmlalt_laneq_f16_mf8_fpm", "(>F)(>F)..IV", "Qm", [ImmCheck<3, ImmCheck0_15, 0>]>;
2177+
2178+
def VMLALLBB_F32_F8_LANE : VInst<"vmlallbb_lane_f32_mf8_fpm", "(>>F)(>>F).qIV", "Qm", [ImmCheck<3, ImmCheck0_7, 0>]>;
2179+
def VMLALLBB_F32_F8_LANEQ : VInst<"vmlallbb_laneq_f32_mf8_fpm", "(>>F)(>>F)..IV", "Qm", [ImmCheck<3, ImmCheck0_15, 0>]>;
2180+
def VMLALLBT_F32_F8_LANE : VInst<"vmlallbt_lane_f32_mf8_fpm", "(>>F)(>>F).qIV", "Qm", [ImmCheck<3, ImmCheck0_7, 0>]>;
2181+
def VMLALLBT_F32_F8_LANEQ : VInst<"vmlallbt_laneq_f32_mf8_fpm", "(>>F)(>>F)..IV", "Qm", [ImmCheck<3, ImmCheck0_15, 0>]>;
2182+
def VMLALLTB_F32_F8_LANE : VInst<"vmlalltb_lane_f32_mf8_fpm", "(>>F)(>>F).qIV", "Qm", [ImmCheck<3, ImmCheck0_7, 0>]>;
2183+
def VMLALLTB_F32_F8_LANEQ : VInst<"vmlalltb_laneq_f32_mf8_fpm", "(>>F)(>>F)..IV", "Qm", [ImmCheck<3, ImmCheck0_15, 0>]>;
2184+
def VMLALLTT_F32_F8_LANE : VInst<"vmlalltt_lane_f32_mf8_fpm", "(>>F)(>>F).qIV", "Qm", [ImmCheck<3, ImmCheck0_7, 0>]>;
2185+
def VMLALLTT_F32_F8_LANEQ : VInst<"vmlalltt_laneq_f32_mf8_fpm", "(>>F)(>>F)..IV", "Qm", [ImmCheck<3, ImmCheck0_15, 0>]>;
2186+
}
2187+
21642188
let ArchGuard = "defined(__aarch64__)", TargetGuard = "neon,faminmax" in {
21652189
def FAMIN : WInst<"vamin", "...", "fhQdQfQh">;
21662190
def FAMAX : WInst<"vamax", "...", "fhQdQfQh">;

clang/lib/CodeGen/CGBuiltin.cpp

Lines changed: 91 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6759,29 +6759,45 @@ Value *CodeGenFunction::EmitNeonCall(Function *F, SmallVectorImpl<Value*> &Ops,
67596759
return Builder.CreateCall(F, Ops, name);
67606760
}
67616761

6762-
Value *CodeGenFunction::EmitFP8NeonCall(Function *F,
6762+
Value *CodeGenFunction::EmitFP8NeonCall(unsigned IID,
6763+
ArrayRef<llvm::Type *> Tys,
67636764
SmallVectorImpl<Value *> &Ops,
6764-
Value *FPM, const char *name) {
6765+
const CallExpr *E, const char *name) {
6766+
llvm::Value *FPM =
6767+
EmitScalarOrConstFoldImmArg(/* ICEArguments */ 0, E->getNumArgs() - 1, E);
67656768
Builder.CreateCall(CGM.getIntrinsic(Intrinsic::aarch64_set_fpmr), FPM);
6766-
return EmitNeonCall(F, Ops, name);
6769+
return EmitNeonCall(CGM.getIntrinsic(IID, Tys), Ops, name);
67676770
}
67686771

67696772
llvm::Value *CodeGenFunction::EmitFP8NeonFDOTCall(
6770-
unsigned IID, bool ExtendLane, llvm::Type *RetTy,
6773+
unsigned IID, bool ExtendLaneArg, llvm::Type *RetTy,
67716774
SmallVectorImpl<llvm::Value *> &Ops, const CallExpr *E, const char *name) {
67726775

67736776
const unsigned ElemCount = Ops[0]->getType()->getPrimitiveSizeInBits() /
67746777
RetTy->getPrimitiveSizeInBits();
67756778
llvm::Type *Tys[] = {llvm::FixedVectorType::get(RetTy, ElemCount),
67766779
Ops[1]->getType()};
6777-
if (ExtendLane) {
6780+
if (ExtendLaneArg) {
67786781
auto *VT = llvm::FixedVectorType::get(Int8Ty, 16);
67796782
Ops[2] = Builder.CreateInsertVector(VT, PoisonValue::get(VT), Ops[2],
67806783
Builder.getInt64(0));
67816784
}
6782-
llvm::Value *FPM =
6783-
EmitScalarOrConstFoldImmArg(/* ICEArguments */ 0, E->getNumArgs() - 1, E);
6784-
return EmitFP8NeonCall(CGM.getIntrinsic(IID, Tys), Ops, FPM, name);
6785+
return EmitFP8NeonCall(IID, Tys, Ops, E, name);
6786+
}
6787+
6788+
llvm::Value *CodeGenFunction::EmitFP8NeonFMLACall(
6789+
unsigned IID, bool ExtendLaneArg, llvm::Type *RetTy,
6790+
SmallVectorImpl<llvm::Value *> &Ops, const CallExpr *E, const char *name) {
6791+
6792+
if (ExtendLaneArg) {
6793+
auto *VT = llvm::FixedVectorType::get(Int8Ty, 16);
6794+
Ops[2] = Builder.CreateInsertVector(VT, PoisonValue::get(VT), Ops[2],
6795+
Builder.getInt64(0));
6796+
}
6797+
const unsigned ElemCount = Ops[0]->getType()->getPrimitiveSizeInBits() /
6798+
RetTy->getPrimitiveSizeInBits();
6799+
return EmitFP8NeonCall(IID, {llvm::FixedVectorType::get(RetTy, ElemCount)},
6800+
Ops, E, name);
67856801
}
67866802

67876803
Value *CodeGenFunction::EmitNeonShiftVector(Value *V, llvm::Type *Ty,
@@ -6802,9 +6818,7 @@ Value *CodeGenFunction::EmitFP8NeonCvtCall(unsigned IID, llvm::Type *Ty0,
68026818
Tys[1] = llvm::FixedVectorType::get(Int8Ty, 8);
68036819
Ops[0] = Builder.CreateExtractVector(Tys[1], Ops[0], Builder.getInt64(0));
68046820
}
6805-
llvm::Value *FPM =
6806-
EmitScalarOrConstFoldImmArg(/* ICEArguments */ 0, E->getNumArgs() - 1, E);
6807-
return EmitFP8NeonCall(CGM.getIntrinsic(IID, Tys), Ops, FPM, name);
6821+
return EmitFP8NeonCall(IID, Tys, Ops, E, name);
68086822
}
68096823

68106824
// Right-shift a vector by a constant.
@@ -12779,7 +12793,7 @@ Value *CodeGenFunction::EmitAArch64BuiltinExpr(unsigned BuiltinID,
1277912793

1278012794
unsigned Int;
1278112795
bool ExtractLow = false;
12782-
bool ExtendLane = false;
12796+
bool ExtendLaneArg = false;
1278312797
switch (BuiltinID) {
1278412798
default: return nullptr;
1278512799
case NEON::BI__builtin_neon_vbsl_v:
@@ -14054,24 +14068,85 @@ Value *CodeGenFunction::EmitAArch64BuiltinExpr(unsigned BuiltinID,
1405414068
Ops, E, "fdot2");
1405514069
case NEON::BI__builtin_neon_vdot_lane_f16_mf8_fpm:
1405614070
case NEON::BI__builtin_neon_vdotq_lane_f16_mf8_fpm:
14057-
ExtendLane = true;
14071+
ExtendLaneArg = true;
1405814072
LLVM_FALLTHROUGH;
1405914073
case NEON::BI__builtin_neon_vdot_laneq_f16_mf8_fpm:
1406014074
case NEON::BI__builtin_neon_vdotq_laneq_f16_mf8_fpm:
1406114075
return EmitFP8NeonFDOTCall(Intrinsic::aarch64_neon_fp8_fdot2_lane,
14062-
ExtendLane, HalfTy, Ops, E, "fdot2_lane");
14076+
ExtendLaneArg, HalfTy, Ops, E, "fdot2_lane");
1406314077
case NEON::BI__builtin_neon_vdot_f32_mf8_fpm:
1406414078
case NEON::BI__builtin_neon_vdotq_f32_mf8_fpm:
1406514079
return EmitFP8NeonFDOTCall(Intrinsic::aarch64_neon_fp8_fdot4, false,
1406614080
FloatTy, Ops, E, "fdot4");
1406714081
case NEON::BI__builtin_neon_vdot_lane_f32_mf8_fpm:
1406814082
case NEON::BI__builtin_neon_vdotq_lane_f32_mf8_fpm:
14069-
ExtendLane = true;
14083+
ExtendLaneArg = true;
1407014084
LLVM_FALLTHROUGH;
1407114085
case NEON::BI__builtin_neon_vdot_laneq_f32_mf8_fpm:
1407214086
case NEON::BI__builtin_neon_vdotq_laneq_f32_mf8_fpm:
1407314087
return EmitFP8NeonFDOTCall(Intrinsic::aarch64_neon_fp8_fdot4_lane,
14074-
ExtendLane, FloatTy, Ops, E, "fdot4_lane");
14088+
ExtendLaneArg, FloatTy, Ops, E, "fdot4_lane");
14089+
14090+
case NEON::BI__builtin_neon_vmlalbq_f16_mf8_fpm:
14091+
return EmitFP8NeonCall(Intrinsic::aarch64_neon_fp8_fmlalb,
14092+
{llvm::FixedVectorType::get(HalfTy, 8)}, Ops, E,
14093+
"vmlal");
14094+
case NEON::BI__builtin_neon_vmlaltq_f16_mf8_fpm:
14095+
return EmitFP8NeonCall(Intrinsic::aarch64_neon_fp8_fmlalt,
14096+
{llvm::FixedVectorType::get(HalfTy, 8)}, Ops, E,
14097+
"vmlal");
14098+
case NEON::BI__builtin_neon_vmlallbbq_f32_mf8_fpm:
14099+
return EmitFP8NeonCall(Intrinsic::aarch64_neon_fp8_fmlallbb,
14100+
{llvm::FixedVectorType::get(FloatTy, 4)}, Ops, E,
14101+
"vmlall");
14102+
case NEON::BI__builtin_neon_vmlallbtq_f32_mf8_fpm:
14103+
return EmitFP8NeonCall(Intrinsic::aarch64_neon_fp8_fmlallbt,
14104+
{llvm::FixedVectorType::get(FloatTy, 4)}, Ops, E,
14105+
"vmlall");
14106+
case NEON::BI__builtin_neon_vmlalltbq_f32_mf8_fpm:
14107+
return EmitFP8NeonCall(Intrinsic::aarch64_neon_fp8_fmlalltb,
14108+
{llvm::FixedVectorType::get(FloatTy, 4)}, Ops, E,
14109+
"vmlall");
14110+
case NEON::BI__builtin_neon_vmlallttq_f32_mf8_fpm:
14111+
return EmitFP8NeonCall(Intrinsic::aarch64_neon_fp8_fmlalltt,
14112+
{llvm::FixedVectorType::get(FloatTy, 4)}, Ops, E,
14113+
"vmlall");
14114+
case NEON::BI__builtin_neon_vmlalbq_lane_f16_mf8_fpm:
14115+
ExtendLaneArg = true;
14116+
LLVM_FALLTHROUGH;
14117+
case NEON::BI__builtin_neon_vmlalbq_laneq_f16_mf8_fpm:
14118+
return EmitFP8NeonFMLACall(Intrinsic::aarch64_neon_fp8_fmlalb_lane,
14119+
ExtendLaneArg, HalfTy, Ops, E, "vmlal_lane");
14120+
case NEON::BI__builtin_neon_vmlaltq_lane_f16_mf8_fpm:
14121+
ExtendLaneArg = true;
14122+
LLVM_FALLTHROUGH;
14123+
case NEON::BI__builtin_neon_vmlaltq_laneq_f16_mf8_fpm:
14124+
return EmitFP8NeonFMLACall(Intrinsic::aarch64_neon_fp8_fmlalt_lane,
14125+
ExtendLaneArg, HalfTy, Ops, E, "vmlal_lane");
14126+
case NEON::BI__builtin_neon_vmlallbbq_lane_f32_mf8_fpm:
14127+
ExtendLaneArg = true;
14128+
LLVM_FALLTHROUGH;
14129+
case NEON::BI__builtin_neon_vmlallbbq_laneq_f32_mf8_fpm:
14130+
return EmitFP8NeonFMLACall(Intrinsic::aarch64_neon_fp8_fmlallbb_lane,
14131+
ExtendLaneArg, FloatTy, Ops, E, "vmlall_lane");
14132+
case NEON::BI__builtin_neon_vmlallbtq_lane_f32_mf8_fpm:
14133+
ExtendLaneArg = true;
14134+
LLVM_FALLTHROUGH;
14135+
case NEON::BI__builtin_neon_vmlallbtq_laneq_f32_mf8_fpm:
14136+
return EmitFP8NeonFMLACall(Intrinsic::aarch64_neon_fp8_fmlallbt_lane,
14137+
ExtendLaneArg, FloatTy, Ops, E, "vmlall_lane");
14138+
case NEON::BI__builtin_neon_vmlalltbq_lane_f32_mf8_fpm:
14139+
ExtendLaneArg = true;
14140+
LLVM_FALLTHROUGH;
14141+
case NEON::BI__builtin_neon_vmlalltbq_laneq_f32_mf8_fpm:
14142+
return EmitFP8NeonFMLACall(Intrinsic::aarch64_neon_fp8_fmlalltb_lane,
14143+
ExtendLaneArg, FloatTy, Ops, E, "vmlall_lane");
14144+
case NEON::BI__builtin_neon_vmlallttq_lane_f32_mf8_fpm:
14145+
ExtendLaneArg = true;
14146+
LLVM_FALLTHROUGH;
14147+
case NEON::BI__builtin_neon_vmlallttq_laneq_f32_mf8_fpm:
14148+
return EmitFP8NeonFMLACall(Intrinsic::aarch64_neon_fp8_fmlalltt_lane,
14149+
ExtendLaneArg, FloatTy, Ops, E, "vmlall_lane");
1407514150
case NEON::BI__builtin_neon_vamin_f16:
1407614151
case NEON::BI__builtin_neon_vaminq_f16:
1407714152
case NEON::BI__builtin_neon_vamin_f32:

clang/lib/CodeGen/CodeGenFunction.h

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4692,14 +4692,18 @@ class CodeGenFunction : public CodeGenTypeCache {
46924692
SmallVectorImpl<llvm::Value*> &O,
46934693
const char *name,
46944694
unsigned shift = 0, bool rightshift = false);
4695-
llvm::Value *EmitFP8NeonCall(llvm::Function *F,
4695+
llvm::Value *EmitFP8NeonCall(unsigned IID, ArrayRef<llvm::Type *> Tys,
46964696
SmallVectorImpl<llvm::Value *> &O,
4697-
llvm::Value *FPM, const char *name);
4697+
const CallExpr *E, const char *name);
46984698
llvm::Value *EmitFP8NeonCvtCall(unsigned IID, llvm::Type *Ty0,
46994699
llvm::Type *Ty1, bool Extract,
47004700
SmallVectorImpl<llvm::Value *> &Ops,
47014701
const CallExpr *E, const char *name);
4702-
llvm::Value *EmitFP8NeonFDOTCall(unsigned IID, bool ExtendLane,
4702+
llvm::Value *EmitFP8NeonFDOTCall(unsigned IID, bool ExtendLaneArg,
4703+
llvm::Type *RetTy,
4704+
SmallVectorImpl<llvm::Value *> &Ops,
4705+
const CallExpr *E, const char *name);
4706+
llvm::Value *EmitFP8NeonFMLACall(unsigned IID, bool ExtendLaneArg,
47034707
llvm::Type *RetTy,
47044708
SmallVectorImpl<llvm::Value *> &Ops,
47054709
const CallExpr *E, const char *name);

0 commit comments

Comments
 (0)