Skip to content

Commit 87103a0

Browse files
[AArch64] Implement NEON FP8 vectors as VectorType (llvm#123603)
Reimplement Neon FP8 vector types using attribute `neon_vector_type` instead of having them as builtin types. This allows to implement FP8 Neon intrinsics without the need to add special cases for these types when using `__builtin_shufflevector` or bitcast (using C-style cast operator) between vectors, both extensively used in the generated code in `arm_neon.h`.
1 parent 3b5e9ee commit 87103a0

19 files changed

+553
-69
lines changed

clang/include/clang/AST/Type.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2518,6 +2518,7 @@ class alignas(TypeAlignment) Type : public ExtQualsTypeCommonBase {
25182518
bool isFloat32Type() const;
25192519
bool isDoubleType() const;
25202520
bool isBFloat16Type() const;
2521+
bool isMFloat8Type() const;
25212522
bool isFloat128Type() const;
25222523
bool isIbm128Type() const;
25232524
bool isRealType() const; // C99 6.2.5p17 (real floating + integer)
@@ -8537,6 +8538,10 @@ inline bool Type::isBFloat16Type() const {
85378538
return isSpecificBuiltinType(BuiltinType::BFloat16);
85388539
}
85398540

8541+
inline bool Type::isMFloat8Type() const {
8542+
return isSpecificBuiltinType(BuiltinType::MFloat8);
8543+
}
8544+
85408545
inline bool Type::isFloat128Type() const {
85418546
return isSpecificBuiltinType(BuiltinType::Float128);
85428547
}

clang/include/clang/Basic/AArch64SVEACLETypes.def

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -201,8 +201,6 @@ SVE_PREDICATE_TYPE_ALL("__clang_svboolx4_t", "svboolx4_t", SveBoolx4, SveBoolx4T
201201
SVE_OPAQUE_TYPE("__SVCount_t", "__SVCount_t", SveCount, SveCountTy)
202202

203203
AARCH64_VECTOR_TYPE_MFLOAT("__mfp8", "__mfp8", MFloat8, MFloat8Ty, 1, 8, 1)
204-
AARCH64_VECTOR_TYPE_MFLOAT("__MFloat8x8_t", "__MFloat8x8_t", MFloat8x8, MFloat8x8Ty, 8, 8, 1)
205-
AARCH64_VECTOR_TYPE_MFLOAT("__MFloat8x16_t", "__MFloat8x16_t", MFloat8x16, MFloat8x16Ty, 16, 8, 1)
206204

207205
#undef SVE_VECTOR_TYPE
208206
#undef SVE_VECTOR_TYPE_BFLOAT

clang/include/clang/Basic/TargetBuiltins.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,8 @@ namespace clang {
208208
Float16,
209209
Float32,
210210
Float64,
211-
BFloat16
211+
BFloat16,
212+
MFloat8
212213
};
213214

214215
NeonTypeFlags(unsigned F) : Flags(F) {}
@@ -230,6 +231,7 @@ namespace clang {
230231
switch (getEltType()) {
231232
case Int8:
232233
case Poly8:
234+
case MFloat8:
233235
return 8;
234236
case Int16:
235237
case Float16:

clang/lib/AST/ItaniumMangle.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3919,6 +3919,9 @@ void CXXNameMangler::mangleNeonVectorType(const VectorType *T) {
39193919
case BuiltinType::Float: EltName = "float32_t"; break;
39203920
case BuiltinType::Half: EltName = "float16_t"; break;
39213921
case BuiltinType::BFloat16: EltName = "bfloat16_t"; break;
3922+
case BuiltinType::MFloat8:
3923+
EltName = "mfloat8_t";
3924+
break;
39223925
default:
39233926
llvm_unreachable("unexpected Neon vector element type");
39243927
}
@@ -3972,6 +3975,8 @@ static StringRef mangleAArch64VectorBase(const BuiltinType *EltType) {
39723975
return "Float64";
39733976
case BuiltinType::BFloat16:
39743977
return "Bfloat16";
3978+
case BuiltinType::MFloat8:
3979+
return "Mfloat8";
39753980
default:
39763981
llvm_unreachable("Unexpected vector element base type");
39773982
}

clang/lib/CodeGen/CGBuiltin.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6679,6 +6679,7 @@ static llvm::FixedVectorType *GetNeonType(CodeGenFunction *CGF,
66796679
switch (TypeFlags.getEltType()) {
66806680
case NeonTypeFlags::Int8:
66816681
case NeonTypeFlags::Poly8:
6682+
case NeonTypeFlags::MFloat8:
66826683
return llvm::FixedVectorType::get(CGF->Int8Ty, V1Ty ? 1 : (8 << IsQuad));
66836684
case NeonTypeFlags::Int16:
66846685
case NeonTypeFlags::Poly16:

clang/lib/CodeGen/CGExpr.cpp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2414,8 +2414,15 @@ void CodeGenFunction::EmitStoreThroughLValue(RValue Src, LValue Dst,
24142414
Vec = Builder.CreateBitCast(Vec, IRVecTy);
24152415
// iN --> <N x i1>.
24162416
}
2417-
Vec = Builder.CreateInsertElement(Vec, Src.getScalarVal(),
2418-
Dst.getVectorIdx(), "vecins");
2417+
llvm::Value *SrcVal = Src.getScalarVal();
2418+
// Allow inserting `<1 x T>` into an `<N x T>`. It can happen with scalar
2419+
// types which are mapped to vector LLVM IR types (e.g. for implementing
2420+
// an ABI).
2421+
if (auto *EltTy = dyn_cast<llvm::FixedVectorType>(SrcVal->getType());
2422+
EltTy && EltTy->getNumElements() == 1)
2423+
SrcVal = Builder.CreateBitCast(SrcVal, EltTy->getElementType());
2424+
Vec = Builder.CreateInsertElement(Vec, SrcVal, Dst.getVectorIdx(),
2425+
"vecins");
24192426
if (IRStoreTy) {
24202427
// <N x i1> --> <iN>.
24212428
Vec = Builder.CreateBitCast(Vec, IRStoreTy);

clang/lib/CodeGen/CodeGenTypes.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -650,7 +650,9 @@ llvm::Type *CodeGenTypes::ConvertType(QualType T) {
650650
// An ext_vector_type of Bool is really a vector of bits.
651651
llvm::Type *IRElemTy = VT->isExtVectorBoolType()
652652
? llvm::Type::getInt1Ty(getLLVMContext())
653-
: ConvertType(VT->getElementType());
653+
: (VT->getElementType()->isMFloat8Type()
654+
? llvm::Type::getInt8Ty(getLLVMContext())
655+
: ConvertType(VT->getElementType()));
654656
ResultType = llvm::FixedVectorType::get(IRElemTy, VT->getNumElements());
655657
break;
656658
}

clang/lib/CodeGen/Targets/AArch64.cpp

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -383,10 +383,6 @@ ABIArgInfo AArch64ABIInfo::classifyArgumentType(QualType Ty, bool IsVariadicFn,
383383
NSRN = std::min(NSRN + 1, 8u);
384384
else {
385385
switch (BT->getKind()) {
386-
case BuiltinType::MFloat8x8:
387-
case BuiltinType::MFloat8x16:
388-
NSRN = std::min(NSRN + 1, 8u);
389-
break;
390386
case BuiltinType::SveBool:
391387
case BuiltinType::SveCount:
392388
NPRN = std::min(NPRN + 1, 4u);
@@ -629,8 +625,7 @@ bool AArch64ABIInfo::isHomogeneousAggregateBaseType(QualType Ty) const {
629625
// but with the difference that any floating-point type is allowed,
630626
// including __fp16.
631627
if (const BuiltinType *BT = Ty->getAs<BuiltinType>()) {
632-
if (BT->isFloatingPoint() || BT->getKind() == BuiltinType::MFloat8x16 ||
633-
BT->getKind() == BuiltinType::MFloat8x8)
628+
if (BT->isFloatingPoint())
634629
return true;
635630
} else if (const VectorType *VT = Ty->getAs<VectorType>()) {
636631
if (auto Kind = VT->getVectorKind();

clang/lib/Sema/SemaARM.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -352,6 +352,8 @@ static QualType getNeonEltType(NeonTypeFlags Flags, ASTContext &Context,
352352
return Context.DoubleTy;
353353
case NeonTypeFlags::BFloat16:
354354
return Context.BFloat16Ty;
355+
case NeonTypeFlags::MFloat8:
356+
return Context.MFloat8Ty;
355357
}
356358
llvm_unreachable("Invalid NeonTypeFlag!");
357359
}

clang/lib/Sema/SemaExpr.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7503,7 +7503,7 @@ static bool breakDownVectorType(QualType type, uint64_t &len,
75037503
if (const VectorType *vecType = type->getAs<VectorType>()) {
75047504
len = vecType->getNumElements();
75057505
eltType = vecType->getElementType();
7506-
assert(eltType->isScalarType());
7506+
assert(eltType->isScalarType() || eltType->isMFloat8Type());
75077507
return true;
75087508
}
75097509

@@ -10174,6 +10174,11 @@ QualType Sema::CheckVectorOperands(ExprResult &LHS, ExprResult &RHS,
1017410174
return HLSL().handleVectorBinOpConversion(LHS, RHS, LHSType, RHSType,
1017510175
IsCompAssign);
1017610176

10177+
// Any operation with MFloat8 type is only possible with C intrinsics
10178+
if ((LHSVecType && LHSVecType->getElementType()->isMFloat8Type()) ||
10179+
(RHSVecType && RHSVecType->getElementType()->isMFloat8Type()))
10180+
return InvalidOperands(Loc, LHS, RHS);
10181+
1017710182
// AltiVec-style "vector bool op vector bool" combinations are allowed
1017810183
// for some operators but not others.
1017910184
if (!AllowBothBool && LHSVecType &&

0 commit comments

Comments
 (0)