Skip to content

Commit b755f98

Browse files
[AArch64] Add FP8 Neon intrinsics for dot-product
THis patch adds the following intrinsics: float16x4_t vdot_f16_mf8_fpm(float16x4_t vd, mfloat8x8_t vn, mfloat8x8_t vm, fpm_t fpm) float16x8_t vdotq_f16_mf8_fpm(float16x8_t vd, mfloat8x16_t vn, mfloat8x16_t vm, fpm_t fpm) float16x4_t vdot_lane_f16_mf8_fpm(float16x4_t vd, mfloat8x8_t vn, mfloat8x8_t vm, __builtin_constant_p(lane), fpm_t fpm) float16x4_t vdot_laneq_f16_mf8_fpm(float16x4_t vd, mfloat8x8_t vn, mfloat8x16_t vm, __builtin_constant_p(lane), fpm_t fpm) float16x8_t vdotq_lane_f16_mf8_fpm(float16x8_t vd, mfloat8x16_t vn, mfloat8x8_t vm, __builtin_constant_p(lane), fpm_t fpm) float16x8_t vdotq_laneq_f16_mf8_fpm(float16x8_t vd, mfloat8x16_t vn, mfloat8x16_t vm, __builtin_constant_p(lane), fpm_t fpm)
1 parent 0c2e6ad commit b755f98

File tree

10 files changed

+511
-40
lines changed

10 files changed

+511
-40
lines changed

clang/include/clang/Basic/arm_neon.td

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2148,6 +2148,28 @@ let ArchGuard = "defined(__aarch64__)", TargetGuard = "fp8,neon" in {
21482148
def VCVTNQ_F8_F16 : VInst<"vcvtq_mf8_f16_fpm", ".(>F)(>F)V", "Qm">;
21492149
}
21502150

2151+
let ArchGuard = "defined(__aarch64__)", TargetGuard = "fp8dot2,neon" in {
2152+
def VDOT_F16_MF8 : VInst<"vdot_f16_mf8_fpm", "(>F)(>F)..V", "m">;
2153+
def VDOTQ_F16_MF8 : VInst<"vdotq_f16_mf8_fpm", "(>F)(>F)..V", "Qm">;
2154+
2155+
def VDOT_LANE_F16_MF8 : VInst<"vdot_lane_f16_mf8_fpm", "(>F)(>F)..IV", "m", [ImmCheck<3, ImmCheck0_3, 0>]>;
2156+
def VDOT_LANEQ_F16_MF8 : VInst<"vdot_laneq_f16_mf8_fpm", "(>F)(>F).QIV", "m", [ImmCheck<3, ImmCheck0_7, 0>]>;
2157+
2158+
def VDOTQ_LANE_F16_MF8 : VInst<"vdotq_lane_f16_mf8_fpm", "(>F)(>F).qIV", "Qm", [ImmCheck<3, ImmCheck0_3, 0>]>;
2159+
def VDOTQ_LANEQ_F16_MF8 : VInst<"vdotq_laneq_f16_mf8_fpm", "(>F)(>F)..IV", "Qm", [ImmCheck<3, ImmCheck0_7, 0>]>;
2160+
}
2161+
2162+
let ArchGuard = "defined(__aarch64__)", TargetGuard = "fp8dot4,neon" in {
2163+
def VDOT_F32_MF8 : VInst<"vdot_f32_mf8_fpm", "(>>F)(>>F)..V", "m">;
2164+
def VDOTQ_F32_MF8 : VInst<"vdotq_f32_mf8_fpm", "(>>F)(>>F)..V", "Qm">;
2165+
2166+
def VDOT_LANE_F32_MF8 : VInst<"vdot_lane_f32_mf8_fpm", "(>>F)(>>F)..IV", "m", [ImmCheck<3, ImmCheck0_1, 0>]>;
2167+
def VDOT_LANEQ_F32_MF8 : VInst<"vdot_laneq_f32_mf8_fpm", "(>>F)(>>F).QIV", "m", [ImmCheck<3, ImmCheck0_3, 0>]>;
2168+
2169+
def VDOTQ_LANE_F32_MF8 : VInst<"vdotq_lane_f32_mf8_fpm", "(>>F)(>>F).qIV", "Qm", [ImmCheck<3, ImmCheck0_1, 0>]>;
2170+
def VDOTQ_LANEQ_F32_MF8 : VInst<"vdotq_laneq_f32_mf8_fpm", "(>>F)(>>F)..IV", "Qm", [ImmCheck<3, ImmCheck0_3, 0>]>;
2171+
}
2172+
21512173
let ArchGuard = "defined(__aarch64__)", TargetGuard = "neon,faminmax" in {
21522174
def FAMIN : WInst<"vamin", "...", "fhQdQfQh">;
21532175
def FAMAX : WInst<"vamax", "...", "fhQdQfQh">;

clang/include/clang/Basic/arm_neon_incl.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,7 @@ class Inst <string n, string p, string t, Operation o, list<ImmCheck> ch = []>{
302302
class SInst<string n, string p, string t, list<ImmCheck> ch = []> : Inst<n, p, t, OP_NONE, ch> {}
303303
class IInst<string n, string p, string t, list<ImmCheck> ch = []> : Inst<n, p, t, OP_NONE, ch> {}
304304
class WInst<string n, string p, string t, list<ImmCheck> ch = []> : Inst<n, p, t, OP_NONE, ch> {}
305-
class VInst<string n, string p, string t> : Inst<n, p, t, OP_NONE> {}
305+
class VInst<string n, string p, string t, list<ImmCheck> ch = []> : Inst<n, p, t, OP_NONE, ch> {}
306306

307307
// The following instruction classes are implemented via operators
308308
// instead of builtins. As such these declarations are only used for

clang/lib/CodeGen/CGBuiltin.cpp

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6876,6 +6876,24 @@ Value *CodeGenFunction::EmitFP8NeonCall(Function *F,
68766876
return EmitNeonCall(F, Ops, name);
68776877
}
68786878

6879+
llvm::Value *CodeGenFunction::EmitFP8NeonFDOTCall(
6880+
unsigned IID, bool ExtendLane, llvm::Type *RetTy,
6881+
SmallVectorImpl<llvm::Value *> &Ops, const CallExpr *E, const char *name) {
6882+
6883+
const unsigned ElemCount = Ops[0]->getType()->getPrimitiveSizeInBits() /
6884+
RetTy->getPrimitiveSizeInBits();
6885+
llvm::Type *Tys[] = {llvm::FixedVectorType::get(RetTy, ElemCount),
6886+
Ops[1]->getType()};
6887+
if (ExtendLane) {
6888+
auto *VT = llvm::FixedVectorType::get(Int8Ty, 16);
6889+
Ops[2] = Builder.CreateInsertVector(VT, PoisonValue::get(VT), Ops[2],
6890+
Builder.getInt64(0));
6891+
}
6892+
llvm::Value *FPM =
6893+
EmitScalarOrConstFoldImmArg(/* ICEArguments */ 0, E->getNumArgs() - 1, E);
6894+
return EmitFP8NeonCall(CGM.getIntrinsic(IID, Tys), Ops, FPM, name);
6895+
}
6896+
68796897
Value *CodeGenFunction::EmitNeonShiftVector(Value *V, llvm::Type *Ty,
68806898
bool neg) {
68816899
int SV = cast<ConstantInt>(V)->getSExtValue();
@@ -12823,6 +12841,7 @@ Value *CodeGenFunction::EmitAArch64BuiltinExpr(unsigned BuiltinID,
1282312841

1282412842
unsigned Int;
1282512843
bool ExtractLow = false;
12844+
bool ExtendLane = false;
1282612845
switch (BuiltinID) {
1282712846
default: return nullptr;
1282812847
case NEON::BI__builtin_neon_vbsl_v:
@@ -14090,6 +14109,31 @@ Value *CodeGenFunction::EmitAArch64BuiltinExpr(unsigned BuiltinID,
1409014109
return EmitFP8NeonCvtCall(Intrinsic::aarch64_neon_fp8_fcvtn2,
1409114110
Ty, Ops[1]->getType(), false, Ops, E, "vfcvtn2");
1409214111
}
14112+
14113+
case NEON::BI__builtin_neon_vdot_f16_mf8_fpm:
14114+
case NEON::BI__builtin_neon_vdotq_f16_mf8_fpm:
14115+
return EmitFP8NeonFDOTCall(Intrinsic::aarch64_neon_fp8_fdot2, false, HalfTy,
14116+
Ops, E, "fdot2");
14117+
case NEON::BI__builtin_neon_vdot_lane_f16_mf8_fpm:
14118+
case NEON::BI__builtin_neon_vdotq_lane_f16_mf8_fpm:
14119+
ExtendLane = true;
14120+
LLVM_FALLTHROUGH;
14121+
case NEON::BI__builtin_neon_vdot_laneq_f16_mf8_fpm:
14122+
case NEON::BI__builtin_neon_vdotq_laneq_f16_mf8_fpm:
14123+
return EmitFP8NeonFDOTCall(Intrinsic::aarch64_neon_fp8_fdot2_lane,
14124+
ExtendLane, HalfTy, Ops, E, "fdot2_lane");
14125+
case NEON::BI__builtin_neon_vdot_f32_mf8_fpm:
14126+
case NEON::BI__builtin_neon_vdotq_f32_mf8_fpm:
14127+
return EmitFP8NeonFDOTCall(Intrinsic::aarch64_neon_fp8_fdot4, false,
14128+
FloatTy, Ops, E, "fdot4");
14129+
case NEON::BI__builtin_neon_vdot_lane_f32_mf8_fpm:
14130+
case NEON::BI__builtin_neon_vdotq_lane_f32_mf8_fpm:
14131+
ExtendLane = true;
14132+
LLVM_FALLTHROUGH;
14133+
case NEON::BI__builtin_neon_vdot_laneq_f32_mf8_fpm:
14134+
case NEON::BI__builtin_neon_vdotq_laneq_f32_mf8_fpm:
14135+
return EmitFP8NeonFDOTCall(Intrinsic::aarch64_neon_fp8_fdot4_lane,
14136+
ExtendLane, FloatTy, Ops, E, "fdot4_lane");
1409314137
case NEON::BI__builtin_neon_vamin_f16:
1409414138
case NEON::BI__builtin_neon_vaminq_f16:
1409514139
case NEON::BI__builtin_neon_vamin_f32:

clang/lib/CodeGen/CodeGenFunction.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4655,6 +4655,10 @@ class CodeGenFunction : public CodeGenTypeCache {
46554655
llvm::Type *Ty1, bool Extract,
46564656
SmallVectorImpl<llvm::Value *> &Ops,
46574657
const CallExpr *E, const char *name);
4658+
llvm::Value *EmitFP8NeonFDOTCall(unsigned IID, bool ExtendLane,
4659+
llvm::Type *RetTy,
4660+
SmallVectorImpl<llvm::Value *> &Ops,
4661+
const CallExpr *E, const char *name);
46584662
llvm::Value *EmitNeonSplat(llvm::Value *V, llvm::Constant *Idx,
46594663
const llvm::ElementCount &Count);
46604664
llvm::Value *EmitNeonSplat(llvm::Value *V, llvm::Constant *Idx);

0 commit comments

Comments
 (0)