Skip to content

Commit 29d77bf

Browse files
author
Dmitry Sidorov
committed
[Backport to 21] Implement SPV_INTEL_float4 and SPV_INTEL_fp_conversions extensions (KhronosGroup#3419)
As well as their appropriate conversions via __builtin_spirv mechanism. Specification: intel/llvm#20467 Signed-off-by: Dmitry Sidorov <dmitry.sidorov@intel.com>
1 parent 8bc30a9 commit 29d77bf

File tree

15 files changed

+1127
-106
lines changed

15 files changed

+1127
-106
lines changed

include/LLVMSPIRVExtensions.inc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,3 +85,5 @@ EXT(SPV_INTEL_shader_atomic_bfloat16)
8585
EXT(SPV_EXT_float8)
8686
EXT(SPV_INTEL_predicated_io)
8787
EXT(SPV_INTEL_sigmoid)
88+
EXT(SPV_INTEL_float4)
89+
EXT(SPV_INTEL_fp_conversions)

lib/SPIRV/SPIRVInternal.h

Lines changed: 112 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1048,6 +1048,7 @@ enum FPEncodingWrap {
10481048
BF16 = FPEncoding::FPEncodingBFloat16KHR,
10491049
E4M3 = FPEncoding::FPEncodingFloat8E4M3EXT,
10501050
E5M2 = FPEncoding::FPEncodingFloat8E5M2EXT,
1051+
E2M1 = internal::FPEncodingFloat4E2M1INTEL,
10511052
};
10521053

10531054
// Structure describing non-trivial conversions (FP8 and int4)
@@ -1076,36 +1077,117 @@ typedef SPIRVMap<llvm::StringRef, FPConversionDesc> FPConvertToEncodingMap;
10761077

10771078
// clang-format off
10781079
template <> inline void FPConvertToEncodingMap::init() {
1079-
// 8-bit conversions
1080-
add("ConvertE4M3ToFP16EXT",
1081-
{FPEncodingWrap::E4M3, FPEncodingWrap::IEEE754, OpFConvert});
1082-
add("ConvertE5M2ToFP16EXT",
1083-
{FPEncodingWrap::E5M2, FPEncodingWrap::IEEE754, OpFConvert});
1084-
add("ConvertE4M3ToBF16EXT",
1085-
{FPEncodingWrap::E4M3, FPEncodingWrap::BF16, OpFConvert});
1086-
add("ConvertE5M2ToBF16EXT",
1087-
{FPEncodingWrap::E5M2, FPEncodingWrap::BF16, OpFConvert});
1088-
add("ConvertFP16ToE4M3EXT",
1089-
{FPEncodingWrap::IEEE754, FPEncodingWrap::E4M3, OpFConvert});
1090-
add("ConvertFP16ToE5M2EXT",
1091-
{FPEncodingWrap::IEEE754, FPEncodingWrap::E5M2, OpFConvert});
1092-
add("ConvertBF16ToE4M3EXT",
1093-
{FPEncodingWrap::BF16, FPEncodingWrap::E4M3, OpFConvert});
1094-
add("ConvertBF16ToE5M2EXT",
1095-
{FPEncodingWrap::BF16, FPEncodingWrap::E5M2, OpFConvert});
1096-
1097-
add("ConvertInt4ToE4M3INTEL",
1098-
{FPEncodingWrap::Integer, FPEncodingWrap::E4M3, OpConvertSToF});
1099-
add("ConvertInt4ToE5M2INTEL",
1100-
{FPEncodingWrap::Integer, FPEncodingWrap::E5M2, OpConvertSToF});
1101-
add("ConvertInt4ToFP16INTEL",
1102-
{FPEncodingWrap::Integer, FPEncodingWrap::IEEE754, OpConvertSToF});
1103-
add("ConvertInt4ToBF16INTEL",
1104-
{FPEncodingWrap::Integer, FPEncodingWrap::BF16, OpConvertSToF});
1105-
add("ConvertFP16ToInt4INTEL",
1106-
{FPEncodingWrap::IEEE754, FPEncodingWrap::Integer, OpConvertFToS});
1107-
add("ConvertBF16ToInt4INTEL",
1108-
{FPEncodingWrap::BF16, FPEncodingWrap::Integer, OpConvertFToS});
1080+
// 4-bit conversions
1081+
add("ConvertE2M1ToE4M3INTEL",
1082+
{FPEncodingWrap::E2M1, FPEncodingWrap::E4M3, OpFConvert});
1083+
add("ConvertE2M1ToE5M2INTEL",
1084+
{FPEncodingWrap::E2M1, FPEncodingWrap::E5M2, OpFConvert});
1085+
add("ConvertE2M1ToFP16INTEL",
1086+
{FPEncodingWrap::E2M1, FPEncodingWrap::IEEE754, OpFConvert});
1087+
add("ConvertE2M1ToBF16INTEL",
1088+
{FPEncodingWrap::E2M1, FPEncodingWrap::BF16, OpFConvert});
1089+
1090+
add("ConvertInt4ToE4M3INTEL",
1091+
{FPEncodingWrap::Integer, FPEncodingWrap::E4M3, OpConvertSToF});
1092+
add("ConvertInt4ToE5M2INTEL",
1093+
{FPEncodingWrap::Integer, FPEncodingWrap::E5M2, OpConvertSToF});
1094+
add("ConvertInt4ToFP16INTEL",
1095+
{FPEncodingWrap::Integer, FPEncodingWrap::IEEE754, OpConvertSToF});
1096+
add("ConvertInt4ToBF16INTEL",
1097+
{FPEncodingWrap::Integer, FPEncodingWrap::BF16, OpConvertSToF});
1098+
add("ConvertInt4ToInt8INTEL",
1099+
{FPEncodingWrap::Integer, FPEncodingWrap::Integer, OpSConvert});
1100+
1101+
add("ConvertFP16ToE2M1INTEL",
1102+
{FPEncodingWrap::IEEE754, FPEncodingWrap::E2M1, OpFConvert});
1103+
add("ConvertBF16ToE2M1INTEL",
1104+
{FPEncodingWrap::BF16, FPEncodingWrap::E2M1, OpFConvert});
1105+
add("ConvertFP16ToInt4INTEL",
1106+
{FPEncodingWrap::IEEE754, FPEncodingWrap::Integer, OpConvertFToS});
1107+
add("ConvertBF16ToInt4INTEL",
1108+
{FPEncodingWrap::BF16, FPEncodingWrap::Integer, OpConvertFToS});
1109+
1110+
// 8-bit conversions
1111+
add("ConvertE4M3ToFP16EXT",
1112+
{FPEncodingWrap::E4M3, FPEncodingWrap::IEEE754, OpFConvert});
1113+
add("ConvertE5M2ToFP16EXT",
1114+
{FPEncodingWrap::E5M2, FPEncodingWrap::IEEE754, OpFConvert});
1115+
add("ConvertE4M3ToBF16EXT",
1116+
{FPEncodingWrap::E4M3, FPEncodingWrap::BF16, OpFConvert});
1117+
add("ConvertE5M2ToBF16EXT",
1118+
{FPEncodingWrap::E5M2, FPEncodingWrap::BF16, OpFConvert});
1119+
add("ConvertFP16ToE4M3EXT",
1120+
{FPEncodingWrap::IEEE754, FPEncodingWrap::E4M3, OpFConvert});
1121+
add("ConvertFP16ToE5M2EXT",
1122+
{FPEncodingWrap::IEEE754, FPEncodingWrap::E5M2, OpFConvert});
1123+
add("ConvertBF16ToE4M3EXT",
1124+
{FPEncodingWrap::BF16, FPEncodingWrap::E4M3, OpFConvert});
1125+
add("ConvertBF16ToE5M2EXT",
1126+
{FPEncodingWrap::BF16, FPEncodingWrap::E5M2, OpFConvert});
1127+
1128+
// SPV_INTEL_fp_conversions
1129+
add("ClampConvertFP16ToE2M1INTEL",
1130+
{FPEncodingWrap::IEEE754, FPEncodingWrap::E2M1,
1131+
internal::OpClampConvertFToFINTEL});
1132+
add("ClampConvertBF16ToE2M1INTEL",
1133+
{FPEncodingWrap::BF16, FPEncodingWrap::E2M1,
1134+
internal::OpClampConvertFToFINTEL});
1135+
add("ClampConvertFP16ToE4M3INTEL",
1136+
{FPEncodingWrap::IEEE754, FPEncodingWrap::E4M3,
1137+
internal::OpClampConvertFToFINTEL});
1138+
add("ClampConvertBF16ToE4M3INTEL",
1139+
{FPEncodingWrap::BF16, FPEncodingWrap::E4M3,
1140+
internal::OpClampConvertFToFINTEL});
1141+
add("ClampConvertFP16ToE5M2INTEL",
1142+
{FPEncodingWrap::IEEE754, FPEncodingWrap::E5M2,
1143+
internal::OpClampConvertFToFINTEL});
1144+
add("ClampConvertBF16ToE5M2INTEL",
1145+
{FPEncodingWrap::BF16, FPEncodingWrap::E5M2,
1146+
internal::OpClampConvertFToFINTEL});
1147+
add("ClampConvertFP16ToInt4INTEL",
1148+
{FPEncodingWrap::IEEE754, FPEncodingWrap::Integer,
1149+
internal::OpClampConvertFToSINTEL});
1150+
add("ClampConvertBF16ToInt4INTEL",
1151+
{FPEncodingWrap::BF16, FPEncodingWrap::Integer,
1152+
internal::OpClampConvertFToSINTEL});
1153+
1154+
add("StochasticRoundFP16ToE5M2INTEL",
1155+
{FPEncodingWrap::IEEE754, FPEncodingWrap::E5M2,
1156+
internal::OpStochasticRoundFToFINTEL});
1157+
add("StochasticRoundFP16ToE4M3INTEL",
1158+
{FPEncodingWrap::IEEE754, FPEncodingWrap::E4M3,
1159+
internal::OpStochasticRoundFToFINTEL});
1160+
add("StochasticRoundBF16ToE5M2INTEL",
1161+
{FPEncodingWrap::BF16, FPEncodingWrap::E5M2,
1162+
internal::OpStochasticRoundFToFINTEL});
1163+
add("StochasticRoundBF16ToE4M3INTEL",
1164+
{FPEncodingWrap::BF16, FPEncodingWrap::E4M3,
1165+
internal::OpStochasticRoundFToFINTEL});
1166+
add("StochasticRoundFP16ToE2M1INTEL",
1167+
{FPEncodingWrap::IEEE754, FPEncodingWrap::E2M1,
1168+
internal::OpStochasticRoundFToFINTEL});
1169+
add("StochasticRoundBF16ToE2M1INTEL",
1170+
{FPEncodingWrap::BF16, FPEncodingWrap::E2M1,
1171+
internal::OpStochasticRoundFToFINTEL});
1172+
add("ClampStochasticRoundFP16ToInt4INTEL",
1173+
{FPEncodingWrap::IEEE754, FPEncodingWrap::Integer,
1174+
internal::OpClampStochasticRoundFToSINTEL});
1175+
add("ClampStochasticRoundBF16ToInt4INTEL",
1176+
{FPEncodingWrap::BF16, FPEncodingWrap::Integer,
1177+
internal::OpClampStochasticRoundFToSINTEL});
1178+
1179+
add("ClampStochasticRoundFP16ToE5M2INTEL",
1180+
{FPEncodingWrap::IEEE754, FPEncodingWrap::E5M2,
1181+
internal::OpClampStochasticRoundFToFINTEL});
1182+
add("ClampStochasticRoundFP16ToE4M3INTEL",
1183+
{FPEncodingWrap::IEEE754, FPEncodingWrap::E4M3,
1184+
internal::OpClampStochasticRoundFToFINTEL});
1185+
add("ClampStochasticRoundBF16ToE5M2INTEL",
1186+
{FPEncodingWrap::BF16, FPEncodingWrap::E5M2,
1187+
internal::OpClampStochasticRoundFToFINTEL});
1188+
add("ClampStochasticRoundBF16ToE4M3INTEL",
1189+
{FPEncodingWrap::BF16, FPEncodingWrap::E4M3,
1190+
internal::OpClampStochasticRoundFToFINTEL});
11091191
}
11101192

11111193
// clang-format on

lib/SPIRV/SPIRVReader.cpp

Lines changed: 61 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -314,8 +314,11 @@ std::optional<uint64_t> SPIRVToLLVM::getAlignment(SPIRVValue *V) {
314314

315315
Type *SPIRVToLLVM::transFPType(SPIRVType *T) {
316316
switch (T->getFloatBitWidth()) {
317+
case 4:
318+
// No LLVM IR counter part for FP4 - map it on i4.
319+
return Type::getIntNTy(*Context, 4);
317320
case 8:
318-
// No LLVM IR counter part for FP8 - map it on i8
321+
// No LLVM IR counter part for FP8 - map it on i8.
319322
return Type::getIntNTy(*Context, 8);
320323
case 16:
321324
if (T->isTypeFloat(16, FPEncodingBFloat16KHR))
@@ -1081,11 +1084,12 @@ Value *SPIRVToLLVM::transConvertInst(SPIRVValue *BV, Function *F,
10811084
return FPEncodingWrap::IEEE754;
10821085
};
10831086

1084-
auto IsFP8Encoding = [](FPEncodingWrap Encoding) -> bool {
1085-
return Encoding == FPEncodingWrap::E4M3 || Encoding == FPEncodingWrap::E5M2;
1087+
auto IsFP4OrFP8Encoding = [](FPEncodingWrap Encoding) -> bool {
1088+
return Encoding == FPEncodingWrap::E4M3 ||
1089+
Encoding == FPEncodingWrap::E5M2 || Encoding == FPEncodingWrap::E2M1;
10861090
};
10871091

1088-
switch (BC->getOpCode()) {
1092+
switch (static_cast<unsigned>(BC->getOpCode())) {
10891093
case OpPtrCastToGeneric:
10901094
case OpGenericCastToPtr:
10911095
case OpPtrCastToCrossWorkgroupINTEL:
@@ -1106,6 +1110,11 @@ Value *SPIRVToLLVM::transConvertInst(SPIRVValue *BV, Function *F,
11061110
case OpUConvert:
11071111
CO = IsExt ? Instruction::ZExt : Instruction::Trunc;
11081112
break;
1113+
case internal::OpClampConvertFToFINTEL:
1114+
case internal::OpClampConvertFToSINTEL:
1115+
case internal::OpStochasticRoundFToFINTEL:
1116+
case internal::OpClampStochasticRoundFToFINTEL:
1117+
case internal::OpClampStochasticRoundFToSINTEL:
11091118
case OpConvertSToF:
11101119
case OpConvertFToS:
11111120
case OpConvertUToF:
@@ -1130,7 +1139,7 @@ Value *SPIRVToLLVM::transConvertInst(SPIRVValue *BV, Function *F,
11301139

11311140
FPEncodingWrap SrcEnc = GetEncodingAndUpdateType(SPVSrcTy);
11321141
FPEncodingWrap DstEnc = GetEncodingAndUpdateType(SPVDstTy);
1133-
if (IsFP8Encoding(SrcEnc) || IsFP8Encoding(DstEnc) ||
1142+
if (IsFP4OrFP8Encoding(SrcEnc) || IsFP4OrFP8Encoding(DstEnc) ||
11341143
SPVSrcTy->isTypeInt(4) || SPVDstTy->isTypeInt(4)) {
11351144
FPConversionDesc FPDesc = {SrcEnc, DstEnc, BC->getOpCode()};
11361145
auto Conv = SPIRV::FPConvertToEncodingMap::rmap(FPDesc);
@@ -1140,13 +1149,47 @@ Value *SPIRVToLLVM::transConvertInst(SPIRVValue *BV, Function *F,
11401149
std::string BuiltinName =
11411150
kSPIRVName::InternalBuiltinPrefix + std::string(Conv);
11421151
BuiltinFuncMangleInfo Info;
1143-
std::string MangledName = mangleBuiltin(BuiltinName, OpsTys, &Info);
1152+
std::string MangledName;
1153+
// Translate additional Ops for stochastic conversions.
1154+
if (OC == internal::OpStochasticRoundFToFINTEL ||
1155+
OC == internal::OpClampStochasticRoundFToFINTEL ||
1156+
OC == internal::OpClampStochasticRoundFToSINTEL) {
1157+
// Seed.
1158+
Ops.emplace_back(transValue(SPVOps[1], F, BB, true));
1159+
OpsTys.emplace_back(Ops[1]->getType());
1160+
constexpr unsigned MaxOpsSize = 3;
1161+
if (SPVOps.size() == MaxOpsSize) {
1162+
// New Seed.
1163+
Ops.emplace_back(transValue(SPVOps[2], F, BB, true));
1164+
1165+
// The following mess is needed to create a function with correct
1166+
// mangling.
1167+
SPIRVType *PtrTy = SPVOps[2]->getType();
1168+
const unsigned AS =
1169+
SPIRSPIRVAddrSpaceMap::rmap(PtrTy->getPointerStorageClass());
1170+
Type *ElementTy = transType(PtrTy->getPointerElementType());
1171+
OpsTys.emplace_back(TypedPointerType::get(ElementTy, AS));
1172+
MangledName = mangleBuiltin(BuiltinName, OpsTys, &Info);
1173+
// But to create function itself we need untyped pointer type.
1174+
OpsTys[2] = opaquifyType(OpsTys[2]);
1175+
}
1176+
}
1177+
1178+
if (MangledName.empty())
1179+
MangledName = mangleBuiltin(BuiltinName, OpsTys, &Info);
11441180

11451181
FunctionType *FTy = FunctionType::get(Dst, OpsTys, false);
11461182
FunctionCallee Func = M->getOrInsertFunction(MangledName, FTy);
11471183
return CallInst::Create(Func, Ops, "", BB);
11481184
}
11491185
}
1186+
// These conversions can be done without __builtin_spirv prefixed functions
1187+
// as their operand and result types have native representation in LLVM IR.
1188+
if (OC == internal::OpClampConvertFToFINTEL ||
1189+
OC == internal::OpStochasticRoundFToFINTEL ||
1190+
OC == internal::OpClampStochasticRoundFToFINTEL)
1191+
return mapValue(BV, transSPIRVBuiltinFromInst(
1192+
static_cast<SPIRVInstruction *>(BV), BB));
11501193

11511194
if (OC == OpFConvert) {
11521195
CO = IsExt ? Instruction::FPExt : Instruction::FPTrunc;
@@ -3073,7 +3116,11 @@ Value *SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F,
30733116
if (OutMatrixElementTy->isTypeFloat(8, FPEncodingFloat8E4M3EXT) ||
30743117
OutMatrixElementTy->isTypeFloat(8, FPEncodingFloat8E5M2EXT) ||
30753118
InMatrixElementTy->isTypeFloat(8, FPEncodingFloat8E4M3EXT) ||
3076-
InMatrixElementTy->isTypeFloat(8, FPEncodingFloat8E5M2EXT))
3119+
InMatrixElementTy->isTypeFloat(8, FPEncodingFloat8E5M2EXT) ||
3120+
OutMatrixElementTy->isTypeFloat(
3121+
4, internal::FPEncodingFloat4E2M1INTEL) ||
3122+
InMatrixElementTy->isTypeFloat(4,
3123+
internal::FPEncodingFloat4E2M1INTEL))
30773124
Inst = transConvertInst(BV, F, BB);
30783125
else
30793126
Inst = transSPIRVBuiltinFromInst(BI, BB);
@@ -3082,6 +3129,8 @@ Value *SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F,
30823129
}
30833130
return mapValue(BV, Inst);
30843131
}
3132+
if (isIntelCvtOpCode(OC))
3133+
return mapValue(BV, transConvertInst(BV, F, BB));
30853134
return mapValue(
30863135
BV, transSPIRVBuiltinFromInst(static_cast<SPIRVInstruction *>(BV), BB));
30873136
}
@@ -3898,6 +3947,11 @@ Instruction *SPIRVToLLVM::transSPIRVBuiltinFromInst(SPIRVInstruction *BI,
38983947
case internal::OpTaskSequenceCreateINTEL:
38993948
case internal::OpConvertHandleToImageINTEL:
39003949
case internal::OpConvertHandleToSampledImageINTEL:
3950+
case internal::OpClampConvertFToFINTEL:
3951+
case internal::OpClampConvertFToSINTEL:
3952+
case internal::OpStochasticRoundFToFINTEL:
3953+
case internal::OpClampStochasticRoundFToFINTEL:
3954+
case internal::OpClampStochasticRoundFToSINTEL:
39013955
AddRetTypePostfix = true;
39023956
break;
39033957
default: {

lib/SPIRV/SPIRVToOCL.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,10 @@ void SPIRVToOCLBase::visitCastInst(CastInst &Cast) {
247247
DstVecTy->getScalarSizeInBits() == 1)
248248
return;
249249

250+
// We don't have OpenCL builtins for 4-bit conversions.
251+
if (DstVecTy->getScalarSizeInBits() == 4 || SrcTy->getScalarSizeInBits() == 4)
252+
return;
253+
250254
// Assemble built-in name -> convert_gentypeN
251255
std::string CastBuiltInName(kOCLBuiltinName::ConvertPrefix);
252256
// Check if this is 'floating point -> unsigned integer' cast

0 commit comments

Comments
 (0)