Skip to content

Commit 066de4a

Browse files
author
Dmitry Sidorov
committed
[Backport to 17] Implement SPV_INTEL_float4 and SPV_INTEL_fp_conversions extensions (KhronosGroup#3419)
As well as their appropriate conversions via __builtin_spirv mechanism. Specification: intel/llvm#20467 Signed-off-by: Dmitry Sidorov <dmitry.sidorov@intel.com>
1 parent f49317c commit 066de4a

File tree

15 files changed

+1125
-105
lines changed

15 files changed

+1125
-105
lines changed

include/LLVMSPIRVExtensions.inc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,3 +84,5 @@ EXT(SPV_INTEL_shader_atomic_bfloat16)
8484
EXT(SPV_EXT_float8)
8585
EXT(SPV_INTEL_predicated_io)
8686
EXT(SPV_INTEL_sigmoid)
87+
EXT(SPV_INTEL_float4)
88+
EXT(SPV_INTEL_fp_conversions)

lib/SPIRV/SPIRVInternal.h

Lines changed: 112 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1063,6 +1063,7 @@ enum FPEncodingWrap {
10631063
BF16 = FPEncoding::FPEncodingBFloat16KHR,
10641064
E4M3 = FPEncoding::FPEncodingFloat8E4M3EXT,
10651065
E5M2 = FPEncoding::FPEncodingFloat8E5M2EXT,
1066+
E2M1 = internal::FPEncodingFloat4E2M1INTEL,
10661067
};
10671068

10681069
// Structure describing non-trivial conversions (FP8 and int4)
@@ -1091,36 +1092,117 @@ typedef SPIRVMap<llvm::StringRef, FPConversionDesc> FPConvertToEncodingMap;
10911092

10921093
// clang-format off
10931094
template <> inline void FPConvertToEncodingMap::init() {
1094-
// 8-bit conversions
1095-
add("ConvertE4M3ToFP16EXT",
1096-
{FPEncodingWrap::E4M3, FPEncodingWrap::IEEE754, OpFConvert});
1097-
add("ConvertE5M2ToFP16EXT",
1098-
{FPEncodingWrap::E5M2, FPEncodingWrap::IEEE754, OpFConvert});
1099-
add("ConvertE4M3ToBF16EXT",
1100-
{FPEncodingWrap::E4M3, FPEncodingWrap::BF16, OpFConvert});
1101-
add("ConvertE5M2ToBF16EXT",
1102-
{FPEncodingWrap::E5M2, FPEncodingWrap::BF16, OpFConvert});
1103-
add("ConvertFP16ToE4M3EXT",
1104-
{FPEncodingWrap::IEEE754, FPEncodingWrap::E4M3, OpFConvert});
1105-
add("ConvertFP16ToE5M2EXT",
1106-
{FPEncodingWrap::IEEE754, FPEncodingWrap::E5M2, OpFConvert});
1107-
add("ConvertBF16ToE4M3EXT",
1108-
{FPEncodingWrap::BF16, FPEncodingWrap::E4M3, OpFConvert});
1109-
add("ConvertBF16ToE5M2EXT",
1110-
{FPEncodingWrap::BF16, FPEncodingWrap::E5M2, OpFConvert});
1111-
1112-
add("ConvertInt4ToE4M3INTEL",
1113-
{FPEncodingWrap::Integer, FPEncodingWrap::E4M3, OpConvertSToF});
1114-
add("ConvertInt4ToE5M2INTEL",
1115-
{FPEncodingWrap::Integer, FPEncodingWrap::E5M2, OpConvertSToF});
1116-
add("ConvertInt4ToFP16INTEL",
1117-
{FPEncodingWrap::Integer, FPEncodingWrap::IEEE754, OpConvertSToF});
1118-
add("ConvertInt4ToBF16INTEL",
1119-
{FPEncodingWrap::Integer, FPEncodingWrap::BF16, OpConvertSToF});
1120-
add("ConvertFP16ToInt4INTEL",
1121-
{FPEncodingWrap::IEEE754, FPEncodingWrap::Integer, OpConvertFToS});
1122-
add("ConvertBF16ToInt4INTEL",
1123-
{FPEncodingWrap::BF16, FPEncodingWrap::Integer, OpConvertFToS});
1095+
// 4-bit conversions
1096+
add("ConvertE2M1ToE4M3INTEL",
1097+
{FPEncodingWrap::E2M1, FPEncodingWrap::E4M3, OpFConvert});
1098+
add("ConvertE2M1ToE5M2INTEL",
1099+
{FPEncodingWrap::E2M1, FPEncodingWrap::E5M2, OpFConvert});
1100+
add("ConvertE2M1ToFP16INTEL",
1101+
{FPEncodingWrap::E2M1, FPEncodingWrap::IEEE754, OpFConvert});
1102+
add("ConvertE2M1ToBF16INTEL",
1103+
{FPEncodingWrap::E2M1, FPEncodingWrap::BF16, OpFConvert});
1104+
1105+
add("ConvertInt4ToE4M3INTEL",
1106+
{FPEncodingWrap::Integer, FPEncodingWrap::E4M3, OpConvertSToF});
1107+
add("ConvertInt4ToE5M2INTEL",
1108+
{FPEncodingWrap::Integer, FPEncodingWrap::E5M2, OpConvertSToF});
1109+
add("ConvertInt4ToFP16INTEL",
1110+
{FPEncodingWrap::Integer, FPEncodingWrap::IEEE754, OpConvertSToF});
1111+
add("ConvertInt4ToBF16INTEL",
1112+
{FPEncodingWrap::Integer, FPEncodingWrap::BF16, OpConvertSToF});
1113+
add("ConvertInt4ToInt8INTEL",
1114+
{FPEncodingWrap::Integer, FPEncodingWrap::Integer, OpSConvert});
1115+
1116+
add("ConvertFP16ToE2M1INTEL",
1117+
{FPEncodingWrap::IEEE754, FPEncodingWrap::E2M1, OpFConvert});
1118+
add("ConvertBF16ToE2M1INTEL",
1119+
{FPEncodingWrap::BF16, FPEncodingWrap::E2M1, OpFConvert});
1120+
add("ConvertFP16ToInt4INTEL",
1121+
{FPEncodingWrap::IEEE754, FPEncodingWrap::Integer, OpConvertFToS});
1122+
add("ConvertBF16ToInt4INTEL",
1123+
{FPEncodingWrap::BF16, FPEncodingWrap::Integer, OpConvertFToS});
1124+
1125+
// 8-bit conversions
1126+
add("ConvertE4M3ToFP16EXT",
1127+
{FPEncodingWrap::E4M3, FPEncodingWrap::IEEE754, OpFConvert});
1128+
add("ConvertE5M2ToFP16EXT",
1129+
{FPEncodingWrap::E5M2, FPEncodingWrap::IEEE754, OpFConvert});
1130+
add("ConvertE4M3ToBF16EXT",
1131+
{FPEncodingWrap::E4M3, FPEncodingWrap::BF16, OpFConvert});
1132+
add("ConvertE5M2ToBF16EXT",
1133+
{FPEncodingWrap::E5M2, FPEncodingWrap::BF16, OpFConvert});
1134+
add("ConvertFP16ToE4M3EXT",
1135+
{FPEncodingWrap::IEEE754, FPEncodingWrap::E4M3, OpFConvert});
1136+
add("ConvertFP16ToE5M2EXT",
1137+
{FPEncodingWrap::IEEE754, FPEncodingWrap::E5M2, OpFConvert});
1138+
add("ConvertBF16ToE4M3EXT",
1139+
{FPEncodingWrap::BF16, FPEncodingWrap::E4M3, OpFConvert});
1140+
add("ConvertBF16ToE5M2EXT",
1141+
{FPEncodingWrap::BF16, FPEncodingWrap::E5M2, OpFConvert});
1142+
1143+
// SPV_INTEL_fp_conversions
1144+
add("ClampConvertFP16ToE2M1INTEL",
1145+
{FPEncodingWrap::IEEE754, FPEncodingWrap::E2M1,
1146+
internal::OpClampConvertFToFINTEL});
1147+
add("ClampConvertBF16ToE2M1INTEL",
1148+
{FPEncodingWrap::BF16, FPEncodingWrap::E2M1,
1149+
internal::OpClampConvertFToFINTEL});
1150+
add("ClampConvertFP16ToE4M3INTEL",
1151+
{FPEncodingWrap::IEEE754, FPEncodingWrap::E4M3,
1152+
internal::OpClampConvertFToFINTEL});
1153+
add("ClampConvertBF16ToE4M3INTEL",
1154+
{FPEncodingWrap::BF16, FPEncodingWrap::E4M3,
1155+
internal::OpClampConvertFToFINTEL});
1156+
add("ClampConvertFP16ToE5M2INTEL",
1157+
{FPEncodingWrap::IEEE754, FPEncodingWrap::E5M2,
1158+
internal::OpClampConvertFToFINTEL});
1159+
add("ClampConvertBF16ToE5M2INTEL",
1160+
{FPEncodingWrap::BF16, FPEncodingWrap::E5M2,
1161+
internal::OpClampConvertFToFINTEL});
1162+
add("ClampConvertFP16ToInt4INTEL",
1163+
{FPEncodingWrap::IEEE754, FPEncodingWrap::Integer,
1164+
internal::OpClampConvertFToSINTEL});
1165+
add("ClampConvertBF16ToInt4INTEL",
1166+
{FPEncodingWrap::BF16, FPEncodingWrap::Integer,
1167+
internal::OpClampConvertFToSINTEL});
1168+
1169+
add("StochasticRoundFP16ToE5M2INTEL",
1170+
{FPEncodingWrap::IEEE754, FPEncodingWrap::E5M2,
1171+
internal::OpStochasticRoundFToFINTEL});
1172+
add("StochasticRoundFP16ToE4M3INTEL",
1173+
{FPEncodingWrap::IEEE754, FPEncodingWrap::E4M3,
1174+
internal::OpStochasticRoundFToFINTEL});
1175+
add("StochasticRoundBF16ToE5M2INTEL",
1176+
{FPEncodingWrap::BF16, FPEncodingWrap::E5M2,
1177+
internal::OpStochasticRoundFToFINTEL});
1178+
add("StochasticRoundBF16ToE4M3INTEL",
1179+
{FPEncodingWrap::BF16, FPEncodingWrap::E4M3,
1180+
internal::OpStochasticRoundFToFINTEL});
1181+
add("StochasticRoundFP16ToE2M1INTEL",
1182+
{FPEncodingWrap::IEEE754, FPEncodingWrap::E2M1,
1183+
internal::OpStochasticRoundFToFINTEL});
1184+
add("StochasticRoundBF16ToE2M1INTEL",
1185+
{FPEncodingWrap::BF16, FPEncodingWrap::E2M1,
1186+
internal::OpStochasticRoundFToFINTEL});
1187+
add("ClampStochasticRoundFP16ToInt4INTEL",
1188+
{FPEncodingWrap::IEEE754, FPEncodingWrap::Integer,
1189+
internal::OpClampStochasticRoundFToSINTEL});
1190+
add("ClampStochasticRoundBF16ToInt4INTEL",
1191+
{FPEncodingWrap::BF16, FPEncodingWrap::Integer,
1192+
internal::OpClampStochasticRoundFToSINTEL});
1193+
1194+
add("ClampStochasticRoundFP16ToE5M2INTEL",
1195+
{FPEncodingWrap::IEEE754, FPEncodingWrap::E5M2,
1196+
internal::OpClampStochasticRoundFToFINTEL});
1197+
add("ClampStochasticRoundFP16ToE4M3INTEL",
1198+
{FPEncodingWrap::IEEE754, FPEncodingWrap::E4M3,
1199+
internal::OpClampStochasticRoundFToFINTEL});
1200+
add("ClampStochasticRoundBF16ToE5M2INTEL",
1201+
{FPEncodingWrap::BF16, FPEncodingWrap::E5M2,
1202+
internal::OpClampStochasticRoundFToFINTEL});
1203+
add("ClampStochasticRoundBF16ToE4M3INTEL",
1204+
{FPEncodingWrap::BF16, FPEncodingWrap::E4M3,
1205+
internal::OpClampStochasticRoundFToFINTEL});
11241206
}
11251207

11261208
// clang-format on

lib/SPIRV/SPIRVReader.cpp

Lines changed: 61 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -313,8 +313,11 @@ std::optional<uint64_t> SPIRVToLLVM::getAlignment(SPIRVValue *V) {
313313

314314
Type *SPIRVToLLVM::transFPType(SPIRVType *T) {
315315
switch (T->getFloatBitWidth()) {
316+
case 4:
317+
// No LLVM IR counter part for FP4 - map it on i4.
318+
return Type::getIntNTy(*Context, 4);
316319
case 8:
317-
// No LLVM IR counter part for FP8 - map it on i8
320+
// No LLVM IR counter part for FP8 - map it on i8.
318321
return Type::getIntNTy(*Context, 8);
319322
case 16:
320323
if (T->isTypeFloat(16, FPEncodingBFloat16KHR))
@@ -1072,11 +1075,12 @@ Value *SPIRVToLLVM::transConvertInst(SPIRVValue *BV, Function *F,
10721075
return FPEncodingWrap::IEEE754;
10731076
};
10741077

1075-
auto IsFP8Encoding = [](FPEncodingWrap Encoding) -> bool {
1076-
return Encoding == FPEncodingWrap::E4M3 || Encoding == FPEncodingWrap::E5M2;
1078+
auto IsFP4OrFP8Encoding = [](FPEncodingWrap Encoding) -> bool {
1079+
return Encoding == FPEncodingWrap::E4M3 ||
1080+
Encoding == FPEncodingWrap::E5M2 || Encoding == FPEncodingWrap::E2M1;
10771081
};
10781082

1079-
switch (BC->getOpCode()) {
1083+
switch (static_cast<unsigned>(BC->getOpCode())) {
10801084
case OpPtrCastToGeneric:
10811085
case OpGenericCastToPtr:
10821086
case OpPtrCastToCrossWorkgroupINTEL:
@@ -1097,6 +1101,11 @@ Value *SPIRVToLLVM::transConvertInst(SPIRVValue *BV, Function *F,
10971101
case OpUConvert:
10981102
CO = IsExt ? Instruction::ZExt : Instruction::Trunc;
10991103
break;
1104+
case internal::OpClampConvertFToFINTEL:
1105+
case internal::OpClampConvertFToSINTEL:
1106+
case internal::OpStochasticRoundFToFINTEL:
1107+
case internal::OpClampStochasticRoundFToFINTEL:
1108+
case internal::OpClampStochasticRoundFToSINTEL:
11001109
case OpConvertSToF:
11011110
case OpConvertFToS:
11021111
case OpConvertUToF:
@@ -1121,7 +1130,7 @@ Value *SPIRVToLLVM::transConvertInst(SPIRVValue *BV, Function *F,
11211130

11221131
FPEncodingWrap SrcEnc = GetEncodingAndUpdateType(SPVSrcTy);
11231132
FPEncodingWrap DstEnc = GetEncodingAndUpdateType(SPVDstTy);
1124-
if (IsFP8Encoding(SrcEnc) || IsFP8Encoding(DstEnc) ||
1133+
if (IsFP4OrFP8Encoding(SrcEnc) || IsFP4OrFP8Encoding(DstEnc) ||
11251134
SPVSrcTy->isTypeInt(4) || SPVDstTy->isTypeInt(4)) {
11261135
FPConversionDesc FPDesc = {SrcEnc, DstEnc, BC->getOpCode()};
11271136
auto Conv = SPIRV::FPConvertToEncodingMap::rmap(FPDesc);
@@ -1131,13 +1140,47 @@ Value *SPIRVToLLVM::transConvertInst(SPIRVValue *BV, Function *F,
11311140
std::string BuiltinName =
11321141
kSPIRVName::InternalBuiltinPrefix + std::string(Conv);
11331142
BuiltinFuncMangleInfo Info;
1134-
std::string MangledName = mangleBuiltin(BuiltinName, OpsTys, &Info);
1143+
std::string MangledName;
1144+
// Translate additional Ops for stochastic conversions.
1145+
if (OC == internal::OpStochasticRoundFToFINTEL ||
1146+
OC == internal::OpClampStochasticRoundFToFINTEL ||
1147+
OC == internal::OpClampStochasticRoundFToSINTEL) {
1148+
// Seed.
1149+
Ops.emplace_back(transValue(SPVOps[1], F, BB, true));
1150+
OpsTys.emplace_back(Ops[1]->getType());
1151+
constexpr unsigned MaxOpsSize = 3;
1152+
if (SPVOps.size() == MaxOpsSize) {
1153+
// New Seed.
1154+
Ops.emplace_back(transValue(SPVOps[2], F, BB, true));
1155+
1156+
// The following mess is needed to create a function with correct
1157+
// mangling.
1158+
SPIRVType *PtrTy = SPVOps[2]->getType();
1159+
const unsigned AS =
1160+
SPIRSPIRVAddrSpaceMap::rmap(PtrTy->getPointerStorageClass());
1161+
Type *ElementTy = transType(PtrTy->getPointerElementType());
1162+
OpsTys.emplace_back(TypedPointerType::get(ElementTy, AS));
1163+
MangledName = mangleBuiltin(BuiltinName, OpsTys, &Info);
1164+
// But to create function itself we need untyped pointer type.
1165+
OpsTys[2] = opaquifyType(OpsTys[2]);
1166+
}
1167+
}
1168+
1169+
if (MangledName.empty())
1170+
MangledName = mangleBuiltin(BuiltinName, OpsTys, &Info);
11351171

11361172
FunctionType *FTy = FunctionType::get(Dst, OpsTys, false);
11371173
FunctionCallee Func = M->getOrInsertFunction(MangledName, FTy);
11381174
return CallInst::Create(Func, Ops, "", BB);
11391175
}
11401176
}
1177+
// These conversions can be done without __builtin_spirv prefixed functions
1178+
// as their operand and result types have native representation in LLVM IR.
1179+
if (OC == internal::OpClampConvertFToFINTEL ||
1180+
OC == internal::OpStochasticRoundFToFINTEL ||
1181+
OC == internal::OpClampStochasticRoundFToFINTEL)
1182+
return mapValue(BV, transSPIRVBuiltinFromInst(
1183+
static_cast<SPIRVInstruction *>(BV), BB));
11411184

11421185
if (OC == OpFConvert) {
11431186
CO = IsExt ? Instruction::FPExt : Instruction::FPTrunc;
@@ -2974,7 +3017,11 @@ Value *SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F,
29743017
if (OutMatrixElementTy->isTypeFloat(8, FPEncodingFloat8E4M3EXT) ||
29753018
OutMatrixElementTy->isTypeFloat(8, FPEncodingFloat8E5M2EXT) ||
29763019
InMatrixElementTy->isTypeFloat(8, FPEncodingFloat8E4M3EXT) ||
2977-
InMatrixElementTy->isTypeFloat(8, FPEncodingFloat8E5M2EXT))
3020+
InMatrixElementTy->isTypeFloat(8, FPEncodingFloat8E5M2EXT) ||
3021+
OutMatrixElementTy->isTypeFloat(
3022+
4, internal::FPEncodingFloat4E2M1INTEL) ||
3023+
InMatrixElementTy->isTypeFloat(4,
3024+
internal::FPEncodingFloat4E2M1INTEL))
29783025
Inst = transConvertInst(BV, F, BB);
29793026
else
29803027
Inst = transSPIRVBuiltinFromInst(BI, BB);
@@ -2983,6 +3030,8 @@ Value *SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F,
29833030
}
29843031
return mapValue(BV, Inst);
29853032
}
3033+
if (isIntelCvtOpCode(OC))
3034+
return mapValue(BV, transConvertInst(BV, F, BB));
29863035
return mapValue(
29873036
BV, transSPIRVBuiltinFromInst(static_cast<SPIRVInstruction *>(BV), BB));
29883037
}
@@ -3706,6 +3755,11 @@ Instruction *SPIRVToLLVM::transSPIRVBuiltinFromInst(SPIRVInstruction *BI,
37063755
case internal::OpCooperativeMatrixLoadCheckedINTEL:
37073756
case internal::OpConvertHandleToImageINTEL:
37083757
case internal::OpConvertHandleToSampledImageINTEL:
3758+
case internal::OpClampConvertFToFINTEL:
3759+
case internal::OpClampConvertFToSINTEL:
3760+
case internal::OpStochasticRoundFToFINTEL:
3761+
case internal::OpClampStochasticRoundFToFINTEL:
3762+
case internal::OpClampStochasticRoundFToSINTEL:
37093763
AddRetTypePostfix = true;
37103764
break;
37113765
default: {

lib/SPIRV/SPIRVToOCL.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,10 @@ void SPIRVToOCLBase::visitCastInst(CastInst &Cast) {
234234
DstVecTy->getScalarSizeInBits() == 1)
235235
return;
236236

237+
// We don't have OpenCL builtins for 4-bit conversions.
238+
if (DstVecTy->getScalarSizeInBits() == 4 || SrcTy->getScalarSizeInBits() == 4)
239+
return;
240+
237241
// Assemble built-in name -> convert_gentypeN
238242
std::string CastBuiltInName(kOCLBuiltinName::ConvertPrefix);
239243
// Check if this is 'floating point -> unsigned integer' cast

0 commit comments

Comments
 (0)