Skip to content

Commit 880e69b

Browse files
[AArch64] Implement NEON FP8 intrinsics for fused multiply-add (indexed)
This patch adds the following intrinsics: * Floating-point multiply-add long to half-precision (vector, by element) float16x8_t vmlalbq_lane_f16_mf8_fpm(float16x8_t vd, mfloat8x16_t vn, mfloat8x8_t vm, __builtin_constant_p(lane), fpm_t fpm) float16x8_t vmlalbq_laneq_f16_mf8_fpm(float16x8_t vd, mfloat8x16_t vn, mfloat8x16_t vm, __builtin_constant_p(lane), fpm_t fpm) float16x8_t vmlaltq_lane_f16_mf8_fpm(float16x8_t vd, mfloat8x16_t vn, mfloat8x8_t vm, __builtin_constant_p(lane), fpm_t fpm) float16x8_t vmlaltq_laneq_f16_mf8_fpm(float16x8_t vd, mfloat8x16_t vn, mfloat8x16_t vm, __builtin_constant_p(lane), fpm_t fpm) * Floating-point multiply-add long-long to single-precision (vector, by element) float32x4_t vmlallbbq_lane_f32_mf8_fpm(float32x4_t vd, mfloat8x16_t vn, mfloat8x8_t vm, __builtin_constant_p(lane), fpm_t fpm) float32x4_t vmlallbbq_laneq_f32_mf8_fpm(float32x4_t vd, mfloat8x16_t vn, mfloat8x16_t vm, __builtin_constant_p(lane), fpm_t fpm) float32x4_t vmlallbtq_lane_f32_mf8_fpm(float32x4_t vd, mfloat8x16_t vn, mfloat8x8_t vm, __builtin_constant_p(lane), fpm_t fpm) float32x4_t vmlallbtq_laneq_f32_mf8_fpm(float32x4_t vd, mfloat8x16_t vn, mfloat8x16_t vm, __builtin_constant_p(lane), fpm_t fpm) float32x4_t vmlalltbq_lane_f32_mf8_fpm(float32x4_t vd, mfloat8x16_t vn, mfloat8x8_t vm, __builtin_constant_p(lane), fpm_t fpm) float32x4_t vmlalltbq_laneq_f32_mf8_fpm(float32x4_t vd, mfloat8x16_t vn, mfloat8x16_t vm, __builtin_constant_p(lane), fpm_t fpm) float32x4_t vmlallttq_lane_f32_mf8_fpm(float32x4_t vd, mfloat8x16_t vn, mfloat8x8_t vm, __builtin_constant_p(lane), fpm_t fpm) float32x4_t vmlallttq_laneq_f32_mf8_fpm(float32x4_t vd, mfloat8x16_t vn, mfloat8x16_t vm, __builtin_constant_p(lane), fpm_t fpm)
1 parent 2b939e3 commit 880e69b

File tree

9 files changed

+429
-24
lines changed

9 files changed

+429
-24
lines changed

clang/include/clang/Basic/arm_neon.td

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2175,6 +2175,20 @@ let ArchGuard = "defined(__aarch64__)", TargetGuard = "fp8fma,neon" in {
21752175
def VMLALLBT_F32_F8 : VInst<"vmlallbt_f32_mf8_fpm", "(>>F)(>>F)..V", "Qm">;
21762176
def VMLALLTB_F32_F8 : VInst<"vmlalltb_f32_mf8_fpm", "(>>F)(>>F)..V", "Qm">;
21772177
def VMLALLTT_F32_F8 : VInst<"vmlalltt_f32_mf8_fpm", "(>>F)(>>F)..V", "Qm">;
2178+
2179+
def VMLALB_F16_F8_LANE : VInst<"vmlalbq_lane_f16_mf8_fpm", "(>F)(>F).qIV", "Qm", [ImmCheck<3, ImmCheck0_7, 0>]>;
2180+
def VMLALB_F16_F8_LANEQ : VInst<"vmlalbq_laneq_f16_mf8_fpm", "(>F)(>F)..IV", "Qm", [ImmCheck<3, ImmCheck0_15, 0>]>;
2181+
def VMLALT_F16_F8_LANE : VInst<"vmlaltq_lane_f16_mf8_fpm", "(>F)(>F).qIV", "Qm", [ImmCheck<3, ImmCheck0_7, 0>]>;
2182+
def VMLALT_F16_F8_LANEQ : VInst<"vmlaltq_laneq_f16_mf8_fpm", "(>F)(>F)..IV", "Qm", [ImmCheck<3, ImmCheck0_15, 0>]>;
2183+
2184+
def VMLALLBB_F32_F8_LANE : VInst<"vmlallbbq_lane_f32_mf8_fpm", "(>>F)(>>F).qIV", "Qm", [ImmCheck<3, ImmCheck0_7, 0>]>;
2185+
def VMLALLBB_F32_F8_LANEQ : VInst<"vmlallbbq_laneq_f32_mf8_fpm", "(>>F)(>>F)..IV", "Qm", [ImmCheck<3, ImmCheck0_15, 0>]>;
2186+
def VMLALLBT_F32_F8_LANE : VInst<"vmlallbtq_lane_f32_mf8_fpm", "(>>F)(>>F).qIV", "Qm", [ImmCheck<3, ImmCheck0_7, 0>]>;
2187+
def VMLALLBT_F32_F8_LANEQ : VInst<"vmlallbtq_laneq_f32_mf8_fpm", "(>>F)(>>F)..IV", "Qm", [ImmCheck<3, ImmCheck0_15, 0>]>;
2188+
def VMLALLTB_F32_F8_LANE : VInst<"vmlalltbq_lane_f32_mf8_fpm", "(>>F)(>>F).qIV", "Qm", [ImmCheck<3, ImmCheck0_7, 0>]>;
2189+
def VMLALLTB_F32_F8_LANEQ : VInst<"vmlalltbq_laneq_f32_mf8_fpm", "(>>F)(>>F)..IV", "Qm", [ImmCheck<3, ImmCheck0_15, 0>]>;
2190+
def VMLALLTT_F32_F8_LANE : VInst<"vmlallttq_lane_f32_mf8_fpm", "(>>F)(>>F).qIV", "Qm", [ImmCheck<3, ImmCheck0_7, 0>]>;
2191+
def VMLALLTT_F32_F8_LANEQ : VInst<"vmlallttq_laneq_f32_mf8_fpm", "(>>F)(>>F)..IV", "Qm", [ImmCheck<3, ImmCheck0_15, 0>]>;
21782192
}
21792193

21802194
let ArchGuard = "defined(__aarch64__)", TargetGuard = "neon,faminmax" in {

clang/lib/CodeGen/CGBuiltin.cpp

Lines changed: 58 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6909,21 +6909,36 @@ Value *CodeGenFunction::EmitFP8NeonCall(unsigned IID,
69096909
}
69106910

69116911
llvm::Value *CodeGenFunction::EmitFP8NeonFDOTCall(
6912-
unsigned IID, bool ExtendLane, llvm::Type *RetTy,
6912+
unsigned IID, bool ExtendLaneArg, llvm::Type *RetTy,
69136913
SmallVectorImpl<llvm::Value *> &Ops, const CallExpr *E, const char *name) {
69146914

69156915
const unsigned ElemCount = Ops[0]->getType()->getPrimitiveSizeInBits() /
69166916
RetTy->getPrimitiveSizeInBits();
69176917
llvm::Type *Tys[] = {llvm::FixedVectorType::get(RetTy, ElemCount),
69186918
Ops[1]->getType()};
6919-
if (ExtendLane) {
6919+
if (ExtendLaneArg) {
69206920
auto *VT = llvm::FixedVectorType::get(Int8Ty, 16);
69216921
Ops[2] = Builder.CreateInsertVector(VT, PoisonValue::get(VT), Ops[2],
69226922
Builder.getInt64(0));
69236923
}
69246924
return EmitFP8NeonCall(IID, Tys, Ops, E, name);
69256925
}
69266926

6927+
llvm::Value *CodeGenFunction::EmitFP8NeonFMLACall(
6928+
unsigned IID, bool ExtendLaneArg, llvm::Type *RetTy,
6929+
SmallVectorImpl<llvm::Value *> &Ops, const CallExpr *E, const char *name) {
6930+
6931+
if (ExtendLaneArg) {
6932+
auto *VT = llvm::FixedVectorType::get(Int8Ty, 16);
6933+
Ops[2] = Builder.CreateInsertVector(VT, PoisonValue::get(VT), Ops[2],
6934+
Builder.getInt64(0));
6935+
}
6936+
const unsigned ElemCount = Ops[0]->getType()->getPrimitiveSizeInBits() /
6937+
RetTy->getPrimitiveSizeInBits();
6938+
return EmitFP8NeonCall(IID, {llvm::FixedVectorType::get(RetTy, ElemCount)},
6939+
Ops, E, name);
6940+
}
6941+
69276942
Value *CodeGenFunction::EmitNeonShiftVector(Value *V, llvm::Type *Ty,
69286943
bool neg) {
69296944
int SV = cast<ConstantInt>(V)->getSExtValue();
@@ -12882,7 +12897,7 @@ Value *CodeGenFunction::EmitAArch64BuiltinExpr(unsigned BuiltinID,
1288212897

1288312898
unsigned Int;
1288412899
bool ExtractLow = false;
12885-
bool ExtendLane = false;
12900+
bool ExtendLaneArg = false;
1288612901
switch (BuiltinID) {
1288712902
default: return nullptr;
1288812903
case NEON::BI__builtin_neon_vbsl_v:
@@ -14157,24 +14172,24 @@ Value *CodeGenFunction::EmitAArch64BuiltinExpr(unsigned BuiltinID,
1415714172
Ops, E, "fdot2");
1415814173
case NEON::BI__builtin_neon_vdot_lane_f16_mf8_fpm:
1415914174
case NEON::BI__builtin_neon_vdotq_lane_f16_mf8_fpm:
14160-
ExtendLane = true;
14175+
ExtendLaneArg = true;
1416114176
LLVM_FALLTHROUGH;
1416214177
case NEON::BI__builtin_neon_vdot_laneq_f16_mf8_fpm:
1416314178
case NEON::BI__builtin_neon_vdotq_laneq_f16_mf8_fpm:
1416414179
return EmitFP8NeonFDOTCall(Intrinsic::aarch64_neon_fp8_fdot2_lane,
14165-
ExtendLane, HalfTy, Ops, E, "fdot2_lane");
14180+
ExtendLaneArg, HalfTy, Ops, E, "fdot2_lane");
1416614181
case NEON::BI__builtin_neon_vdot_f32_mf8_fpm:
1416714182
case NEON::BI__builtin_neon_vdotq_f32_mf8_fpm:
1416814183
return EmitFP8NeonFDOTCall(Intrinsic::aarch64_neon_fp8_fdot4, false,
1416914184
FloatTy, Ops, E, "fdot4");
1417014185
case NEON::BI__builtin_neon_vdot_lane_f32_mf8_fpm:
1417114186
case NEON::BI__builtin_neon_vdotq_lane_f32_mf8_fpm:
14172-
ExtendLane = true;
14187+
ExtendLaneArg = true;
1417314188
LLVM_FALLTHROUGH;
1417414189
case NEON::BI__builtin_neon_vdot_laneq_f32_mf8_fpm:
1417514190
case NEON::BI__builtin_neon_vdotq_laneq_f32_mf8_fpm:
1417614191
return EmitFP8NeonFDOTCall(Intrinsic::aarch64_neon_fp8_fdot4_lane,
14177-
ExtendLane, FloatTy, Ops, E, "fdot4_lane");
14192+
ExtendLaneArg, FloatTy, Ops, E, "fdot4_lane");
1417814193

1417914194
case NEON::BI__builtin_neon_vmlalbq_f16_mf8_fpm:
1418014195
return EmitFP8NeonCall(Intrinsic::aarch64_neon_fp8_fmlalb,
@@ -14200,7 +14215,42 @@ Value *CodeGenFunction::EmitAArch64BuiltinExpr(unsigned BuiltinID,
1420014215
return EmitFP8NeonCall(Intrinsic::aarch64_neon_fp8_fmlalltt,
1420114216
{llvm::FixedVectorType::get(FloatTy, 4)}, Ops, E,
1420214217
"vmlall");
14203-
14218+
case NEON::BI__builtin_neon_vmlalbq_lane_f16_mf8_fpm:
14219+
ExtendLaneArg = true;
14220+
LLVM_FALLTHROUGH;
14221+
case NEON::BI__builtin_neon_vmlalbq_laneq_f16_mf8_fpm:
14222+
return EmitFP8NeonFMLACall(Intrinsic::aarch64_neon_fp8_fmlalb_lane,
14223+
ExtendLaneArg, HalfTy, Ops, E, "vmlal_lane");
14224+
case NEON::BI__builtin_neon_vmlaltq_lane_f16_mf8_fpm:
14225+
ExtendLaneArg = true;
14226+
LLVM_FALLTHROUGH;
14227+
case NEON::BI__builtin_neon_vmlaltq_laneq_f16_mf8_fpm:
14228+
return EmitFP8NeonFMLACall(Intrinsic::aarch64_neon_fp8_fmlalt_lane,
14229+
ExtendLaneArg, HalfTy, Ops, E, "vmlal_lane");
14230+
case NEON::BI__builtin_neon_vmlallbbq_lane_f32_mf8_fpm:
14231+
ExtendLaneArg = true;
14232+
LLVM_FALLTHROUGH;
14233+
case NEON::BI__builtin_neon_vmlallbbq_laneq_f32_mf8_fpm:
14234+
return EmitFP8NeonFMLACall(Intrinsic::aarch64_neon_fp8_fmlallbb_lane,
14235+
ExtendLaneArg, FloatTy, Ops, E, "vmlall_lane");
14236+
case NEON::BI__builtin_neon_vmlallbtq_lane_f32_mf8_fpm:
14237+
ExtendLaneArg = true;
14238+
LLVM_FALLTHROUGH;
14239+
case NEON::BI__builtin_neon_vmlallbtq_laneq_f32_mf8_fpm:
14240+
return EmitFP8NeonFMLACall(Intrinsic::aarch64_neon_fp8_fmlallbt_lane,
14241+
ExtendLaneArg, FloatTy, Ops, E, "vmlall_lane");
14242+
case NEON::BI__builtin_neon_vmlalltbq_lane_f32_mf8_fpm:
14243+
ExtendLaneArg = true;
14244+
LLVM_FALLTHROUGH;
14245+
case NEON::BI__builtin_neon_vmlalltbq_laneq_f32_mf8_fpm:
14246+
return EmitFP8NeonFMLACall(Intrinsic::aarch64_neon_fp8_fmlalltb_lane,
14247+
ExtendLaneArg, FloatTy, Ops, E, "vmlall_lane");
14248+
case NEON::BI__builtin_neon_vmlallttq_lane_f32_mf8_fpm:
14249+
ExtendLaneArg = true;
14250+
LLVM_FALLTHROUGH;
14251+
case NEON::BI__builtin_neon_vmlallttq_laneq_f32_mf8_fpm:
14252+
return EmitFP8NeonFMLACall(Intrinsic::aarch64_neon_fp8_fmlalltt_lane,
14253+
ExtendLaneArg, FloatTy, Ops, E, "vmlall_lane");
1420414254
case NEON::BI__builtin_neon_vamin_f16:
1420514255
case NEON::BI__builtin_neon_vaminq_f16:
1420614256
case NEON::BI__builtin_neon_vamin_f32:

clang/lib/CodeGen/CodeGenFunction.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4689,7 +4689,11 @@ class CodeGenFunction : public CodeGenTypeCache {
46894689
llvm::Type *Ty1, bool Extract,
46904690
SmallVectorImpl<llvm::Value *> &Ops,
46914691
const CallExpr *E, const char *name);
4692-
llvm::Value *EmitFP8NeonFDOTCall(unsigned IID, bool ExtendLane,
4692+
llvm::Value *EmitFP8NeonFDOTCall(unsigned IID, bool ExtendLaneArg,
4693+
llvm::Type *RetTy,
4694+
SmallVectorImpl<llvm::Value *> &Ops,
4695+
const CallExpr *E, const char *name);
4696+
llvm::Value *EmitFP8NeonFMLACall(unsigned IID, bool ExtendLaneArg,
46934697
llvm::Type *RetTy,
46944698
SmallVectorImpl<llvm::Value *> &Ops,
46954699
const CallExpr *E, const char *name);

0 commit comments

Comments
 (0)