Skip to content

Commit 2e88b01

Browse files
[AArch64] Add Neon FP8 conversion intrinsics
[fixup] Add tests, fix calling the wrong LLVM intrinsic [fixup] Refector much of common code into a helper function (NFC) [fixup] Add target features test, remove redundant bf16 guard [fixup] Clear the NoManglingQ flag for FP8 [fixup] Remove instcombine,tailcallelim from test run lines
1 parent d128c59 commit 2e88b01

File tree

11 files changed

+662
-22
lines changed

11 files changed

+662
-22
lines changed

clang/include/clang/Basic/arm_neon.td

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2125,6 +2125,28 @@ let ArchGuard = "defined(__aarch64__)", TargetGuard = "lut" in {
21252125
}
21262126
}
21272127

2128+
let ArchGuard = "defined(__aarch64__)", TargetGuard = "fp8,neon" in {
2129+
def VBF1CVT_BF16_MF8 : VInst<"vcvt1_bf16_mf8_fpm", "(QB).V", "m">;
2130+
def VBF1CVT_LOW_BF16_MF8 : VInst<"vcvt1_low_bf16_mf8_fpm", "B.V", "Hm">;
2131+
def VBF2CVTL_BF16_MF8 : VInst<"vcvt2_bf16_mf8_fpm", "(QB).V", "m">;
2132+
def VBF2CVTL_LOW_BF16_MF8 : VInst<"vcvt2_low_bf16_mf8_fpm", "B.V", "Hm">;
2133+
def VBF1CVTL2_HIGH_BF16_MF8 : VInst<"vcvt1_high_bf16_mf8_fpm", "B.V", "Hm">;
2134+
def VBF2CVTL2_HIGH_BF16_MF8 : VInst<"vcvt2_high_bf16_mf8_fpm", "B.V", "Hm">;
2135+
}
2136+
2137+
let ArchGuard = "defined(__aarch64__)", TargetGuard = "fp8,neon" in {
2138+
def VF1CVT_F16_MF8 : VInst<"vcvt1_f16_mf8_fpm", "(>QF).V", "m">;
2139+
def VF1CVT_LOW_F16_MF8 : VInst<"vcvt1_low_f16_mf8_fpm", "(>F).V", "Hm">;
2140+
def VF2CVTL_F16_MF8 : VInst<"vcvt2_f16_mf8_fpm", "(>QF).V", "m">;
2141+
def VF2CVTL_LOW_F16_MF8 : VInst<"vcvt2_low_f16_mf8_fpm", "(>F).V", "Hm">;
2142+
def VF1CVTL2_HIGH_F16_MF8 : VInst<"vcvt1_high_f16_mf8_fpm", "(>F).V", "Hm">;
2143+
def VF2CVTL2_HIGH_F16_MF8 : VInst<"vcvt2_high_f16_mf8_fpm", "(>F).V", "Hm">;
2144+
2145+
def VCVTN_LOW_F8_F32 : VInst<"vcvt_mf8_f32_fpm", ".(>>QF)(>>QF)V", "m">;
2146+
def VCVTN_HIGH_F8_F32 : VInst<"vcvt_high_mf8_f32_fpm", ".(q)(>>F)(>>F)V", "Hm">;
2147+
def VCVTN_F8_F16 : VInst<"vcvt_mf8_f16_fpm", ".(>F)(>F)V", "mQm">;
2148+
}
2149+
21282150
let ArchGuard = "defined(__aarch64__)", TargetGuard = "neon,faminmax" in {
21292151
def FAMIN : WInst<"vamin", "...", "fhQdQfQh">;
21302152
def FAMAX : WInst<"vamax", "...", "fhQdQfQh">;

clang/include/clang/Basic/arm_neon_incl.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,7 @@ def OP_UNAVAILABLE : Operation {
243243
// B: change to BFloat16
244244
// P: change to polynomial category.
245245
// p: change polynomial to equivalent integer category. Otherwise nop.
246+
// V: change to fpm_t
246247
//
247248
// >: double element width (vector size unchanged).
248249
// <: half element width (vector size unchanged).
@@ -301,6 +302,7 @@ class Inst <string n, string p, string t, Operation o, list<ImmCheck> ch = []>{
301302
class SInst<string n, string p, string t, list<ImmCheck> ch = []> : Inst<n, p, t, OP_NONE, ch> {}
302303
class IInst<string n, string p, string t, list<ImmCheck> ch = []> : Inst<n, p, t, OP_NONE, ch> {}
303304
class WInst<string n, string p, string t, list<ImmCheck> ch = []> : Inst<n, p, t, OP_NONE, ch> {}
305+
class VInst<string n, string p, string t> : Inst<n, p, t, OP_NONE> {}
304306

305307
// The following instruction classes are implemented via operators
306308
// instead of builtins. As such these declarations are only used for

clang/lib/CodeGen/CGBuiltin.cpp

Lines changed: 78 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6899,12 +6899,36 @@ Value *CodeGenFunction::EmitNeonCall(Function *F, SmallVectorImpl<Value*> &Ops,
68996899
return Builder.CreateCall(F, Ops, name);
69006900
}
69016901

6902+
Value *CodeGenFunction::EmitFP8NeonCall(Function *F,
6903+
SmallVectorImpl<Value *> &Ops,
6904+
Value *FPM, const char *name) {
6905+
Builder.CreateCall(CGM.getIntrinsic(Intrinsic::aarch64_set_fpmr), FPM);
6906+
return EmitNeonCall(F, Ops, name);
6907+
}
6908+
69026909
Value *CodeGenFunction::EmitNeonShiftVector(Value *V, llvm::Type *Ty,
69036910
bool neg) {
69046911
int SV = cast<ConstantInt>(V)->getSExtValue();
69056912
return ConstantInt::get(Ty, neg ? -SV : SV);
69066913
}
69076914

6915+
Value *CodeGenFunction::EmitFP8NeonCvtCall(unsigned IID, llvm::Type *Ty0,
6916+
llvm::Type *Ty1, bool Extract,
6917+
SmallVectorImpl<llvm::Value *> &Ops,
6918+
const CallExpr *E,
6919+
const char *name) {
6920+
llvm::Type *Tys[] = {Ty0, Ty1};
6921+
if (Extract) {
6922+
// Op[0] is mfloat8x16_t, but the intrinsic converts only the lower part of
6923+
// the vector.
6924+
Tys[1] = llvm::FixedVectorType::get(Int8Ty, 8);
6925+
Ops[0] = Builder.CreateExtractVector(Tys[1], Ops[0], Builder.getInt64(0));
6926+
}
6927+
llvm::Value *FPM =
6928+
EmitScalarOrConstFoldImmArg(/* ICEArguments */ 0, E->getNumArgs() - 1, E);
6929+
return EmitFP8NeonCall(CGM.getIntrinsic(IID, Tys), Ops, FPM, name);
6930+
}
6931+
69086932
// Right-shift a vector by a constant.
69096933
Value *CodeGenFunction::EmitNeonRShiftImm(Value *Vec, Value *Shift,
69106934
llvm::Type *Ty, bool usgn,
@@ -12841,6 +12865,7 @@ Value *CodeGenFunction::EmitAArch64BuiltinExpr(unsigned BuiltinID,
1284112865
return V;
1284212866

1284312867
unsigned Int;
12868+
bool ExtractLow = false;
1284412869
switch (BuiltinID) {
1284512870
default: return nullptr;
1284612871
case NEON::BI__builtin_neon_vbsl_v:
@@ -14055,7 +14080,59 @@ Value *CodeGenFunction::EmitAArch64BuiltinExpr(unsigned BuiltinID,
1405514080
Int = Intrinsic::aarch64_neon_vluti4q_laneq_x2;
1405614081
return EmitNeonCall(CGM.getIntrinsic(Int, Ty), Ops, "vluti4q_laneq_x2");
1405714082
}
14058-
14083+
case NEON::BI__builtin_neon_vcvt1_low_bf16_mf8_fpm:
14084+
ExtractLow = true;
14085+
LLVM_FALLTHROUGH;
14086+
case NEON::BI__builtin_neon_vcvt1_bf16_mf8_fpm:
14087+
case NEON::BI__builtin_neon_vcvt1_high_bf16_mf8_fpm:
14088+
return EmitFP8NeonCvtCall(Intrinsic::aarch64_neon_fp8_cvtl1,
14089+
llvm::FixedVectorType::get(BFloatTy, 8),
14090+
Ops[0]->getType(), ExtractLow, Ops, E, "vbfcvt1");
14091+
case NEON::BI__builtin_neon_vcvt2_low_bf16_mf8_fpm:
14092+
ExtractLow = true;
14093+
LLVM_FALLTHROUGH;
14094+
case NEON::BI__builtin_neon_vcvt2_bf16_mf8_fpm:
14095+
case NEON::BI__builtin_neon_vcvt2_high_bf16_mf8_fpm:
14096+
return EmitFP8NeonCvtCall(Intrinsic::aarch64_neon_fp8_cvtl2,
14097+
llvm::FixedVectorType::get(BFloatTy, 8),
14098+
Ops[0]->getType(), ExtractLow, Ops, E, "vbfcvt2");
14099+
case NEON::BI__builtin_neon_vcvt1_low_f16_mf8_fpm:
14100+
ExtractLow = true;
14101+
LLVM_FALLTHROUGH;
14102+
case NEON::BI__builtin_neon_vcvt1_f16_mf8_fpm:
14103+
case NEON::BI__builtin_neon_vcvt1_high_f16_mf8_fpm:
14104+
return EmitFP8NeonCvtCall(Intrinsic::aarch64_neon_fp8_cvtl1,
14105+
llvm::FixedVectorType::get(HalfTy, 8),
14106+
Ops[0]->getType(), ExtractLow, Ops, E, "vbfcvt1");
14107+
case NEON::BI__builtin_neon_vcvt2_low_f16_mf8_fpm:
14108+
ExtractLow = true;
14109+
LLVM_FALLTHROUGH;
14110+
case NEON::BI__builtin_neon_vcvt2_f16_mf8_fpm:
14111+
case NEON::BI__builtin_neon_vcvt2_high_f16_mf8_fpm:
14112+
return EmitFP8NeonCvtCall(Intrinsic::aarch64_neon_fp8_cvtl2,
14113+
llvm::FixedVectorType::get(HalfTy, 8),
14114+
Ops[0]->getType(), ExtractLow, Ops, E, "vbfcvt2");
14115+
case NEON::BI__builtin_neon_vcvt_mf8_f32_fpm:
14116+
return EmitFP8NeonCvtCall(Intrinsic::aarch64_neon_fp8_fcvtn,
14117+
llvm::FixedVectorType::get(Int8Ty, 8),
14118+
Ops[0]->getType(), false, Ops, E, "vfcvtn");
14119+
case NEON::BI__builtin_neon_vcvt_mf8_f16_fpm:
14120+
return EmitFP8NeonCvtCall(Intrinsic::aarch64_neon_fp8_fcvtn,
14121+
llvm::FixedVectorType::get(Int8Ty, 8),
14122+
llvm::FixedVectorType::get(HalfTy, 4), false, Ops,
14123+
E, "vfcvtn");
14124+
case NEON::BI__builtin_neon_vcvtq_mf8_f16_fpm:
14125+
return EmitFP8NeonCvtCall(Intrinsic::aarch64_neon_fp8_fcvtn,
14126+
llvm::FixedVectorType::get(Int8Ty, 16),
14127+
llvm::FixedVectorType::get(HalfTy, 8), false, Ops,
14128+
E, "vfcvtn");
14129+
case NEON::BI__builtin_neon_vcvt_high_mf8_f32_fpm: {
14130+
llvm::Type *Ty = llvm::FixedVectorType::get(Int8Ty, 16);
14131+
Ops[0] = Builder.CreateInsertVector(Ty, PoisonValue::get(Ty), Ops[0],
14132+
Builder.getInt64(0));
14133+
return EmitFP8NeonCvtCall(Intrinsic::aarch64_neon_fp8_fcvtn2, Ty,
14134+
Ops[1]->getType(), false, Ops, E, "vfcvtn2");
14135+
}
1405914136
case NEON::BI__builtin_neon_vamin_f16:
1406014137
case NEON::BI__builtin_neon_vaminq_f16:
1406114138
case NEON::BI__builtin_neon_vamin_f32:

clang/lib/CodeGen/CodeGenFunction.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4691,6 +4691,13 @@ class CodeGenFunction : public CodeGenTypeCache {
46914691
SmallVectorImpl<llvm::Value*> &O,
46924692
const char *name,
46934693
unsigned shift = 0, bool rightshift = false);
4694+
llvm::Value *EmitFP8NeonCall(llvm::Function *F,
4695+
SmallVectorImpl<llvm::Value *> &O,
4696+
llvm::Value *FPM, const char *name);
4697+
llvm::Value *EmitFP8NeonCvtCall(unsigned IID, llvm::Type *Ty0,
4698+
llvm::Type *Ty1, bool Extract,
4699+
SmallVectorImpl<llvm::Value *> &Ops,
4700+
const CallExpr *E, const char *name);
46944701
llvm::Value *EmitNeonSplat(llvm::Value *V, llvm::Constant *Idx,
46954702
const llvm::ElementCount &Count);
46964703
llvm::Value *EmitNeonSplat(llvm::Value *V, llvm::Constant *Idx);

0 commit comments

Comments
 (0)