Skip to content

Commit 99bd2e3

Browse files
[AArch64] Add Neon FP8 conversion intrinsics (llvm#123612)
The patch adds the following intrinsics: bfloat16x8_t vcvt1_bf16_mf8_fpm(mfloat8x8_t vn, fpm_t fpm) bfloat16x8_t vcvt1_low_bf16_mf8_fpm(mfloat8x16_t vn, fpm_t fpm) bfloat16x8_t vcvt2_bf16_mf8_fpm(mfloat8x8_t vn, fpm_t fpm) bfloat16x8_t vcvt2_low_bf16_mf8_fpm(mfloat8x16_t vn, fpm_t fpm) bfloat16x8_t vcvt1_high_bf16_mf8_fpm(mfloat8x16_t vn, fpm_t fpm) bfloat16x8_t vcvt2_high_bf16_mf8_fpm(mfloat8x16_t vn, fpm_t fpm) float16x8_t vcvt1_f16_mf8_fpm(mfloat8x8_t vn, fpm_t fpm) float16x8_t vcvt1_low_f16_mf8_fpm(mfloat8x16_t vn, fpm_t fpm) float16x8_t vcvt2_f16_mf8_fpm(mfloat8x8_t vn, fpm_t fpm) float16x8_t vcvt2_low_f16_mf8_fpm(mfloat8x16_t vn, fpm_t fpm) float16x8_t vcvt1_high_f16_mf8_fpm(mfloat8x16_t vn, fpm_t fpm) float16x8_t vcvt2_high_f16_mf8_fpm(mfloat8x16_t vn, fpm_t fpm) mfloat8x8_t vcvt_mf8_f32_fpm(float32x4_t vn, float32x4_t vm, fpm_t fpm) mfloat8x16_t vcvt_high_mf8_f32_fpm(mfloat8x8_t vd, float32x4_t vn, float32x4_t vm, fpm_t fpm) mfloat8x8_t vcvt_mf8_f16_fpm(float16x4_t vn, float16x4_t vm, fpm_t fpm) mfloat8x16_t vcvtq_mf8_f16_fpm(float16x8_t vn, float16x8_t vm, fpm_t fpm) Co-Authored-By: Caroline Concatto <[email protected]>
1 parent 19f0524 commit 99bd2e3

File tree

11 files changed

+662
-22
lines changed

11 files changed

+662
-22
lines changed

clang/include/clang/Basic/arm_neon.td

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2119,6 +2119,28 @@ let ArchGuard = "defined(__aarch64__)", TargetGuard = "lut" in {
21192119
}
21202120
}
21212121

2122+
let ArchGuard = "defined(__aarch64__)", TargetGuard = "fp8,neon" in {
2123+
def VBF1CVT_BF16_MF8 : VInst<"vcvt1_bf16_mf8_fpm", "(QB).V", "m">;
2124+
def VBF1CVT_LOW_BF16_MF8 : VInst<"vcvt1_low_bf16_mf8_fpm", "B.V", "Hm">;
2125+
def VBF2CVTL_BF16_MF8 : VInst<"vcvt2_bf16_mf8_fpm", "(QB).V", "m">;
2126+
def VBF2CVTL_LOW_BF16_MF8 : VInst<"vcvt2_low_bf16_mf8_fpm", "B.V", "Hm">;
2127+
def VBF1CVTL2_HIGH_BF16_MF8 : VInst<"vcvt1_high_bf16_mf8_fpm", "B.V", "Hm">;
2128+
def VBF2CVTL2_HIGH_BF16_MF8 : VInst<"vcvt2_high_bf16_mf8_fpm", "B.V", "Hm">;
2129+
}
2130+
2131+
let ArchGuard = "defined(__aarch64__)", TargetGuard = "fp8,neon" in {
2132+
def VF1CVT_F16_MF8 : VInst<"vcvt1_f16_mf8_fpm", "(>QF).V", "m">;
2133+
def VF1CVT_LOW_F16_MF8 : VInst<"vcvt1_low_f16_mf8_fpm", "(>F).V", "Hm">;
2134+
def VF2CVTL_F16_MF8 : VInst<"vcvt2_f16_mf8_fpm", "(>QF).V", "m">;
2135+
def VF2CVTL_LOW_F16_MF8 : VInst<"vcvt2_low_f16_mf8_fpm", "(>F).V", "Hm">;
2136+
def VF1CVTL2_HIGH_F16_MF8 : VInst<"vcvt1_high_f16_mf8_fpm", "(>F).V", "Hm">;
2137+
def VF2CVTL2_HIGH_F16_MF8 : VInst<"vcvt2_high_f16_mf8_fpm", "(>F).V", "Hm">;
2138+
2139+
def VCVTN_LOW_F8_F32 : VInst<"vcvt_mf8_f32_fpm", ".(>>QF)(>>QF)V", "m">;
2140+
def VCVTN_HIGH_F8_F32 : VInst<"vcvt_high_mf8_f32_fpm", ".(q)(>>F)(>>F)V", "Hm">;
2141+
def VCVTN_F8_F16 : VInst<"vcvt_mf8_f16_fpm", ".(>F)(>F)V", "mQm">;
2142+
}
2143+
21222144
let ArchGuard = "defined(__aarch64__)", TargetGuard = "neon,faminmax" in {
21232145
def FAMIN : WInst<"vamin", "...", "fhQdQfQh">;
21242146
def FAMAX : WInst<"vamax", "...", "fhQdQfQh">;

clang/include/clang/Basic/arm_neon_incl.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,7 @@ def OP_UNAVAILABLE : Operation {
243243
// B: change to BFloat16
244244
// P: change to polynomial category.
245245
// p: change polynomial to equivalent integer category. Otherwise nop.
246+
// V: change to fpm_t
246247
//
247248
// >: double element width (vector size unchanged).
248249
// <: half element width (vector size unchanged).
@@ -301,6 +302,7 @@ class Inst <string n, string p, string t, Operation o, list<ImmCheck> ch = []>{
301302
class SInst<string n, string p, string t, list<ImmCheck> ch = []> : Inst<n, p, t, OP_NONE, ch> {}
302303
class IInst<string n, string p, string t, list<ImmCheck> ch = []> : Inst<n, p, t, OP_NONE, ch> {}
303304
class WInst<string n, string p, string t, list<ImmCheck> ch = []> : Inst<n, p, t, OP_NONE, ch> {}
305+
class VInst<string n, string p, string t> : Inst<n, p, t, OP_NONE> {}
304306

305307
// The following instruction classes are implemented via operators
306308
// instead of builtins. As such these declarations are only used for

clang/lib/CodeGen/CGBuiltin.cpp

Lines changed: 78 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6759,12 +6759,36 @@ Value *CodeGenFunction::EmitNeonCall(Function *F, SmallVectorImpl<Value*> &Ops,
67596759
return Builder.CreateCall(F, Ops, name);
67606760
}
67616761

6762+
Value *CodeGenFunction::EmitFP8NeonCall(Function *F,
6763+
SmallVectorImpl<Value *> &Ops,
6764+
Value *FPM, const char *name) {
6765+
Builder.CreateCall(CGM.getIntrinsic(Intrinsic::aarch64_set_fpmr), FPM);
6766+
return EmitNeonCall(F, Ops, name);
6767+
}
6768+
67626769
Value *CodeGenFunction::EmitNeonShiftVector(Value *V, llvm::Type *Ty,
67636770
bool neg) {
67646771
int SV = cast<ConstantInt>(V)->getSExtValue();
67656772
return ConstantInt::get(Ty, neg ? -SV : SV);
67666773
}
67676774

6775+
Value *CodeGenFunction::EmitFP8NeonCvtCall(unsigned IID, llvm::Type *Ty0,
6776+
llvm::Type *Ty1, bool Extract,
6777+
SmallVectorImpl<llvm::Value *> &Ops,
6778+
const CallExpr *E,
6779+
const char *name) {
6780+
llvm::Type *Tys[] = {Ty0, Ty1};
6781+
if (Extract) {
6782+
// Op[0] is mfloat8x16_t, but the intrinsic converts only the lower part of
6783+
// the vector.
6784+
Tys[1] = llvm::FixedVectorType::get(Int8Ty, 8);
6785+
Ops[0] = Builder.CreateExtractVector(Tys[1], Ops[0], Builder.getInt64(0));
6786+
}
6787+
llvm::Value *FPM =
6788+
EmitScalarOrConstFoldImmArg(/* ICEArguments */ 0, E->getNumArgs() - 1, E);
6789+
return EmitFP8NeonCall(CGM.getIntrinsic(IID, Tys), Ops, FPM, name);
6790+
}
6791+
67686792
// Right-shift a vector by a constant.
67696793
Value *CodeGenFunction::EmitNeonRShiftImm(Value *Vec, Value *Shift,
67706794
llvm::Type *Ty, bool usgn,
@@ -12736,6 +12760,7 @@ Value *CodeGenFunction::EmitAArch64BuiltinExpr(unsigned BuiltinID,
1273612760
return V;
1273712761

1273812762
unsigned Int;
12763+
bool ExtractLow = false;
1273912764
switch (BuiltinID) {
1274012765
default: return nullptr;
1274112766
case NEON::BI__builtin_neon_vbsl_v:
@@ -13950,7 +13975,59 @@ Value *CodeGenFunction::EmitAArch64BuiltinExpr(unsigned BuiltinID,
1395013975
Int = Intrinsic::aarch64_neon_vluti4q_laneq_x2;
1395113976
return EmitNeonCall(CGM.getIntrinsic(Int, Ty), Ops, "vluti4q_laneq_x2");
1395213977
}
13953-
13978+
case NEON::BI__builtin_neon_vcvt1_low_bf16_mf8_fpm:
13979+
ExtractLow = true;
13980+
LLVM_FALLTHROUGH;
13981+
case NEON::BI__builtin_neon_vcvt1_bf16_mf8_fpm:
13982+
case NEON::BI__builtin_neon_vcvt1_high_bf16_mf8_fpm:
13983+
return EmitFP8NeonCvtCall(Intrinsic::aarch64_neon_fp8_cvtl1,
13984+
llvm::FixedVectorType::get(BFloatTy, 8),
13985+
Ops[0]->getType(), ExtractLow, Ops, E, "vbfcvt1");
13986+
case NEON::BI__builtin_neon_vcvt2_low_bf16_mf8_fpm:
13987+
ExtractLow = true;
13988+
LLVM_FALLTHROUGH;
13989+
case NEON::BI__builtin_neon_vcvt2_bf16_mf8_fpm:
13990+
case NEON::BI__builtin_neon_vcvt2_high_bf16_mf8_fpm:
13991+
return EmitFP8NeonCvtCall(Intrinsic::aarch64_neon_fp8_cvtl2,
13992+
llvm::FixedVectorType::get(BFloatTy, 8),
13993+
Ops[0]->getType(), ExtractLow, Ops, E, "vbfcvt2");
13994+
case NEON::BI__builtin_neon_vcvt1_low_f16_mf8_fpm:
13995+
ExtractLow = true;
13996+
LLVM_FALLTHROUGH;
13997+
case NEON::BI__builtin_neon_vcvt1_f16_mf8_fpm:
13998+
case NEON::BI__builtin_neon_vcvt1_high_f16_mf8_fpm:
13999+
return EmitFP8NeonCvtCall(Intrinsic::aarch64_neon_fp8_cvtl1,
14000+
llvm::FixedVectorType::get(HalfTy, 8),
14001+
Ops[0]->getType(), ExtractLow, Ops, E, "vbfcvt1");
14002+
case NEON::BI__builtin_neon_vcvt2_low_f16_mf8_fpm:
14003+
ExtractLow = true;
14004+
LLVM_FALLTHROUGH;
14005+
case NEON::BI__builtin_neon_vcvt2_f16_mf8_fpm:
14006+
case NEON::BI__builtin_neon_vcvt2_high_f16_mf8_fpm:
14007+
return EmitFP8NeonCvtCall(Intrinsic::aarch64_neon_fp8_cvtl2,
14008+
llvm::FixedVectorType::get(HalfTy, 8),
14009+
Ops[0]->getType(), ExtractLow, Ops, E, "vbfcvt2");
14010+
case NEON::BI__builtin_neon_vcvt_mf8_f32_fpm:
14011+
return EmitFP8NeonCvtCall(Intrinsic::aarch64_neon_fp8_fcvtn,
14012+
llvm::FixedVectorType::get(Int8Ty, 8),
14013+
Ops[0]->getType(), false, Ops, E, "vfcvtn");
14014+
case NEON::BI__builtin_neon_vcvt_mf8_f16_fpm:
14015+
return EmitFP8NeonCvtCall(Intrinsic::aarch64_neon_fp8_fcvtn,
14016+
llvm::FixedVectorType::get(Int8Ty, 8),
14017+
llvm::FixedVectorType::get(HalfTy, 4), false, Ops,
14018+
E, "vfcvtn");
14019+
case NEON::BI__builtin_neon_vcvtq_mf8_f16_fpm:
14020+
return EmitFP8NeonCvtCall(Intrinsic::aarch64_neon_fp8_fcvtn,
14021+
llvm::FixedVectorType::get(Int8Ty, 16),
14022+
llvm::FixedVectorType::get(HalfTy, 8), false, Ops,
14023+
E, "vfcvtn");
14024+
case NEON::BI__builtin_neon_vcvt_high_mf8_f32_fpm: {
14025+
llvm::Type *Ty = llvm::FixedVectorType::get(Int8Ty, 16);
14026+
Ops[0] = Builder.CreateInsertVector(Ty, PoisonValue::get(Ty), Ops[0],
14027+
Builder.getInt64(0));
14028+
return EmitFP8NeonCvtCall(Intrinsic::aarch64_neon_fp8_fcvtn2, Ty,
14029+
Ops[1]->getType(), false, Ops, E, "vfcvtn2");
14030+
}
1395414031
case NEON::BI__builtin_neon_vamin_f16:
1395514032
case NEON::BI__builtin_neon_vaminq_f16:
1395614033
case NEON::BI__builtin_neon_vamin_f32:

clang/lib/CodeGen/CodeGenFunction.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4692,6 +4692,13 @@ class CodeGenFunction : public CodeGenTypeCache {
46924692
SmallVectorImpl<llvm::Value*> &O,
46934693
const char *name,
46944694
unsigned shift = 0, bool rightshift = false);
4695+
llvm::Value *EmitFP8NeonCall(llvm::Function *F,
4696+
SmallVectorImpl<llvm::Value *> &O,
4697+
llvm::Value *FPM, const char *name);
4698+
llvm::Value *EmitFP8NeonCvtCall(unsigned IID, llvm::Type *Ty0,
4699+
llvm::Type *Ty1, bool Extract,
4700+
SmallVectorImpl<llvm::Value *> &Ops,
4701+
const CallExpr *E, const char *name);
46954702
llvm::Value *EmitNeonSplat(llvm::Value *V, llvm::Constant *Idx,
46964703
const llvm::ElementCount &Count);
46974704
llvm::Value *EmitNeonSplat(llvm::Value *V, llvm::Constant *Idx);

0 commit comments

Comments
 (0)