Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/LLVMSPIRVExtensions.inc
Original file line number Diff line number Diff line change
Expand Up @@ -82,4 +82,5 @@ EXT(SPV_INTEL_ternary_bitwise_function)
EXT(SPV_INTEL_int4)
EXT(SPV_INTEL_function_variants)
EXT(SPV_INTEL_shader_atomic_bfloat16)
EXT(SPV_EXT_float8)
EXT(SPV_INTEL_predicated_io)
84 changes: 83 additions & 1 deletion lib/SPIRV/SPIRVInternal.h
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,7 @@ const static char ConvertHandleToImageINTEL[] = "ConvertHandleToImageINTEL";
const static char ConvertHandleToSamplerINTEL[] = "ConvertHandleToSamplerINTEL";
const static char ConvertHandleToSampledImageINTEL[] =
"ConvertHandleToSampledImageINTEL";
const static char InternalBuiltinPrefix[] = "__builtin_spirv_";
} // namespace kSPIRVName

namespace kSPIRVPostfix {
Expand Down Expand Up @@ -665,7 +666,7 @@ Op getSPIRVFuncOC(StringRef Name, SmallVectorImpl<std::string> *Dec = nullptr);
bool getSPIRVBuiltin(const std::string &Name, spv::BuiltIn &Builtin);

/// \param Name LLVM function name
/// \param DemangledName demanged name of the OpenCL built-in function
/// \param DemangledName demangled name of the OpenCL built-in function
/// \returns true if Name is the name of the OpenCL built-in function,
/// false for other functions
bool oclIsBuiltin(StringRef Name, StringRef &DemangledName, bool IsCpp = false);
Expand Down Expand Up @@ -728,6 +729,9 @@ CallInst *addCallInst(Module *M, StringRef FuncName, Type *RetTy,
StringRef InstName = SPIR_TEMP_NAME_PREFIX_CALL,
bool TakeFuncName = true);

/// Check if an LLVM type is spirv.CooperativeMatrixKHR.
bool isLLVMCooperativeMatrixType(llvm::Type *Ty);

/// Add a call instruction for SPIR-V builtin function.
CallInst *addCallInstSPIRV(Module *M, StringRef FuncName, Type *RetTy,
ArrayRef<Value *> Args, AttributeList *Attrs,
Expand Down Expand Up @@ -1029,6 +1033,84 @@ bool postProcessBuiltinsReturningStruct(Module *M, bool IsCpp = false);

bool postProcessBuiltinsWithArrayArguments(Module *M, bool IsCpp = false);

/// \param MangledName LLVM function name.
/// \param DemangledName demangled name of the input function if it is the
/// translator's internal built-in function.
/// \returns true if MangledName is the name of the translator's internal
/// built-in function, false for other functions.
/// Used for 'mini'-floats conversion functions
bool isInternalSPIRVBuiltin(StringRef MangledName, StringRef &DemangledName);

// Wrapper around SPIR-V 1.6.4 FP Encoding to be used in the conversion
// descriptor
enum FPEncodingWrap {
Integer = FPEncoding::FPEncodingMax - 1,
IEEE754 = FPEncoding::FPEncodingMax,
BF16 = FPEncoding::FPEncodingBFloat16KHR,
E4M3 = FPEncoding::FPEncodingFloat8E4M3EXT,
E5M2 = FPEncoding::FPEncodingFloat8E5M2EXT,
};

// Structure describing non-trivial conversions (FP8 and int4)
struct FPConversionDesc {
FPEncodingWrap SrcEncoding;
FPEncodingWrap DstEncoding;
SPIRVWord ConvOpCode;

// To use as a key in std::map
bool operator==(const FPConversionDesc &Other) const {
return SrcEncoding == Other.SrcEncoding &&
DstEncoding == Other.DstEncoding && ConvOpCode == Other.ConvOpCode;
}

bool operator<(const FPConversionDesc &Other) const {
if (ConvOpCode != Other.ConvOpCode)
return ConvOpCode < Other.ConvOpCode;
if (SrcEncoding != Other.SrcEncoding)
return SrcEncoding < Other.SrcEncoding;
return DstEncoding < Other.DstEncoding;
}
};

// Maps internal builtin name to conversion descriptor
typedef SPIRVMap<llvm::StringRef, FPConversionDesc> FPConvertToEncodingMap;

// clang-format off
template <> inline void FPConvertToEncodingMap::init() {
// 8-bit conversions
add("ConvertE4M3ToFP16EXT",
{FPEncodingWrap::E4M3, FPEncodingWrap::IEEE754, OpFConvert});
add("ConvertE5M2ToFP16EXT",
{FPEncodingWrap::E5M2, FPEncodingWrap::IEEE754, OpFConvert});
add("ConvertE4M3ToBF16EXT",
{FPEncodingWrap::E4M3, FPEncodingWrap::BF16, OpFConvert});
add("ConvertE5M2ToBF16EXT",
{FPEncodingWrap::E5M2, FPEncodingWrap::BF16, OpFConvert});
add("ConvertFP16ToE4M3EXT",
{FPEncodingWrap::IEEE754, FPEncodingWrap::E4M3, OpFConvert});
add("ConvertFP16ToE5M2EXT",
{FPEncodingWrap::IEEE754, FPEncodingWrap::E5M2, OpFConvert});
add("ConvertBF16ToE4M3EXT",
{FPEncodingWrap::BF16, FPEncodingWrap::E4M3, OpFConvert});
add("ConvertBF16ToE5M2EXT",
{FPEncodingWrap::BF16, FPEncodingWrap::E5M2, OpFConvert});

add("ConvertInt4ToE4M3INTEL",
{FPEncodingWrap::Integer, FPEncodingWrap::E4M3, OpConvertSToF});
add("ConvertInt4ToE5M2INTEL",
{FPEncodingWrap::Integer, FPEncodingWrap::E5M2, OpConvertSToF});
add("ConvertInt4ToFP16INTEL",
{FPEncodingWrap::Integer, FPEncodingWrap::IEEE754, OpConvertSToF});
add("ConvertInt4ToBF16INTEL",
{FPEncodingWrap::Integer, FPEncodingWrap::BF16, OpConvertSToF});
add("ConvertFP16ToInt4INTEL",
{FPEncodingWrap::IEEE754, FPEncodingWrap::Integer, OpConvertFToS});
add("ConvertBF16ToInt4INTEL",
{FPEncodingWrap::BF16, FPEncodingWrap::Integer, OpConvertFToS});
}

// clang-format on

} // namespace SPIRV

#endif // SPIRV_SPIRVINTERNAL_H
95 changes: 90 additions & 5 deletions lib/SPIRV/SPIRVReader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,9 @@ std::optional<uint64_t> SPIRVToLLVM::getAlignment(SPIRVValue *V) {

Type *SPIRVToLLVM::transFPType(SPIRVType *T) {
switch (T->getFloatBitWidth()) {
case 8:
// No LLVM IR counter part for FP8 - map it on i8
return Type::getIntNTy(*Context, 8);
case 16:
if (T->isTypeFloat(16, FPEncodingBFloat16KHR))
return Type::getBFloatTy(*Context);
Expand Down Expand Up @@ -1049,6 +1052,22 @@ Value *SPIRVToLLVM::transConvertInst(SPIRVValue *BV, Function *F,
CastInst::CastOps CO = Instruction::BitCast;
bool IsExt =
Dst->getScalarSizeInBits() > Src->getType()->getScalarSizeInBits();

auto GetFPEncoding = [](SPIRVType *Ty) -> FPEncodingWrap {
if (Ty->isTypeFloat()) {
unsigned Enc =
static_cast<SPIRVTypeFloat *>(Ty)->getFloatingPointEncoding();
return static_cast<FPEncodingWrap>(Enc);
}
if (Ty->isTypeInt())
return FPEncodingWrap::Integer;
return FPEncodingWrap::IEEE754;
};

auto IsFP8Encoding = [](FPEncodingWrap Encoding) -> bool {
return Encoding == FPEncodingWrap::E4M3 || Encoding == FPEncodingWrap::E5M2;
};

switch (BC->getOpCode()) {
case OpPtrCastToGeneric:
case OpGenericCastToPtr:
Expand All @@ -1070,10 +1089,58 @@ Value *SPIRVToLLVM::transConvertInst(SPIRVValue *BV, Function *F,
case OpUConvert:
CO = IsExt ? Instruction::ZExt : Instruction::Trunc;
break;
case OpFConvert:
CO = IsExt ? Instruction::FPExt : Instruction::FPTrunc;
case OpConvertSToF:
case OpConvertFToS:
case OpConvertUToF:
case OpConvertFToU:
case OpFConvert: {
const auto OC = BC->getOpCode();
{
auto SPVOps = BC->getOperands();
auto *SPVSrcTy = SPVOps[0]->getType();
auto *SPVDstTy = BC->getType();

auto GetEncodingAndUpdateType =
[GetFPEncoding](SPIRVType *&SPVTy) -> FPEncodingWrap {
if (SPVTy->isTypeVector()) {
SPVTy = SPVTy->getVectorComponentType();
} else if (SPVTy->isTypeCooperativeMatrixKHR()) {
auto *MT = static_cast<SPIRVTypeCooperativeMatrixKHR *>(SPVTy);
SPVTy = MT->getCompType();
}
return GetFPEncoding(SPVTy);
};

FPEncodingWrap SrcEnc = GetEncodingAndUpdateType(SPVSrcTy);
FPEncodingWrap DstEnc = GetEncodingAndUpdateType(SPVDstTy);
if (IsFP8Encoding(SrcEnc) || IsFP8Encoding(DstEnc) ||
SPVSrcTy->isTypeInt(4) || SPVDstTy->isTypeInt(4)) {
FPConversionDesc FPDesc = {SrcEnc, DstEnc, BC->getOpCode()};
auto Conv = SPIRV::FPConvertToEncodingMap::rmap(FPDesc);
std::vector<Value *> Ops = {Src};
std::vector<Type *> OpsTys = {Src->getType()};

std::string BuiltinName =
kSPIRVName::InternalBuiltinPrefix + std::string(Conv);
BuiltinFuncMangleInfo Info;
std::string MangledName = mangleBuiltin(BuiltinName, OpsTys, &Info);

FunctionType *FTy = FunctionType::get(Dst, OpsTys, false);
FunctionCallee Func = M->getOrInsertFunction(MangledName, FTy);
return CallInst::Create(Func, Ops, "", BB);
}
}

if (OC == OpFConvert) {
CO = IsExt ? Instruction::FPExt : Instruction::FPTrunc;
break;
}
CO = static_cast<CastInst::CastOps>(OpCodeMap::rmap(OC));
break;
}
case OpBitcast:
if (!Dst->isPointerTy() && Dst == Src->getType())
return Src;
// OpBitcast need to be handled as a special-case when the source is a
// pointer and the destination is not a pointer, and where the source is not
// a pointer and the destination is a pointer. This is supported by the
Expand Down Expand Up @@ -2970,11 +3037,29 @@ Value *SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F,
if (isCvtOpCode(OC) && OC != OpGenericCastToPtrExplicit) {
auto *BI = static_cast<SPIRVInstruction *>(BV);
Value *Inst = nullptr;
if (BI->hasFPRoundingMode() || BI->isSaturatedConversion() ||
BI->getType()->isTypeCooperativeMatrixKHR())
if (BI->hasFPRoundingMode() || BI->isSaturatedConversion()) {
Inst = transSPIRVBuiltinFromInst(BI, BB);
else
} else if (BI->getType()->isTypeCooperativeMatrixKHR()) {
// For cooperative matrix conversions generate __builtin_spirv
// conversions instead of __spirv_FConvert in case of mini-float
// type element type.
auto *OutMatrixElementTy =
static_cast<SPIRVTypeCooperativeMatrixKHR *>(BI->getType())
->getCompType();
auto *InMatrixElementTy =
static_cast<SPIRVTypeCooperativeMatrixKHR *>(
static_cast<SPIRVUnary *>(BI)->getOperand(0)->getType())
->getCompType();
if (OutMatrixElementTy->isTypeFloat(8, FPEncodingFloat8E4M3EXT) ||
OutMatrixElementTy->isTypeFloat(8, FPEncodingFloat8E5M2EXT) ||
InMatrixElementTy->isTypeFloat(8, FPEncodingFloat8E4M3EXT) ||
InMatrixElementTy->isTypeFloat(8, FPEncodingFloat8E5M2EXT))
Inst = transConvertInst(BV, F, BB);
else
Inst = transSPIRVBuiltinFromInst(BI, BB);
} else {
Inst = transConvertInst(BV, F, BB);
}
return mapValue(BV, Inst);
}
return mapValue(
Expand Down
24 changes: 23 additions & 1 deletion lib/SPIRV/SPIRVUtil.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@

// This file needs to be included before anything that declares
// llvm::PointerType to avoid a compilation bug on MSVC.
#include "llvm/Demangle/Demangle.h"
#include "llvm/Demangle/ItaniumDemangle.h"

#include "FunctionDescriptor.h"
Expand Down Expand Up @@ -267,6 +268,12 @@ bool isSYCLBfloat16Type(llvm::Type *Ty) {
return false;
}

bool isLLVMCooperativeMatrixType(llvm::Type *Ty) {
if (auto *TargetTy = dyn_cast<TargetExtType>(Ty))
return TargetTy->getName() == "spirv.CooperativeMatrixKHR";
return false;
}

Function *getOrCreateFunction(Module *M, Type *RetTy, ArrayRef<Type *> ArgTypes,
StringRef Name, BuiltinFuncMangleInfo *Mangle,
AttributeList *Attrs, bool TakeName) {
Expand Down Expand Up @@ -439,7 +446,7 @@ bool getSPIRVBuiltin(const std::string &OrigName, spv::BuiltIn &B) {
return getByName(R.str(), B);
}

// Demangled name is a substring of the name. The DemangledName is updated only
// DemangledName is a substring of Name. The DemangledName is updated only
// if true is returned
bool oclIsBuiltin(StringRef Name, StringRef &DemangledName, bool IsCpp) {
if (Name == "printf") {
Expand Down Expand Up @@ -484,6 +491,21 @@ bool oclIsBuiltin(StringRef Name, StringRef &DemangledName, bool IsCpp) {
return false;
}

// DemangledName is a substring of Name. The DemangledName is updated only
// if true is returned.
bool isInternalSPIRVBuiltin(StringRef Name, StringRef &DemangledName) {
if (!Name.starts_with("_Z"))
return false;
constexpr unsigned DemangledNameLenStart = 2;
size_t Start = Name.find_first_not_of("0123456789", DemangledNameLenStart);
if (!Name.substr(Start, Name.size() - 1)
.starts_with(kSPIRVName::InternalBuiltinPrefix))
return false;
DemangledName = llvm::itaniumDemangle(Name.data(), false);
DemangledName.consume_front(kSPIRVName::InternalBuiltinPrefix);
return true;
}

// Check if a mangled type Name is unsigned
bool isMangledTypeUnsigned(char Mangled) {
return Mangled == 'h' /* uchar */
Expand Down
Loading