Skip to content

Commit afc66a8

Browse files
[AArch64] Implement NEON FP8 intrinsics for fused multiply-add (indexed)
This patch adds the following intrinsics: * Floating-point multiply-add long to half-precision (vector, by element) float16x8_t vmlalbq_lane_f16_mf8_fpm(float16x8_t vd, mfloat8x16_t vn, mfloat8x8_t vm, __builtin_constant_p(lane), fpm_t fpm) float16x8_t vmlalbq_laneq_f16_mf8_fpm(float16x8_t vd, mfloat8x16_t vn, mfloat8x16_t vm, __builtin_constant_p(lane), fpm_t fpm) float16x8_t vmlaltq_lane_f16_mf8_fpm(float16x8_t vd, mfloat8x16_t vn, mfloat8x8_t vm, __builtin_constant_p(lane), fpm_t fpm) float16x8_t vmlaltq_laneq_f16_mf8_fpm(float16x8_t vd, mfloat8x16_t vn, mfloat8x16_t vm, __builtin_constant_p(lane), fpm_t fpm) * Floating-point multiply-add long-long to single-precision (vector, by element) float32x4_t vmlallbbq_lane_f32_mf8_fpm(float32x4_t vd, mfloat8x16_t vn, mfloat8x8_t vm, __builtin_constant_p(lane), fpm_t fpm) float32x4_t vmlallbbq_laneq_f32_mf8_fpm(float32x4_t vd, mfloat8x16_t vn, mfloat8x16_t vm, __builtin_constant_p(lane), fpm_t fpm) float32x4_t vmlallbtq_lane_f32_mf8_fpm(float32x4_t vd, mfloat8x16_t vn, mfloat8x8_t vm, __builtin_constant_p(lane), fpm_t fpm) float32x4_t vmlallbtq_laneq_f32_mf8_fpm(float32x4_t vd, mfloat8x16_t vn, mfloat8x16_t vm, __builtin_constant_p(lane), fpm_t fpm) float32x4_t vmlalltbq_lane_f32_mf8_fpm(float32x4_t vd, mfloat8x16_t vn, mfloat8x8_t vm, __builtin_constant_p(lane), fpm_t fpm) float32x4_t vmlalltbq_laneq_f32_mf8_fpm(float32x4_t vd, mfloat8x16_t vn, mfloat8x16_t vm, __builtin_constant_p(lane), fpm_t fpm) float32x4_t vmlallttq_lane_f32_mf8_fpm(float32x4_t vd, mfloat8x16_t vn, mfloat8x8_t vm, __builtin_constant_p(lane), fpm_t fpm) float32x4_t vmlallttq_laneq_f32_mf8_fpm(float32x4_t vd, mfloat8x16_t vn, mfloat8x16_t vm, __builtin_constant_p(lane), fpm_t fpm) [fixup] Update intrinsics definitions [fixup] Regenerate tests
1 parent 2ceaf4b commit afc66a8

File tree

9 files changed

+445
-24
lines changed

9 files changed

+445
-24
lines changed

clang/include/clang/Basic/arm_neon.td

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2175,6 +2175,20 @@ let ArchGuard = "defined(__aarch64__)", TargetGuard = "fp8fma,neon" in {
21752175
def VMLALLBT_F32_F8 : VInst<"vmlallbt_f32_mf8_fpm", "(>>F)(>>F)..V", "Qm">;
21762176
def VMLALLTB_F32_F8 : VInst<"vmlalltb_f32_mf8_fpm", "(>>F)(>>F)..V", "Qm">;
21772177
def VMLALLTT_F32_F8 : VInst<"vmlalltt_f32_mf8_fpm", "(>>F)(>>F)..V", "Qm">;
2178+
2179+
def VMLALB_F16_F8_LANE : VInst<"vmlalb_lane_f16_mf8_fpm", "(>F)(>F).qIV", "Qm", [ImmCheck<3, ImmCheck0_7, 0>]>;
2180+
def VMLALB_F16_F8_LANEQ : VInst<"vmlalb_laneq_f16_mf8_fpm", "(>F)(>F)..IV", "Qm", [ImmCheck<3, ImmCheck0_15, 0>]>;
2181+
def VMLALT_F16_F8_LANE : VInst<"vmlalt_lane_f16_mf8_fpm", "(>F)(>F).qIV", "Qm", [ImmCheck<3, ImmCheck0_7, 0>]>;
2182+
def VMLALT_F16_F8_LANEQ : VInst<"vmlalt_laneq_f16_mf8_fpm", "(>F)(>F)..IV", "Qm", [ImmCheck<3, ImmCheck0_15, 0>]>;
2183+
2184+
def VMLALLBB_F32_F8_LANE : VInst<"vmlallbb_lane_f32_mf8_fpm", "(>>F)(>>F).qIV", "Qm", [ImmCheck<3, ImmCheck0_7, 0>]>;
2185+
def VMLALLBB_F32_F8_LANEQ : VInst<"vmlallbb_laneq_f32_mf8_fpm", "(>>F)(>>F)..IV", "Qm", [ImmCheck<3, ImmCheck0_15, 0>]>;
2186+
def VMLALLBT_F32_F8_LANE : VInst<"vmlallbt_lane_f32_mf8_fpm", "(>>F)(>>F).qIV", "Qm", [ImmCheck<3, ImmCheck0_7, 0>]>;
2187+
def VMLALLBT_F32_F8_LANEQ : VInst<"vmlallbt_laneq_f32_mf8_fpm", "(>>F)(>>F)..IV", "Qm", [ImmCheck<3, ImmCheck0_15, 0>]>;
2188+
def VMLALLTB_F32_F8_LANE : VInst<"vmlalltb_lane_f32_mf8_fpm", "(>>F)(>>F).qIV", "Qm", [ImmCheck<3, ImmCheck0_7, 0>]>;
2189+
def VMLALLTB_F32_F8_LANEQ : VInst<"vmlalltb_laneq_f32_mf8_fpm", "(>>F)(>>F)..IV", "Qm", [ImmCheck<3, ImmCheck0_15, 0>]>;
2190+
def VMLALLTT_F32_F8_LANE : VInst<"vmlalltt_lane_f32_mf8_fpm", "(>>F)(>>F).qIV", "Qm", [ImmCheck<3, ImmCheck0_7, 0>]>;
2191+
def VMLALLTT_F32_F8_LANEQ : VInst<"vmlalltt_laneq_f32_mf8_fpm", "(>>F)(>>F)..IV", "Qm", [ImmCheck<3, ImmCheck0_15, 0>]>;
21782192
}
21792193

21802194
let ArchGuard = "defined(__aarch64__)", TargetGuard = "neon,faminmax" in {

clang/lib/CodeGen/CGBuiltin.cpp

Lines changed: 58 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6910,21 +6910,36 @@ Value *CodeGenFunction::EmitFP8NeonCall(unsigned IID,
69106910
}
69116911

69126912
llvm::Value *CodeGenFunction::EmitFP8NeonFDOTCall(
6913-
unsigned IID, bool ExtendLane, llvm::Type *RetTy,
6913+
unsigned IID, bool ExtendLaneArg, llvm::Type *RetTy,
69146914
SmallVectorImpl<llvm::Value *> &Ops, const CallExpr *E, const char *name) {
69156915

69166916
const unsigned ElemCount = Ops[0]->getType()->getPrimitiveSizeInBits() /
69176917
RetTy->getPrimitiveSizeInBits();
69186918
llvm::Type *Tys[] = {llvm::FixedVectorType::get(RetTy, ElemCount),
69196919
Ops[1]->getType()};
6920-
if (ExtendLane) {
6920+
if (ExtendLaneArg) {
69216921
auto *VT = llvm::FixedVectorType::get(Int8Ty, 16);
69226922
Ops[2] = Builder.CreateInsertVector(VT, PoisonValue::get(VT), Ops[2],
69236923
Builder.getInt64(0));
69246924
}
69256925
return EmitFP8NeonCall(IID, Tys, Ops, E, name);
69266926
}
69276927

6928+
llvm::Value *CodeGenFunction::EmitFP8NeonFMLACall(
6929+
unsigned IID, bool ExtendLaneArg, llvm::Type *RetTy,
6930+
SmallVectorImpl<llvm::Value *> &Ops, const CallExpr *E, const char *name) {
6931+
6932+
if (ExtendLaneArg) {
6933+
auto *VT = llvm::FixedVectorType::get(Int8Ty, 16);
6934+
Ops[2] = Builder.CreateInsertVector(VT, PoisonValue::get(VT), Ops[2],
6935+
Builder.getInt64(0));
6936+
}
6937+
const unsigned ElemCount = Ops[0]->getType()->getPrimitiveSizeInBits() /
6938+
RetTy->getPrimitiveSizeInBits();
6939+
return EmitFP8NeonCall(IID, {llvm::FixedVectorType::get(RetTy, ElemCount)},
6940+
Ops, E, name);
6941+
}
6942+
69286943
Value *CodeGenFunction::EmitNeonShiftVector(Value *V, llvm::Type *Ty,
69296944
bool neg) {
69306945
int SV = cast<ConstantInt>(V)->getSExtValue();
@@ -12883,7 +12898,7 @@ Value *CodeGenFunction::EmitAArch64BuiltinExpr(unsigned BuiltinID,
1288312898

1288412899
unsigned Int;
1288512900
bool ExtractLow = false;
12886-
bool ExtendLane = false;
12901+
bool ExtendLaneArg = false;
1288712902
switch (BuiltinID) {
1288812903
default: return nullptr;
1288912904
case NEON::BI__builtin_neon_vbsl_v:
@@ -14158,24 +14173,24 @@ Value *CodeGenFunction::EmitAArch64BuiltinExpr(unsigned BuiltinID,
1415814173
Ops, E, "fdot2");
1415914174
case NEON::BI__builtin_neon_vdot_lane_f16_mf8_fpm:
1416014175
case NEON::BI__builtin_neon_vdotq_lane_f16_mf8_fpm:
14161-
ExtendLane = true;
14176+
ExtendLaneArg = true;
1416214177
LLVM_FALLTHROUGH;
1416314178
case NEON::BI__builtin_neon_vdot_laneq_f16_mf8_fpm:
1416414179
case NEON::BI__builtin_neon_vdotq_laneq_f16_mf8_fpm:
1416514180
return EmitFP8NeonFDOTCall(Intrinsic::aarch64_neon_fp8_fdot2_lane,
14166-
ExtendLane, HalfTy, Ops, E, "fdot2_lane");
14181+
ExtendLaneArg, HalfTy, Ops, E, "fdot2_lane");
1416714182
case NEON::BI__builtin_neon_vdot_f32_mf8_fpm:
1416814183
case NEON::BI__builtin_neon_vdotq_f32_mf8_fpm:
1416914184
return EmitFP8NeonFDOTCall(Intrinsic::aarch64_neon_fp8_fdot4, false,
1417014185
FloatTy, Ops, E, "fdot4");
1417114186
case NEON::BI__builtin_neon_vdot_lane_f32_mf8_fpm:
1417214187
case NEON::BI__builtin_neon_vdotq_lane_f32_mf8_fpm:
14173-
ExtendLane = true;
14188+
ExtendLaneArg = true;
1417414189
LLVM_FALLTHROUGH;
1417514190
case NEON::BI__builtin_neon_vdot_laneq_f32_mf8_fpm:
1417614191
case NEON::BI__builtin_neon_vdotq_laneq_f32_mf8_fpm:
1417714192
return EmitFP8NeonFDOTCall(Intrinsic::aarch64_neon_fp8_fdot4_lane,
14178-
ExtendLane, FloatTy, Ops, E, "fdot4_lane");
14193+
ExtendLaneArg, FloatTy, Ops, E, "fdot4_lane");
1417914194

1418014195
case NEON::BI__builtin_neon_vmlalbq_f16_mf8_fpm:
1418114196
return EmitFP8NeonCall(Intrinsic::aarch64_neon_fp8_fmlalb,
@@ -14201,7 +14216,42 @@ Value *CodeGenFunction::EmitAArch64BuiltinExpr(unsigned BuiltinID,
1420114216
return EmitFP8NeonCall(Intrinsic::aarch64_neon_fp8_fmlalltt,
1420214217
{llvm::FixedVectorType::get(FloatTy, 4)}, Ops, E,
1420314218
"vmlall");
14204-
14219+
case NEON::BI__builtin_neon_vmlalbq_lane_f16_mf8_fpm:
14220+
ExtendLaneArg = true;
14221+
LLVM_FALLTHROUGH;
14222+
case NEON::BI__builtin_neon_vmlalbq_laneq_f16_mf8_fpm:
14223+
return EmitFP8NeonFMLACall(Intrinsic::aarch64_neon_fp8_fmlalb_lane,
14224+
ExtendLaneArg, HalfTy, Ops, E, "vmlal_lane");
14225+
case NEON::BI__builtin_neon_vmlaltq_lane_f16_mf8_fpm:
14226+
ExtendLaneArg = true;
14227+
LLVM_FALLTHROUGH;
14228+
case NEON::BI__builtin_neon_vmlaltq_laneq_f16_mf8_fpm:
14229+
return EmitFP8NeonFMLACall(Intrinsic::aarch64_neon_fp8_fmlalt_lane,
14230+
ExtendLaneArg, HalfTy, Ops, E, "vmlal_lane");
14231+
case NEON::BI__builtin_neon_vmlallbbq_lane_f32_mf8_fpm:
14232+
ExtendLaneArg = true;
14233+
LLVM_FALLTHROUGH;
14234+
case NEON::BI__builtin_neon_vmlallbbq_laneq_f32_mf8_fpm:
14235+
return EmitFP8NeonFMLACall(Intrinsic::aarch64_neon_fp8_fmlallbb_lane,
14236+
ExtendLaneArg, FloatTy, Ops, E, "vmlall_lane");
14237+
case NEON::BI__builtin_neon_vmlallbtq_lane_f32_mf8_fpm:
14238+
ExtendLaneArg = true;
14239+
LLVM_FALLTHROUGH;
14240+
case NEON::BI__builtin_neon_vmlallbtq_laneq_f32_mf8_fpm:
14241+
return EmitFP8NeonFMLACall(Intrinsic::aarch64_neon_fp8_fmlallbt_lane,
14242+
ExtendLaneArg, FloatTy, Ops, E, "vmlall_lane");
14243+
case NEON::BI__builtin_neon_vmlalltbq_lane_f32_mf8_fpm:
14244+
ExtendLaneArg = true;
14245+
LLVM_FALLTHROUGH;
14246+
case NEON::BI__builtin_neon_vmlalltbq_laneq_f32_mf8_fpm:
14247+
return EmitFP8NeonFMLACall(Intrinsic::aarch64_neon_fp8_fmlalltb_lane,
14248+
ExtendLaneArg, FloatTy, Ops, E, "vmlall_lane");
14249+
case NEON::BI__builtin_neon_vmlallttq_lane_f32_mf8_fpm:
14250+
ExtendLaneArg = true;
14251+
LLVM_FALLTHROUGH;
14252+
case NEON::BI__builtin_neon_vmlallttq_laneq_f32_mf8_fpm:
14253+
return EmitFP8NeonFMLACall(Intrinsic::aarch64_neon_fp8_fmlalltt_lane,
14254+
ExtendLaneArg, FloatTy, Ops, E, "vmlall_lane");
1420514255
case NEON::BI__builtin_neon_vamin_f16:
1420614256
case NEON::BI__builtin_neon_vaminq_f16:
1420714257
case NEON::BI__builtin_neon_vamin_f32:

clang/lib/CodeGen/CodeGenFunction.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4698,7 +4698,11 @@ class CodeGenFunction : public CodeGenTypeCache {
46984698
llvm::Type *Ty1, bool Extract,
46994699
SmallVectorImpl<llvm::Value *> &Ops,
47004700
const CallExpr *E, const char *name);
4701-
llvm::Value *EmitFP8NeonFDOTCall(unsigned IID, bool ExtendLane,
4701+
llvm::Value *EmitFP8NeonFDOTCall(unsigned IID, bool ExtendLaneArg,
4702+
llvm::Type *RetTy,
4703+
SmallVectorImpl<llvm::Value *> &Ops,
4704+
const CallExpr *E, const char *name);
4705+
llvm::Value *EmitFP8NeonFMLACall(unsigned IID, bool ExtendLaneArg,
47024706
llvm::Type *RetTy,
47034707
SmallVectorImpl<llvm::Value *> &Ops,
47044708
const CallExpr *E, const char *name);

0 commit comments

Comments
 (0)