Skip to content

Commit 019578e

Browse files
Implement NEON FP8 vectors as VectorType
Co-Aurhored-By: Caroline Concatto <[email protected]>
1 parent 229d78d commit 019578e

File tree

13 files changed

+62
-34
lines changed

13 files changed

+62
-34
lines changed

clang/include/clang/AST/Type.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2518,6 +2518,7 @@ class alignas(TypeAlignment) Type : public ExtQualsTypeCommonBase {
25182518
bool isFloat32Type() const;
25192519
bool isDoubleType() const;
25202520
bool isBFloat16Type() const;
2521+
bool isMFloat8Type() const;
25212522
bool isFloat128Type() const;
25222523
bool isIbm128Type() const;
25232524
bool isRealType() const; // C99 6.2.5p17 (real floating + integer)
@@ -8532,6 +8533,10 @@ inline bool Type::isBFloat16Type() const {
85328533
return isSpecificBuiltinType(BuiltinType::BFloat16);
85338534
}
85348535

8536+
inline bool Type::isMFloat8Type() const {
8537+
return isSpecificBuiltinType(BuiltinType::MFloat8);
8538+
}
8539+
85358540
inline bool Type::isFloat128Type() const {
85368541
return isSpecificBuiltinType(BuiltinType::Float128);
85378542
}

clang/include/clang/Basic/AArch64SVEACLETypes.def

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -201,8 +201,6 @@ SVE_PREDICATE_TYPE_ALL("__clang_svboolx4_t", "svboolx4_t", SveBoolx4, SveBoolx4T
201201
SVE_OPAQUE_TYPE("__SVCount_t", "__SVCount_t", SveCount, SveCountTy)
202202

203203
AARCH64_VECTOR_TYPE_MFLOAT("__mfp8", "__mfp8", MFloat8, MFloat8Ty, 1, 8, 1)
204-
AARCH64_VECTOR_TYPE_MFLOAT("__MFloat8x8_t", "__MFloat8x8_t", MFloat8x8, MFloat8x8Ty, 8, 8, 1)
205-
AARCH64_VECTOR_TYPE_MFLOAT("__MFloat8x16_t", "__MFloat8x16_t", MFloat8x16, MFloat8x16Ty, 16, 8, 1)
206204

207205
#undef SVE_VECTOR_TYPE
208206
#undef SVE_VECTOR_TYPE_BFLOAT

clang/include/clang/Basic/TargetBuiltins.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,8 @@ namespace clang {
200200
Float16,
201201
Float32,
202202
Float64,
203-
BFloat16
203+
BFloat16,
204+
MFloat8
204205
};
205206

206207
NeonTypeFlags(unsigned F) : Flags(F) {}
@@ -222,6 +223,7 @@ namespace clang {
222223
switch (getEltType()) {
223224
case Int8:
224225
case Poly8:
226+
case MFloat8:
225227
return 8;
226228
case Int16:
227229
case Float16:

clang/lib/AST/ItaniumMangle.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3970,6 +3970,8 @@ static StringRef mangleAArch64VectorBase(const BuiltinType *EltType) {
39703970
return "Float64";
39713971
case BuiltinType::BFloat16:
39723972
return "Bfloat16";
3973+
case BuiltinType::MFloat8:
3974+
return "Mfloat8";
39733975
default:
39743976
llvm_unreachable("Unexpected vector element base type");
39753977
}

clang/lib/CodeGen/CGBuiltin.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6789,6 +6789,7 @@ static llvm::FixedVectorType *GetNeonType(CodeGenFunction *CGF,
67896789
switch (TypeFlags.getEltType()) {
67906790
case NeonTypeFlags::Int8:
67916791
case NeonTypeFlags::Poly8:
6792+
case NeonTypeFlags::MFloat8:
67926793
return llvm::FixedVectorType::get(CGF->Int8Ty, V1Ty ? 1 : (8 << IsQuad));
67936794
case NeonTypeFlags::Int16:
67946795
case NeonTypeFlags::Poly16:

clang/lib/CodeGen/CodeGenTypes.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -650,7 +650,9 @@ llvm::Type *CodeGenTypes::ConvertType(QualType T) {
650650
// An ext_vector_type of Bool is really a vector of bits.
651651
llvm::Type *IRElemTy = VT->isExtVectorBoolType()
652652
? llvm::Type::getInt1Ty(getLLVMContext())
653-
: ConvertType(VT->getElementType());
653+
: (VT->getElementType()->isMFloat8Type()
654+
? llvm::Type::getInt8Ty(getLLVMContext())
655+
: ConvertType(VT->getElementType()));
654656
ResultType = llvm::FixedVectorType::get(IRElemTy, VT->getNumElements());
655657
break;
656658
}

clang/lib/CodeGen/Targets/AArch64.cpp

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -378,10 +378,6 @@ ABIArgInfo AArch64ABIInfo::classifyArgumentType(QualType Ty, bool IsVariadicFn,
378378
NSRN = std::min(NSRN + 1, 8u);
379379
else {
380380
switch (BT->getKind()) {
381-
case BuiltinType::MFloat8x8:
382-
case BuiltinType::MFloat8x16:
383-
NSRN = std::min(NSRN + 1, 8u);
384-
break;
385381
case BuiltinType::SveBool:
386382
case BuiltinType::SveCount:
387383
NPRN = std::min(NPRN + 1, 4u);
@@ -623,8 +619,7 @@ bool AArch64ABIInfo::isHomogeneousAggregateBaseType(QualType Ty) const {
623619
// but with the difference that any floating-point type is allowed,
624620
// including __fp16.
625621
if (const BuiltinType *BT = Ty->getAs<BuiltinType>()) {
626-
if (BT->isFloatingPoint() || BT->getKind() == BuiltinType::MFloat8x16 ||
627-
BT->getKind() == BuiltinType::MFloat8x8)
622+
if (BT->isFloatingPoint())
628623
return true;
629624
} else if (const VectorType *VT = Ty->getAs<VectorType>()) {
630625
if (auto Kind = VT->getVectorKind();

clang/lib/Sema/SemaARM.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -352,6 +352,8 @@ static QualType getNeonEltType(NeonTypeFlags Flags, ASTContext &Context,
352352
return Context.DoubleTy;
353353
case NeonTypeFlags::BFloat16:
354354
return Context.BFloat16Ty;
355+
case NeonTypeFlags::MFloat8:
356+
return Context.MFloat8Ty;
355357
}
356358
llvm_unreachable("Invalid NeonTypeFlag!");
357359
}

clang/lib/Sema/SemaExpr.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10174,6 +10174,11 @@ QualType Sema::CheckVectorOperands(ExprResult &LHS, ExprResult &RHS,
1017410174
return HLSL().handleVectorBinOpConversion(LHS, RHS, LHSType, RHSType,
1017510175
IsCompAssign);
1017610176

10177+
// Any operation with MFloat8 type is only possible with C intrinsics
10178+
if ((LHSVecType && LHSVecType->getElementType()->isMFloat8Type()) ||
10179+
(RHSVecType && RHSVecType->getElementType()->isMFloat8Type()))
10180+
return InvalidOperands(Loc, LHS, RHS);
10181+
1017710182
// AltiVec-style "vector bool op vector bool" combinations are allowed
1017810183
// for some operators but not others.
1017910184
if (!AllowBothBool && LHSVecType &&

clang/lib/Sema/SemaType.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8254,7 +8254,8 @@ static bool isPermittedNeonBaseType(QualType &Ty, VectorKind VecKind, Sema &S) {
82548254
BTy->getKind() == BuiltinType::ULongLong ||
82558255
BTy->getKind() == BuiltinType::Float ||
82568256
BTy->getKind() == BuiltinType::Half ||
8257-
BTy->getKind() == BuiltinType::BFloat16;
8257+
BTy->getKind() == BuiltinType::BFloat16 ||
8258+
BTy->getKind() == BuiltinType::MFloat8;
82588259
}
82598260

82608261
static bool verifyValidIntegerConstantExpr(Sema &S, const ParsedAttr &Attr,

0 commit comments

Comments
 (0)