Skip to content

Commit 6cf0457

Browse files
author
Dmitry Sidorov
committed
[Backport to 14] Add __builtin_spirv_ internal builtins (KhronosGroup#3374)
Way they are implemented is described in: KhronosGroup#3221 The PR also adds SPV_EXT_float8 extension and packed conversions for SPV_INTEL_int4 Currently only conversion instructions (and internal builtins) are covered. TODO: in the following PR Saturation decoration will be added. Signed-off-by: Sidorov, Dmitry <dmitry.sidorov@intel.com> --------- Signed-off-by: Sidorov, Dmitry <dmitry.sidorov@intel.com>
1 parent d5bdfc0 commit 6cf0457

File tree

14 files changed

+1074
-13
lines changed

14 files changed

+1074
-13
lines changed

include/LLVMSPIRVExtensions.inc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,8 @@ EXT(SPV_INTEL_subgroup_matrix_multiply_accumulate)
7474
EXT(SPV_KHR_bfloat16)
7575
EXT(SPV_INTEL_bfloat16_arithmetic)
7676
EXT(SPV_INTEL_16bit_atomics)
77+
EXT(SPV_INTEL_shader_atomic_bfloat16)
78+
EXT(SPV_EXT_float8)
7779
EXT(SPV_INTEL_predicated_io)
7880
EXT(SPV_INTEL_sigmoid)
7981
EXT(SPV_INTEL_ternary_bitwise_function)

lib/SPIRV/SPIRVInternal.h

Lines changed: 84 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,7 @@ const static char ConvertHandleToImageINTEL[] = "ConvertHandleToImageINTEL";
394394
const static char ConvertHandleToSamplerINTEL[] = "ConvertHandleToSamplerINTEL";
395395
const static char ConvertHandleToSampledImageINTEL[] =
396396
"ConvertHandleToSampledImageINTEL";
397+
const static char InternalBuiltinPrefix[] = "__builtin_spirv_";
397398
} // namespace kSPIRVName
398399

399400
namespace kSPIRVPostfix {
@@ -722,7 +723,7 @@ Op getSPIRVFuncOC(StringRef Name, SmallVectorImpl<std::string> *Dec = nullptr);
722723
bool getSPIRVBuiltin(const std::string &Name, spv::BuiltIn &Builtin);
723724

724725
/// \param Name LLVM function name
725-
/// \param DemangledName demanged name of the OpenCL built-in function
726+
/// \param DemangledName demangled name of the OpenCL built-in function
726727
/// \returns true if Name is the name of the OpenCL built-in function,
727728
/// false for other functions
728729
bool oclIsBuiltin(StringRef Name, StringRef &DemangledName, bool IsCpp = false);
@@ -799,6 +800,9 @@ CallInst *addCallInst(Module *M, StringRef FuncName, Type *RetTy,
799800
StringRef InstName = SPIR_TEMP_NAME_PREFIX_CALL,
800801
bool TakeFuncName = true);
801802

803+
/// Check if an LLVM type is spirv.CooperativeMatrixKHR.
804+
bool isLLVMCooperativeMatrixType(llvm::Type *Ty);
805+
802806
/// Add a call instruction for SPIR-V builtin function.
803807
CallInst *addCallInstSPIRV(Module *M, StringRef FuncName, Type *RetTy,
804808
ArrayRef<Value *> Args, AttributeList *Attrs,
@@ -1148,6 +1152,85 @@ MetadataAsValue *map2MDString(LLVMContext &C, SPIRVValue *V);
11481152
/// The return value is undefined if the input is larger than the largest power
11491153
/// of two representable in SPIRVWord.
11501154
[[nodiscard]] SPIRVWord bitCeil(SPIRVWord Value);
1155+
1156+
/// \param MangledName LLVM function name.
1157+
/// \param DemangledName demangled name of the input function if it is the
1158+
/// translator's internal built-in function.
1159+
/// \returns true if MangledName is the name of the translator's internal
1160+
/// built-in function, false for other functions.
1161+
/// Used for 'mini'-floats conversion functions
1162+
bool isInternalSPIRVBuiltin(StringRef MangledName, std::string &DemangledName);
1163+
1164+
// Wrapper around SPIR-V 1.6.4 FP Encoding to be used in the conversion
1165+
// descriptor
1166+
enum FPEncodingWrap {
1167+
Integer = FPEncoding::FPEncodingMax - 1,
1168+
IEEE754 = FPEncoding::FPEncodingMax,
1169+
BF16 = FPEncoding::FPEncodingBFloat16KHR,
1170+
E4M3 = FPEncoding::FPEncodingFloat8E4M3EXT,
1171+
E5M2 = FPEncoding::FPEncodingFloat8E5M2EXT,
1172+
};
1173+
1174+
// Structure describing non-trivial conversions (FP8 and int4)
1175+
struct FPConversionDesc {
1176+
public:
1177+
FPEncodingWrap SrcEncoding;
1178+
FPEncodingWrap DstEncoding;
1179+
SPIRVWord ConvOpCode;
1180+
1181+
// To use as a key in std::map
1182+
bool operator==(const FPConversionDesc &Other) const {
1183+
return SrcEncoding == Other.SrcEncoding &&
1184+
DstEncoding == Other.DstEncoding && ConvOpCode == Other.ConvOpCode;
1185+
}
1186+
1187+
bool operator<(const FPConversionDesc &Other) const {
1188+
if (ConvOpCode != Other.ConvOpCode)
1189+
return ConvOpCode < Other.ConvOpCode;
1190+
if (SrcEncoding != Other.SrcEncoding)
1191+
return SrcEncoding < Other.SrcEncoding;
1192+
return DstEncoding < Other.DstEncoding;
1193+
}
1194+
};
1195+
1196+
// Maps internal builtin name to conversion descriptor
1197+
typedef SPIRVMap<llvm::StringRef, FPConversionDesc> FPConvertToEncodingMap;
1198+
1199+
// clang-format off
1200+
template <> inline void FPConvertToEncodingMap::init() {
1201+
// 8-bit conversions
1202+
add("ConvertE4M3ToFP16EXT",
1203+
{FPEncodingWrap::E4M3, FPEncodingWrap::IEEE754, OpFConvert});
1204+
add("ConvertE5M2ToFP16EXT",
1205+
{FPEncodingWrap::E5M2, FPEncodingWrap::IEEE754, OpFConvert});
1206+
add("ConvertE4M3ToBF16EXT",
1207+
{FPEncodingWrap::E4M3, FPEncodingWrap::BF16, OpFConvert});
1208+
add("ConvertE5M2ToBF16EXT",
1209+
{FPEncodingWrap::E5M2, FPEncodingWrap::BF16, OpFConvert});
1210+
add("ConvertFP16ToE4M3EXT",
1211+
{FPEncodingWrap::IEEE754, FPEncodingWrap::E4M3, OpFConvert});
1212+
add("ConvertFP16ToE5M2EXT",
1213+
{FPEncodingWrap::IEEE754, FPEncodingWrap::E5M2, OpFConvert});
1214+
add("ConvertBF16ToE4M3EXT",
1215+
{FPEncodingWrap::BF16, FPEncodingWrap::E4M3, OpFConvert});
1216+
add("ConvertBF16ToE5M2EXT",
1217+
{FPEncodingWrap::BF16, FPEncodingWrap::E5M2, OpFConvert});
1218+
1219+
add("ConvertInt4ToE4M3INTEL",
1220+
{FPEncodingWrap::Integer, FPEncodingWrap::E4M3, OpConvertSToF});
1221+
add("ConvertInt4ToE5M2INTEL",
1222+
{FPEncodingWrap::Integer, FPEncodingWrap::E5M2, OpConvertSToF});
1223+
add("ConvertInt4ToFP16INTEL",
1224+
{FPEncodingWrap::Integer, FPEncodingWrap::IEEE754, OpConvertSToF});
1225+
add("ConvertInt4ToBF16INTEL",
1226+
{FPEncodingWrap::Integer, FPEncodingWrap::BF16, OpConvertSToF});
1227+
add("ConvertFP16ToInt4INTEL",
1228+
{FPEncodingWrap::IEEE754, FPEncodingWrap::Integer, OpConvertFToS});
1229+
add("ConvertBF16ToInt4INTEL",
1230+
{FPEncodingWrap::BF16, FPEncodingWrap::Integer, OpConvertFToS});
1231+
}
1232+
1233+
// clang-format on
11511234
} // namespace SPIRV
11521235

11531236
#endif // SPIRV_SPIRVINTERNAL_H

lib/SPIRV/SPIRVReader.cpp

Lines changed: 116 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,9 @@ llvm::Optional<uint64_t> SPIRVToLLVM::getAlignment(SPIRVValue *V) {
309309

310310
Type *SPIRVToLLVM::transFPType(SPIRVType *T) {
311311
switch (T->getFloatBitWidth()) {
312+
case 8:
313+
// No LLVM IR counter part for FP8 - map it on i8
314+
return Type::getIntNTy(*Context, 8);
312315
case 16:
313316
if (T->isTypeFloat(16, FPEncodingBFloat16KHR))
314317
return Type::getBFloatTy(*Context);
@@ -1051,6 +1054,22 @@ Value *SPIRVToLLVM::transConvertInst(SPIRVValue *BV, Function *F,
10511054
CastInst::CastOps CO = Instruction::BitCast;
10521055
bool IsExt =
10531056
Dst->getScalarSizeInBits() > Src->getType()->getScalarSizeInBits();
1057+
1058+
auto GetFPEncoding = [](SPIRVType *Ty) -> FPEncodingWrap {
1059+
if (Ty->isTypeFloat()) {
1060+
unsigned Enc =
1061+
static_cast<SPIRVTypeFloat *>(Ty)->getFloatingPointEncoding();
1062+
return static_cast<FPEncodingWrap>(Enc);
1063+
}
1064+
if (Ty->isTypeInt())
1065+
return FPEncodingWrap::Integer;
1066+
return FPEncodingWrap::IEEE754;
1067+
};
1068+
1069+
auto IsFP8Encoding = [](FPEncodingWrap Encoding) -> bool {
1070+
return Encoding == FPEncodingWrap::E4M3 || Encoding == FPEncodingWrap::E5M2;
1071+
};
1072+
10541073
switch (BC->getOpCode()) {
10551074
case OpPtrCastToGeneric:
10561075
case OpGenericCastToPtr:
@@ -1072,10 +1091,84 @@ Value *SPIRVToLLVM::transConvertInst(SPIRVValue *BV, Function *F,
10721091
case OpUConvert:
10731092
CO = IsExt ? Instruction::ZExt : Instruction::Trunc;
10741093
break;
1075-
case OpFConvert:
1076-
CO = IsExt ? Instruction::FPExt : Instruction::FPTrunc;
1094+
case OpConvertSToF:
1095+
case OpConvertFToS:
1096+
case OpConvertUToF:
1097+
case OpConvertFToU:
1098+
case OpFConvert: {
1099+
const auto OC = BC->getOpCode();
1100+
{
1101+
auto SPVOps = BC->getOperands();
1102+
auto *SPVSrcTy = SPVOps[0]->getType();
1103+
auto *SPVDstTy = BC->getType();
1104+
1105+
auto GetEncodingAndUpdateType =
1106+
[GetFPEncoding](SPIRVType *&SPVTy) -> FPEncodingWrap {
1107+
if (SPVTy->isTypeVector()) {
1108+
SPVTy = SPVTy->getVectorComponentType();
1109+
} else if (SPVTy->isTypeCooperativeMatrixKHR()) {
1110+
auto *MT = static_cast<SPIRVTypeCooperativeMatrixKHR *>(SPVTy);
1111+
SPVTy = MT->getCompType();
1112+
}
1113+
return GetFPEncoding(SPVTy);
1114+
};
1115+
1116+
FPEncodingWrap SrcEnc = GetEncodingAndUpdateType(SPVSrcTy);
1117+
FPEncodingWrap DstEnc = GetEncodingAndUpdateType(SPVDstTy);
1118+
if (IsFP8Encoding(SrcEnc) || IsFP8Encoding(DstEnc) ||
1119+
SPVSrcTy->isTypeInt(4) || SPVDstTy->isTypeInt(4)) {
1120+
FPConversionDesc FPDesc = {SrcEnc, DstEnc, BC->getOpCode()};
1121+
auto Conv = SPIRV::FPConvertToEncodingMap::rmap(FPDesc);
1122+
std::vector<Value *> Ops = {Src};
1123+
std::vector<Type *> OpsTys = {Src->getType()};
1124+
1125+
std::string BuiltinName =
1126+
kSPIRVName::InternalBuiltinPrefix + std::string(Conv);
1127+
BuiltinFuncMangleInfo Info;
1128+
std::string MangledName;
1129+
// Translate additional Ops for stochastic conversions.
1130+
if (OC == internal::OpStochasticRoundFToFINTEL ||
1131+
OC == internal::OpClampStochasticRoundFToFINTEL ||
1132+
OC == internal::OpClampStochasticRoundFToSINTEL) {
1133+
// Seed.
1134+
Ops.emplace_back(transValue(SPVOps[1], F, BB, true));
1135+
OpsTys.emplace_back(Ops[1]->getType());
1136+
constexpr unsigned MaxOpsSize = 3;
1137+
if (SPVOps.size() == MaxOpsSize) {
1138+
// New Seed.
1139+
Ops.emplace_back(transValue(SPVOps[2], F, BB, true));
1140+
1141+
// The following mess is needed to create a function with correct
1142+
// mangling.
1143+
SPIRVType *PtrTy = SPVOps[2]->getType();
1144+
const unsigned AS =
1145+
SPIRSPIRVAddrSpaceMap::rmap(PtrTy->getPointerStorageClass());
1146+
Type *ElementTy = transType(PtrTy->getPointerElementType());
1147+
// LLVM 15 uses typed pointers natively
1148+
OpsTys.emplace_back(PointerType::get(ElementTy, AS));
1149+
MangledName = mangleBuiltin(BuiltinName, OpsTys, &Info);
1150+
}
1151+
}
1152+
1153+
if (MangledName.empty())
1154+
MangledName = mangleBuiltin(BuiltinName, OpsTys, &Info);
1155+
1156+
FunctionType *FTy = FunctionType::get(Dst, OpsTys, false);
1157+
FunctionCallee Func = M->getOrInsertFunction(MangledName, FTy);
1158+
return CallInst::Create(Func, Ops, "", BB);
1159+
}
1160+
}
1161+
1162+
if (OC == OpFConvert) {
1163+
CO = IsExt ? Instruction::FPExt : Instruction::FPTrunc;
1164+
break;
1165+
}
1166+
CO = static_cast<CastInst::CastOps>(OpCodeMap::rmap(OC));
10771167
break;
1168+
}
10781169
case OpBitcast:
1170+
if (!Dst->isPointerTy() && Dst == Src->getType())
1171+
return Src;
10791172
// OpBitcast need to be handled as a special-case when the source is a
10801173
// pointer and the destination is not a pointer, and where the source is not
10811174
// a pointer and the destination is a pointer. This is supported by the
@@ -2859,11 +2952,29 @@ Value *SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F,
28592952
if (isCvtOpCode(OC) && OC != OpGenericCastToPtrExplicit) {
28602953
auto BI = static_cast<SPIRVInstruction *>(BV);
28612954
Value *Inst = nullptr;
2862-
if (BI->hasFPRoundingMode() || BI->isSaturatedConversion() ||
2863-
BI->getType()->isTypeCooperativeMatrixKHR())
2955+
if (BI->hasFPRoundingMode() || BI->isSaturatedConversion()) {
28642956
Inst = transSPIRVBuiltinFromInst(BI, BB);
2865-
else
2957+
} else if (BI->getType()->isTypeCooperativeMatrixKHR()) {
2958+
// For cooperative matrix conversions generate __builtin_spirv
2959+
// conversions instead of __spirv_FConvert in case of mini-float
2960+
// type element type.
2961+
auto *OutMatrixElementTy =
2962+
static_cast<SPIRVTypeCooperativeMatrixKHR *>(BI->getType())
2963+
->getCompType();
2964+
auto *InMatrixElementTy =
2965+
static_cast<SPIRVTypeCooperativeMatrixKHR *>(
2966+
static_cast<SPIRVUnary *>(BI)->getOperand(0)->getType())
2967+
->getCompType();
2968+
if (OutMatrixElementTy->isTypeFloat(8, FPEncodingFloat8E4M3EXT) ||
2969+
OutMatrixElementTy->isTypeFloat(8, FPEncodingFloat8E5M2EXT) ||
2970+
InMatrixElementTy->isTypeFloat(8, FPEncodingFloat8E4M3EXT) ||
2971+
InMatrixElementTy->isTypeFloat(8, FPEncodingFloat8E5M2EXT))
2972+
Inst = transConvertInst(BV, F, BB);
2973+
else
2974+
Inst = transSPIRVBuiltinFromInst(BI, BB);
2975+
} else {
28662976
Inst = transConvertInst(BV, F, BB);
2977+
}
28672978
return mapValue(BV, Inst);
28682979
}
28692980
return mapValue(

lib/SPIRV/SPIRVUtil.cpp

Lines changed: 81 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,11 @@
3838
///
3939
//===----------------------------------------------------------------------===//
4040

41+
// This file needs to be included before anything that declares
42+
// llvm::PointerType to avoid a compilation bug on MSVC.
43+
#include "llvm/Demangle/Demangle.h"
44+
#include "llvm/Demangle/ItaniumDemangle.h"
45+
4146
#include "FunctionDescriptor.h"
4247
#include "ManglingUtils.h"
4348
#include "NameMangleAPI.h"
@@ -50,7 +55,6 @@
5055

5156
#include "llvm/ADT/StringSwitch.h"
5257
#include "llvm/Bitcode/BitcodeWriter.h"
53-
#include "llvm/Demangle/Demangle.h"
5458
#include "llvm/IR/IRBuilder.h"
5559
#include "llvm/IR/IntrinsicInst.h"
5660
#include "llvm/IR/Metadata.h"
@@ -340,6 +344,22 @@ bool isSYCLBfloat16Type(llvm::Type *Ty) {
340344
return false;
341345
}
342346

347+
bool isLLVMCooperativeMatrixType(llvm::Type *Ty) {
348+
// In LLVM 14/15, TargetExtType doesn't exist, so cooperative matrices are
349+
// represented as struct types with names like "spirv.CooperativeMatrixKHR._..."
350+
auto *ST = dyn_cast<StructType>(Ty);
351+
if (!ST || !ST->hasName())
352+
return false;
353+
354+
StringRef STName = ST->getName();
355+
if (!STName.startswith(kSPIRVTypeName::PrefixAndDelim))
356+
return false;
357+
358+
SmallVector<std::string, 8> Postfixes;
359+
std::string TN = decodeSPIRVTypeName(STName, Postfixes);
360+
return TN == kSPIRVTypeName::CooperativeMatrixKHR;
361+
}
362+
343363
Function *getOrCreateFunction(Module *M, Type *RetTy, ArrayRef<Type *> ArgTypes,
344364
StringRef Name, BuiltinFuncMangleInfo *Mangle,
345365
AttributeList *Attrs, bool TakeName) {
@@ -537,7 +557,7 @@ bool getSPIRVBuiltin(const std::string &OrigName, spv::BuiltIn &B) {
537557
return getByName(R.str(), B);
538558
}
539559

540-
// Demangled name is a substring of the name. The DemangledName is updated only
560+
// DemangledName is a substring of Name. The DemangledName is updated only
541561
// if true is returned
542562
bool oclIsBuiltin(StringRef Name, StringRef &DemangledName, bool IsCpp) {
543563
if (Name == "printf") {
@@ -576,6 +596,65 @@ bool oclIsBuiltin(StringRef Name, StringRef &DemangledName, bool IsCpp) {
576596
return true;
577597
}
578598

599+
// DemangledName is a substring of Name. The DemangledName is updated only
600+
// if true is returned.
601+
bool isInternalSPIRVBuiltin(StringRef Name, std::string &DemangledName) {
602+
if (!Name.startswith("_Z"))
603+
return false;
604+
constexpr unsigned DemangledNameLenStart = 2;
605+
size_t Start = Name.find_first_not_of("0123456789", DemangledNameLenStart);
606+
if (!Name.substr(Start).startswith(kSPIRVName::InternalBuiltinPrefix))
607+
return false;
608+
// LLVM 16 doesn't have itaniumDemangle(const char*, bool) overload
609+
// Try using the C-style API first, but it may fail for types like bfloat
610+
// (DF16b) that LLVM 16's demangler doesn't recognize
611+
char *Demangled =
612+
llvm::itaniumDemangle(Name.data(), nullptr, nullptr, nullptr);
613+
614+
if (Demangled) {
615+
// Demangling succeeded - extract function name from demangled string
616+
std::string FullDemangled(Demangled);
617+
std::free(Demangled);
618+
if (!StringRef(FullDemangled).startswith(kSPIRVName::InternalBuiltinPrefix))
619+
return false;
620+
// Remove the __builtin_spirv_ prefix
621+
std::string WithoutPrefix =
622+
FullDemangled.substr(strlen(kSPIRVName::InternalBuiltinPrefix));
623+
// Extract just the function name (before the opening parenthesis)
624+
size_t ParenPos = WithoutPrefix.find('(');
625+
if (ParenPos != std::string::npos)
626+
DemangledName = WithoutPrefix.substr(0, ParenPos);
627+
else
628+
DemangledName = WithoutPrefix;
629+
} else {
630+
// Demangling failed (e.g., for bfloat types in LLVM 16)
631+
// Fall back to extracting function name directly from the mangled name
632+
// Format: _Z<len>__builtin_spirv_<FunctionName><TypeEncoding>
633+
// All conversion function names end with "EXT" or "INTEL"
634+
StringRef AfterPrefix =
635+
Name.substr(Start + strlen(kSPIRVName::InternalBuiltinPrefix));
636+
637+
size_t ExtPos = AfterPrefix.find("EXT");
638+
size_t IntelPos = AfterPrefix.find("INTEL");
639+
size_t FuncNameEnd = StringRef::npos;
640+
641+
if (ExtPos != StringRef::npos && IntelPos != StringRef::npos)
642+
FuncNameEnd =
643+
std::min(ExtPos + 3, IntelPos + 5); // +3 for "EXT", +5 for "INTEL"
644+
else if (ExtPos != StringRef::npos)
645+
FuncNameEnd = ExtPos + 3;
646+
else if (IntelPos != StringRef::npos)
647+
FuncNameEnd = IntelPos + 5;
648+
649+
if (FuncNameEnd == StringRef::npos)
650+
DemangledName = AfterPrefix.str();
651+
else
652+
DemangledName = AfterPrefix.substr(0, FuncNameEnd).str();
653+
}
654+
655+
return true;
656+
}
657+
579658
// Check if a mangled type Name is unsigned
580659
bool isMangledTypeUnsigned(char Mangled) {
581660
return Mangled == 'h' /* uchar */

0 commit comments

Comments
 (0)