Skip to content

Commit 0c2e6ad

Browse files
[AArch64] Add Neon FP8 conversion intrinsics
1 parent 0de419a commit 0c2e6ad

File tree

11 files changed

+658
-24
lines changed

11 files changed

+658
-24
lines changed

clang/include/clang/Basic/arm_neon.td

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2125,6 +2125,29 @@ let ArchGuard = "defined(__aarch64__)", TargetGuard = "lut" in {
21252125
}
21262126
}
21272127

2128+
let ArchGuard = "defined(__aarch64__)", TargetGuard = "fp8,neon" in {
2129+
def VBF1CVT_BF16_MF8 : VInst<"vcvt1_bf16_mf8_fpm", "(QB).V", "m">;
2130+
def VBF1CVT_LOW_BF16_MF8 : VInst<"vcvt1_low_bf16_mf8_fpm", "B.V", "Qm">;
2131+
def VBF2CVTL_BF16_MF8 : VInst<"vcvt2_bf16_mf8_fpm", "(QB).V", "m">;
2132+
def VBF2CVTL_LOW_BF16_MF8 : VInst<"vcvt2_low_bf16_mf8_fpm", "B.V", "Qm">;
2133+
def VBF1CVTL2_HIGH_BF16_MF8 : VInst<"vcvt1_high_bf16_mf8_fpm", "B.V", "Qm">;
2134+
def VBF2CVTL2_HIGH_BF16_MF8 : VInst<"vcvt2_high_bf16_mf8_fpm", "B.V", "Qm">;
2135+
}
2136+
2137+
let ArchGuard = "defined(__aarch64__)", TargetGuard = "fp8,neon" in {
2138+
def VF1CVT_F16_MF8 : VInst<"vcvt1_f16_mf8_fpm", "(>QF).V", "m">;
2139+
def VF1CVT_LOW_F16_MF8 : VInst<"vcvt1_low_f16_mf8_fpm", "(>F).V", "Qm">;
2140+
def VF2CVTL_F16_MF8 : VInst<"vcvt2_f16_mf8_fpm", "(>QF).V", "m">;
2141+
def VF2CVTL_LOW_F16_MF8 : VInst<"vcvt2_low_f16_mf8_fpm", "(>F).V", "Qm">;
2142+
def VF1CVTL2_HIGH_F16_MF8 : VInst<"vcvt1_high_f16_mf8_fpm", "(>F).V", "Qm">;
2143+
def VF2CVTL2_HIGH_F16_MF8 : VInst<"vcvt2_high_f16_mf8_fpm", "(>F).V", "Qm">;
2144+
2145+
def VCVTN_LOW_F8_F32 : VInst<"vcvt_mf8_f32_fpm", ".(>>QF)(>>QF)V", "m">;
2146+
def VCVTN_HIGH_F8_F32 : VInst<"vcvt_high_mf8_f32_fpm", ".(q)(>>F)(>>F)V", "Qm">;
2147+
def VCVTN_F8_F16 : VInst<"vcvt_mf8_f16_fpm", ".(>F)(>F)V", "m">;
2148+
def VCVTNQ_F8_F16 : VInst<"vcvtq_mf8_f16_fpm", ".(>F)(>F)V", "Qm">;
2149+
}
2150+
21282151
let ArchGuard = "defined(__aarch64__)", TargetGuard = "neon,faminmax" in {
21292152
def FAMIN : WInst<"vamin", "...", "fhQdQfQh">;
21302153
def FAMAX : WInst<"vamax", "...", "fhQdQfQh">;
@@ -2134,4 +2157,4 @@ let ArchGuard = "defined(__aarch64__)", TargetGuard = "fp8,neon" in {
21342157
// fscale
21352158
def FSCALE_V128 : WInst<"vscale", "..(.S)", "QdQfQh">;
21362159
def FSCALE_V64 : WInst<"vscale", "(.q)(.q)(.qS)", "fh">;
2137-
}
2160+
}

clang/include/clang/Basic/arm_neon_incl.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,7 @@ def OP_UNAVAILABLE : Operation {
243243
// B: change to BFloat16
244244
// P: change to polynomial category.
245245
// p: change polynomial to equivalent integer category. Otherwise nop.
246+
// V: change to fpm_t
246247
//
247248
// >: double element width (vector size unchanged).
248249
// <: half element width (vector size unchanged).
@@ -301,6 +302,7 @@ class Inst <string n, string p, string t, Operation o, list<ImmCheck> ch = []>{
301302
class SInst<string n, string p, string t, list<ImmCheck> ch = []> : Inst<n, p, t, OP_NONE, ch> {}
302303
class IInst<string n, string p, string t, list<ImmCheck> ch = []> : Inst<n, p, t, OP_NONE, ch> {}
303304
class WInst<string n, string p, string t, list<ImmCheck> ch = []> : Inst<n, p, t, OP_NONE, ch> {}
305+
class VInst<string n, string p, string t> : Inst<n, p, t, OP_NONE> {}
304306

305307
// The following instruction classes are implemented via operators
306308
// instead of builtins. As such these declarations are only used for

clang/lib/CodeGen/CGBuiltin.cpp

Lines changed: 78 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6869,12 +6869,36 @@ Value *CodeGenFunction::EmitNeonCall(Function *F, SmallVectorImpl<Value*> &Ops,
68696869
return Builder.CreateCall(F, Ops, name);
68706870
}
68716871

6872+
Value *CodeGenFunction::EmitFP8NeonCall(Function *F,
6873+
SmallVectorImpl<Value *> &Ops,
6874+
Value *FPM, const char *name) {
6875+
Builder.CreateCall(CGM.getIntrinsic(Intrinsic::aarch64_set_fpmr), FPM);
6876+
return EmitNeonCall(F, Ops, name);
6877+
}
6878+
68726879
Value *CodeGenFunction::EmitNeonShiftVector(Value *V, llvm::Type *Ty,
68736880
bool neg) {
68746881
int SV = cast<ConstantInt>(V)->getSExtValue();
68756882
return ConstantInt::get(Ty, neg ? -SV : SV);
68766883
}
68776884

6885+
Value *CodeGenFunction::EmitFP8NeonCvtCall(unsigned IID, llvm::Type *Ty0,
6886+
llvm::Type *Ty1, bool Extract,
6887+
SmallVectorImpl<llvm::Value *> &Ops,
6888+
const CallExpr *E,
6889+
const char *name) {
6890+
llvm::Type *Tys[] = {Ty0, Ty1};
6891+
if (Extract) {
6892+
// Op[0] is mfloat8x16_t, but the intrinsic converts only the lower part of
6893+
// the vector.
6894+
Tys[1] = llvm::FixedVectorType::get(Int8Ty, 8);
6895+
Ops[0] = Builder.CreateExtractVector(Tys[1], Ops[0], Builder.getInt64(0));
6896+
}
6897+
llvm::Value *FPM =
6898+
EmitScalarOrConstFoldImmArg(/* ICEArguments */ 0, E->getNumArgs() - 1, E);
6899+
return EmitFP8NeonCall(CGM.getIntrinsic(IID, Tys), Ops, FPM, name);
6900+
}
6901+
68786902
// Right-shift a vector by a constant.
68796903
Value *CodeGenFunction::EmitNeonRShiftImm(Value *Vec, Value *Shift,
68806904
llvm::Type *Ty, bool usgn,
@@ -12798,6 +12822,7 @@ Value *CodeGenFunction::EmitAArch64BuiltinExpr(unsigned BuiltinID,
1279812822
return V;
1279912823

1280012824
unsigned Int;
12825+
bool ExtractLow = false;
1280112826
switch (BuiltinID) {
1280212827
default: return nullptr;
1280312828
case NEON::BI__builtin_neon_vbsl_v:
@@ -14012,7 +14037,59 @@ Value *CodeGenFunction::EmitAArch64BuiltinExpr(unsigned BuiltinID,
1401214037
Int = Intrinsic::aarch64_neon_vluti4q_laneq_x2;
1401314038
return EmitNeonCall(CGM.getIntrinsic(Int, Ty), Ops, "vluti4q_laneq_x2");
1401414039
}
14015-
14040+
case NEON::BI__builtin_neon_vcvt1_low_bf16_mf8_fpm:
14041+
ExtractLow = true;
14042+
LLVM_FALLTHROUGH;
14043+
case NEON::BI__builtin_neon_vcvt1_bf16_mf8_fpm:
14044+
case NEON::BI__builtin_neon_vcvt1_high_bf16_mf8_fpm:
14045+
return EmitFP8NeonCvtCall(Intrinsic::aarch64_neon_fp8_cvtl1,
14046+
llvm::FixedVectorType::get(BFloatTy, 8),
14047+
Ops[0]->getType(), ExtractLow, Ops, E, "vbfcvt1");
14048+
case NEON::BI__builtin_neon_vcvt2_low_bf16_mf8_fpm:
14049+
ExtractLow = true;
14050+
LLVM_FALLTHROUGH;
14051+
case NEON::BI__builtin_neon_vcvt2_bf16_mf8_fpm:
14052+
case NEON::BI__builtin_neon_vcvt2_high_bf16_mf8_fpm:
14053+
return EmitFP8NeonCvtCall(Intrinsic::aarch64_neon_fp8_cvtl2,
14054+
llvm::FixedVectorType::get(BFloatTy, 8),
14055+
Ops[0]->getType(), ExtractLow, Ops, E, "vbfcvt2");
14056+
case NEON::BI__builtin_neon_vcvt1_low_f16_mf8_fpm:
14057+
ExtractLow = true;
14058+
LLVM_FALLTHROUGH;
14059+
case NEON::BI__builtin_neon_vcvt1_f16_mf8_fpm:
14060+
case NEON::BI__builtin_neon_vcvt1_high_f16_mf8_fpm:
14061+
return EmitFP8NeonCvtCall(Intrinsic::aarch64_neon_fp8_cvtl1,
14062+
llvm::FixedVectorType::get(HalfTy, 8),
14063+
Ops[0]->getType(), ExtractLow, Ops, E, "vbfcvt1");
14064+
case NEON::BI__builtin_neon_vcvt2_low_f16_mf8_fpm:
14065+
ExtractLow = true;
14066+
LLVM_FALLTHROUGH;
14067+
case NEON::BI__builtin_neon_vcvt2_f16_mf8_fpm:
14068+
case NEON::BI__builtin_neon_vcvt2_high_f16_mf8_fpm:
14069+
return EmitFP8NeonCvtCall(Intrinsic::aarch64_neon_fp8_cvtl2,
14070+
llvm::FixedVectorType::get(HalfTy, 8),
14071+
Ops[0]->getType(), ExtractLow, Ops, E, "vbfcvt2");
14072+
case NEON::BI__builtin_neon_vcvt_mf8_f32_fpm:
14073+
return EmitFP8NeonCvtCall(Intrinsic::aarch64_neon_fp8_fcvtn,
14074+
llvm::FixedVectorType::get(Int8Ty, 8),
14075+
Ops[0]->getType(), false, Ops, E, "vfcvtn");
14076+
case NEON::BI__builtin_neon_vcvt_mf8_f16_fpm:
14077+
return EmitFP8NeonCvtCall(Intrinsic::aarch64_neon_fp8_fcvtn,
14078+
llvm::FixedVectorType::get(Int8Ty, 8),
14079+
llvm::FixedVectorType::get(HalfTy, 4), false, Ops,
14080+
E, "vfcvtn");
14081+
case NEON::BI__builtin_neon_vcvtq_mf8_f16_fpm:
14082+
return EmitFP8NeonCvtCall(Intrinsic::aarch64_neon_fp8_fcvtn,
14083+
llvm::FixedVectorType::get(Int8Ty, 16),
14084+
llvm::FixedVectorType::get(HalfTy, 8), false, Ops,
14085+
E, "vfcvtn");
14086+
case NEON::BI__builtin_neon_vcvt_high_mf8_f32_fpm: {
14087+
llvm::Type *Ty = llvm::FixedVectorType::get(Int8Ty, 16);
14088+
Ops[0] = Builder.CreateInsertVector(Ty, PoisonValue::get(Ty), Ops[0],
14089+
Builder.getInt64(0));
14090+
return EmitFP8NeonCvtCall(Intrinsic::aarch64_neon_fp8_fcvtn2,
14091+
Ty, Ops[1]->getType(), false, Ops, E, "vfcvtn2");
14092+
}
1401614093
case NEON::BI__builtin_neon_vamin_f16:
1401714094
case NEON::BI__builtin_neon_vaminq_f16:
1401814095
case NEON::BI__builtin_neon_vamin_f32:

clang/lib/CodeGen/CodeGenFunction.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4648,6 +4648,13 @@ class CodeGenFunction : public CodeGenTypeCache {
46484648
SmallVectorImpl<llvm::Value*> &O,
46494649
const char *name,
46504650
unsigned shift = 0, bool rightshift = false);
4651+
llvm::Value *EmitFP8NeonCall(llvm::Function *F,
4652+
SmallVectorImpl<llvm::Value *> &O,
4653+
llvm::Value *FPM, const char *name);
4654+
llvm::Value *EmitFP8NeonCvtCall(unsigned IID, llvm::Type *Ty0,
4655+
llvm::Type *Ty1, bool Extract,
4656+
SmallVectorImpl<llvm::Value *> &Ops,
4657+
const CallExpr *E, const char *name);
46514658
llvm::Value *EmitNeonSplat(llvm::Value *V, llvm::Constant *Idx,
46524659
const llvm::ElementCount &Count);
46534660
llvm::Value *EmitNeonSplat(llvm::Value *V, llvm::Constant *Idx);

0 commit comments

Comments
 (0)