Skip to content

Commit 4a79123

Browse files
[AArch64] Implement NEON FP8 intrinsics for fused multiply-add (indexed)
This patch adds the following intrinsics: * Floating-point multiply-add long to half-precision (vector, by element) float16x8_t vmlalbq_lane_f16_mf8_fpm(float16x8_t vd, mfloat8x16_t vn, mfloat8x8_t vm, __builtin_constant_p(lane), fpm_t fpm) float16x8_t vmlalbq_laneq_f16_mf8_fpm(float16x8_t vd, mfloat8x16_t vn, mfloat8x16_t vm, __builtin_constant_p(lane), fpm_t fpm) float16x8_t vmlaltq_lane_f16_mf8_fpm(float16x8_t vd, mfloat8x16_t vn, mfloat8x8_t vm, __builtin_constant_p(lane), fpm_t fpm) float16x8_t vmlaltq_laneq_f16_mf8_fpm(float16x8_t vd, mfloat8x16_t vn, mfloat8x16_t vm, __builtin_constant_p(lane), fpm_t fpm) * Floating-point multiply-add long-long to single-precision (vector, by element) float32x4_t vmlallbbq_lane_f32_mf8_fpm(float32x4_t vd, mfloat8x16_t vn, mfloat8x8_t vm, __builtin_constant_p(lane), fpm_t fpm) float32x4_t vmlallbbq_laneq_f32_mf8_fpm(float32x4_t vd, mfloat8x16_t vn, mfloat8x16_t vm, __builtin_constant_p(lane), fpm_t fpm) float32x4_t vmlallbtq_lane_f32_mf8_fpm(float32x4_t vd, mfloat8x16_t vn, mfloat8x8_t vm, __builtin_constant_p(lane), fpm_t fpm) float32x4_t vmlallbtq_laneq_f32_mf8_fpm(float32x4_t vd, mfloat8x16_t vn, mfloat8x16_t vm, __builtin_constant_p(lane), fpm_t fpm) float32x4_t vmlalltbq_lane_f32_mf8_fpm(float32x4_t vd, mfloat8x16_t vn, mfloat8x8_t vm, __builtin_constant_p(lane), fpm_t fpm) float32x4_t vmlalltbq_laneq_f32_mf8_fpm(float32x4_t vd, mfloat8x16_t vn, mfloat8x16_t vm, __builtin_constant_p(lane), fpm_t fpm) float32x4_t vmlallttq_lane_f32_mf8_fpm(float32x4_t vd, mfloat8x16_t vn, mfloat8x8_t vm, __builtin_constant_p(lane), fpm_t fpm) float32x4_t vmlallttq_laneq_f32_mf8_fpm(float32x4_t vd, mfloat8x16_t vn, mfloat8x16_t vm, __builtin_constant_p(lane), fpm_t fpm) [fixup] Update intrinsics definitions [fixup] Regenerate tests
1 parent e677f83 commit 4a79123

File tree

9 files changed

+445
-24
lines changed

9 files changed

+445
-24
lines changed

clang/include/clang/Basic/arm_neon.td

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2169,6 +2169,20 @@ let ArchGuard = "defined(__aarch64__)", TargetGuard = "fp8fma,neon" in {
21692169
def VMLALLBT_F32_F8 : VInst<"vmlallbt_f32_mf8_fpm", "(>>F)(>>F)..V", "Qm">;
21702170
def VMLALLTB_F32_F8 : VInst<"vmlalltb_f32_mf8_fpm", "(>>F)(>>F)..V", "Qm">;
21712171
def VMLALLTT_F32_F8 : VInst<"vmlalltt_f32_mf8_fpm", "(>>F)(>>F)..V", "Qm">;
2172+
2173+
def VMLALB_F16_F8_LANE : VInst<"vmlalb_lane_f16_mf8_fpm", "(>F)(>F).qIV", "Qm", [ImmCheck<3, ImmCheck0_7, 0>]>;
2174+
def VMLALB_F16_F8_LANEQ : VInst<"vmlalb_laneq_f16_mf8_fpm", "(>F)(>F)..IV", "Qm", [ImmCheck<3, ImmCheck0_15, 0>]>;
2175+
def VMLALT_F16_F8_LANE : VInst<"vmlalt_lane_f16_mf8_fpm", "(>F)(>F).qIV", "Qm", [ImmCheck<3, ImmCheck0_7, 0>]>;
2176+
def VMLALT_F16_F8_LANEQ : VInst<"vmlalt_laneq_f16_mf8_fpm", "(>F)(>F)..IV", "Qm", [ImmCheck<3, ImmCheck0_15, 0>]>;
2177+
2178+
def VMLALLBB_F32_F8_LANE : VInst<"vmlallbb_lane_f32_mf8_fpm", "(>>F)(>>F).qIV", "Qm", [ImmCheck<3, ImmCheck0_7, 0>]>;
2179+
def VMLALLBB_F32_F8_LANEQ : VInst<"vmlallbb_laneq_f32_mf8_fpm", "(>>F)(>>F)..IV", "Qm", [ImmCheck<3, ImmCheck0_15, 0>]>;
2180+
def VMLALLBT_F32_F8_LANE : VInst<"vmlallbt_lane_f32_mf8_fpm", "(>>F)(>>F).qIV", "Qm", [ImmCheck<3, ImmCheck0_7, 0>]>;
2181+
def VMLALLBT_F32_F8_LANEQ : VInst<"vmlallbt_laneq_f32_mf8_fpm", "(>>F)(>>F)..IV", "Qm", [ImmCheck<3, ImmCheck0_15, 0>]>;
2182+
def VMLALLTB_F32_F8_LANE : VInst<"vmlalltb_lane_f32_mf8_fpm", "(>>F)(>>F).qIV", "Qm", [ImmCheck<3, ImmCheck0_7, 0>]>;
2183+
def VMLALLTB_F32_F8_LANEQ : VInst<"vmlalltb_laneq_f32_mf8_fpm", "(>>F)(>>F)..IV", "Qm", [ImmCheck<3, ImmCheck0_15, 0>]>;
2184+
def VMLALLTT_F32_F8_LANE : VInst<"vmlalltt_lane_f32_mf8_fpm", "(>>F)(>>F).qIV", "Qm", [ImmCheck<3, ImmCheck0_7, 0>]>;
2185+
def VMLALLTT_F32_F8_LANEQ : VInst<"vmlalltt_laneq_f32_mf8_fpm", "(>>F)(>>F)..IV", "Qm", [ImmCheck<3, ImmCheck0_15, 0>]>;
21722186
}
21732187

21742188
let ArchGuard = "defined(__aarch64__)", TargetGuard = "neon,faminmax" in {

clang/lib/CodeGen/CGBuiltin.cpp

Lines changed: 58 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6770,21 +6770,36 @@ Value *CodeGenFunction::EmitFP8NeonCall(unsigned IID,
67706770
}
67716771

67726772
llvm::Value *CodeGenFunction::EmitFP8NeonFDOTCall(
6773-
unsigned IID, bool ExtendLane, llvm::Type *RetTy,
6773+
unsigned IID, bool ExtendLaneArg, llvm::Type *RetTy,
67746774
SmallVectorImpl<llvm::Value *> &Ops, const CallExpr *E, const char *name) {
67756775

67766776
const unsigned ElemCount = Ops[0]->getType()->getPrimitiveSizeInBits() /
67776777
RetTy->getPrimitiveSizeInBits();
67786778
llvm::Type *Tys[] = {llvm::FixedVectorType::get(RetTy, ElemCount),
67796779
Ops[1]->getType()};
6780-
if (ExtendLane) {
6780+
if (ExtendLaneArg) {
67816781
auto *VT = llvm::FixedVectorType::get(Int8Ty, 16);
67826782
Ops[2] = Builder.CreateInsertVector(VT, PoisonValue::get(VT), Ops[2],
67836783
Builder.getInt64(0));
67846784
}
67856785
return EmitFP8NeonCall(IID, Tys, Ops, E, name);
67866786
}
67876787

6788+
llvm::Value *CodeGenFunction::EmitFP8NeonFMLACall(
6789+
unsigned IID, bool ExtendLaneArg, llvm::Type *RetTy,
6790+
SmallVectorImpl<llvm::Value *> &Ops, const CallExpr *E, const char *name) {
6791+
6792+
if (ExtendLaneArg) {
6793+
auto *VT = llvm::FixedVectorType::get(Int8Ty, 16);
6794+
Ops[2] = Builder.CreateInsertVector(VT, PoisonValue::get(VT), Ops[2],
6795+
Builder.getInt64(0));
6796+
}
6797+
const unsigned ElemCount = Ops[0]->getType()->getPrimitiveSizeInBits() /
6798+
RetTy->getPrimitiveSizeInBits();
6799+
return EmitFP8NeonCall(IID, {llvm::FixedVectorType::get(RetTy, ElemCount)},
6800+
Ops, E, name);
6801+
}
6802+
67886803
Value *CodeGenFunction::EmitNeonShiftVector(Value *V, llvm::Type *Ty,
67896804
bool neg) {
67906805
int SV = cast<ConstantInt>(V)->getSExtValue();
@@ -12778,7 +12793,7 @@ Value *CodeGenFunction::EmitAArch64BuiltinExpr(unsigned BuiltinID,
1277812793

1277912794
unsigned Int;
1278012795
bool ExtractLow = false;
12781-
bool ExtendLane = false;
12796+
bool ExtendLaneArg = false;
1278212797
switch (BuiltinID) {
1278312798
default: return nullptr;
1278412799
case NEON::BI__builtin_neon_vbsl_v:
@@ -14053,24 +14068,24 @@ Value *CodeGenFunction::EmitAArch64BuiltinExpr(unsigned BuiltinID,
1405314068
Ops, E, "fdot2");
1405414069
case NEON::BI__builtin_neon_vdot_lane_f16_mf8_fpm:
1405514070
case NEON::BI__builtin_neon_vdotq_lane_f16_mf8_fpm:
14056-
ExtendLane = true;
14071+
ExtendLaneArg = true;
1405714072
LLVM_FALLTHROUGH;
1405814073
case NEON::BI__builtin_neon_vdot_laneq_f16_mf8_fpm:
1405914074
case NEON::BI__builtin_neon_vdotq_laneq_f16_mf8_fpm:
1406014075
return EmitFP8NeonFDOTCall(Intrinsic::aarch64_neon_fp8_fdot2_lane,
14061-
ExtendLane, HalfTy, Ops, E, "fdot2_lane");
14076+
ExtendLaneArg, HalfTy, Ops, E, "fdot2_lane");
1406214077
case NEON::BI__builtin_neon_vdot_f32_mf8_fpm:
1406314078
case NEON::BI__builtin_neon_vdotq_f32_mf8_fpm:
1406414079
return EmitFP8NeonFDOTCall(Intrinsic::aarch64_neon_fp8_fdot4, false,
1406514080
FloatTy, Ops, E, "fdot4");
1406614081
case NEON::BI__builtin_neon_vdot_lane_f32_mf8_fpm:
1406714082
case NEON::BI__builtin_neon_vdotq_lane_f32_mf8_fpm:
14068-
ExtendLane = true;
14083+
ExtendLaneArg = true;
1406914084
LLVM_FALLTHROUGH;
1407014085
case NEON::BI__builtin_neon_vdot_laneq_f32_mf8_fpm:
1407114086
case NEON::BI__builtin_neon_vdotq_laneq_f32_mf8_fpm:
1407214087
return EmitFP8NeonFDOTCall(Intrinsic::aarch64_neon_fp8_fdot4_lane,
14073-
ExtendLane, FloatTy, Ops, E, "fdot4_lane");
14088+
ExtendLaneArg, FloatTy, Ops, E, "fdot4_lane");
1407414089

1407514090
case NEON::BI__builtin_neon_vmlalbq_f16_mf8_fpm:
1407614091
return EmitFP8NeonCall(Intrinsic::aarch64_neon_fp8_fmlalb,
@@ -14096,7 +14111,42 @@ Value *CodeGenFunction::EmitAArch64BuiltinExpr(unsigned BuiltinID,
1409614111
return EmitFP8NeonCall(Intrinsic::aarch64_neon_fp8_fmlalltt,
1409714112
{llvm::FixedVectorType::get(FloatTy, 4)}, Ops, E,
1409814113
"vmlall");
14099-
14114+
case NEON::BI__builtin_neon_vmlalbq_lane_f16_mf8_fpm:
14115+
ExtendLaneArg = true;
14116+
LLVM_FALLTHROUGH;
14117+
case NEON::BI__builtin_neon_vmlalbq_laneq_f16_mf8_fpm:
14118+
return EmitFP8NeonFMLACall(Intrinsic::aarch64_neon_fp8_fmlalb_lane,
14119+
ExtendLaneArg, HalfTy, Ops, E, "vmlal_lane");
14120+
case NEON::BI__builtin_neon_vmlaltq_lane_f16_mf8_fpm:
14121+
ExtendLaneArg = true;
14122+
LLVM_FALLTHROUGH;
14123+
case NEON::BI__builtin_neon_vmlaltq_laneq_f16_mf8_fpm:
14124+
return EmitFP8NeonFMLACall(Intrinsic::aarch64_neon_fp8_fmlalt_lane,
14125+
ExtendLaneArg, HalfTy, Ops, E, "vmlal_lane");
14126+
case NEON::BI__builtin_neon_vmlallbbq_lane_f32_mf8_fpm:
14127+
ExtendLaneArg = true;
14128+
LLVM_FALLTHROUGH;
14129+
case NEON::BI__builtin_neon_vmlallbbq_laneq_f32_mf8_fpm:
14130+
return EmitFP8NeonFMLACall(Intrinsic::aarch64_neon_fp8_fmlallbb_lane,
14131+
ExtendLaneArg, FloatTy, Ops, E, "vmlall_lane");
14132+
case NEON::BI__builtin_neon_vmlallbtq_lane_f32_mf8_fpm:
14133+
ExtendLaneArg = true;
14134+
LLVM_FALLTHROUGH;
14135+
case NEON::BI__builtin_neon_vmlallbtq_laneq_f32_mf8_fpm:
14136+
return EmitFP8NeonFMLACall(Intrinsic::aarch64_neon_fp8_fmlallbt_lane,
14137+
ExtendLaneArg, FloatTy, Ops, E, "vmlall_lane");
14138+
case NEON::BI__builtin_neon_vmlalltbq_lane_f32_mf8_fpm:
14139+
ExtendLaneArg = true;
14140+
LLVM_FALLTHROUGH;
14141+
case NEON::BI__builtin_neon_vmlalltbq_laneq_f32_mf8_fpm:
14142+
return EmitFP8NeonFMLACall(Intrinsic::aarch64_neon_fp8_fmlalltb_lane,
14143+
ExtendLaneArg, FloatTy, Ops, E, "vmlall_lane");
14144+
case NEON::BI__builtin_neon_vmlallttq_lane_f32_mf8_fpm:
14145+
ExtendLaneArg = true;
14146+
LLVM_FALLTHROUGH;
14147+
case NEON::BI__builtin_neon_vmlallttq_laneq_f32_mf8_fpm:
14148+
return EmitFP8NeonFMLACall(Intrinsic::aarch64_neon_fp8_fmlalltt_lane,
14149+
ExtendLaneArg, FloatTy, Ops, E, "vmlall_lane");
1410014150
case NEON::BI__builtin_neon_vamin_f16:
1410114151
case NEON::BI__builtin_neon_vaminq_f16:
1410214152
case NEON::BI__builtin_neon_vamin_f32:

clang/lib/CodeGen/CodeGenFunction.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4699,7 +4699,11 @@ class CodeGenFunction : public CodeGenTypeCache {
46994699
llvm::Type *Ty1, bool Extract,
47004700
SmallVectorImpl<llvm::Value *> &Ops,
47014701
const CallExpr *E, const char *name);
4702-
llvm::Value *EmitFP8NeonFDOTCall(unsigned IID, bool ExtendLane,
4702+
llvm::Value *EmitFP8NeonFDOTCall(unsigned IID, bool ExtendLaneArg,
4703+
llvm::Type *RetTy,
4704+
SmallVectorImpl<llvm::Value *> &Ops,
4705+
const CallExpr *E, const char *name);
4706+
llvm::Value *EmitFP8NeonFMLACall(unsigned IID, bool ExtendLaneArg,
47034707
llvm::Type *RetTy,
47044708
SmallVectorImpl<llvm::Value *> &Ops,
47054709
const CallExpr *E, const char *name);

0 commit comments

Comments
 (0)