Skip to content

Commit f6172dd

Browse files
author
Dmitry Sidorov
committed
[Backport to 16] Implement SPV_INTEL_float4 and SPV_INTEL_fp_conversions extensions (KhronosGroup#3419)
As well as their appropriate conversions via __builtin_spirv mechanism. Specification: intel/llvm#20467 Signed-off-by: Dmitry Sidorov <dmitry.sidorov@intel.com>
1 parent 9704753 commit f6172dd

File tree

15 files changed

+1125
-105
lines changed

15 files changed

+1125
-105
lines changed

include/LLVMSPIRVExtensions.inc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,3 +82,5 @@ EXT(SPV_INTEL_shader_atomic_bfloat16)
8282
EXT(SPV_EXT_float8)
8383
EXT(SPV_INTEL_predicated_io)
8484
EXT(SPV_INTEL_sigmoid)
85+
EXT(SPV_INTEL_float4)
86+
EXT(SPV_INTEL_fp_conversions)

lib/SPIRV/SPIRVInternal.h

Lines changed: 112 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1145,6 +1145,7 @@ enum FPEncodingWrap {
11451145
BF16 = FPEncoding::FPEncodingBFloat16KHR,
11461146
E4M3 = FPEncoding::FPEncodingFloat8E4M3EXT,
11471147
E5M2 = FPEncoding::FPEncodingFloat8E5M2EXT,
1148+
E2M1 = internal::FPEncodingFloat4E2M1INTEL,
11481149
};
11491150

11501151
// Structure describing non-trivial conversions (FP8 and int4)
@@ -1173,36 +1174,117 @@ typedef SPIRVMap<llvm::StringRef, FPConversionDesc> FPConvertToEncodingMap;
11731174

11741175
// clang-format off
11751176
template <> inline void FPConvertToEncodingMap::init() {
1176-
// 8-bit conversions
1177-
add("ConvertE4M3ToFP16EXT",
1178-
{FPEncodingWrap::E4M3, FPEncodingWrap::IEEE754, OpFConvert});
1179-
add("ConvertE5M2ToFP16EXT",
1180-
{FPEncodingWrap::E5M2, FPEncodingWrap::IEEE754, OpFConvert});
1181-
add("ConvertE4M3ToBF16EXT",
1182-
{FPEncodingWrap::E4M3, FPEncodingWrap::BF16, OpFConvert});
1183-
add("ConvertE5M2ToBF16EXT",
1184-
{FPEncodingWrap::E5M2, FPEncodingWrap::BF16, OpFConvert});
1185-
add("ConvertFP16ToE4M3EXT",
1186-
{FPEncodingWrap::IEEE754, FPEncodingWrap::E4M3, OpFConvert});
1187-
add("ConvertFP16ToE5M2EXT",
1188-
{FPEncodingWrap::IEEE754, FPEncodingWrap::E5M2, OpFConvert});
1189-
add("ConvertBF16ToE4M3EXT",
1190-
{FPEncodingWrap::BF16, FPEncodingWrap::E4M3, OpFConvert});
1191-
add("ConvertBF16ToE5M2EXT",
1192-
{FPEncodingWrap::BF16, FPEncodingWrap::E5M2, OpFConvert});
1193-
1194-
add("ConvertInt4ToE4M3INTEL",
1195-
{FPEncodingWrap::Integer, FPEncodingWrap::E4M3, OpConvertSToF});
1196-
add("ConvertInt4ToE5M2INTEL",
1197-
{FPEncodingWrap::Integer, FPEncodingWrap::E5M2, OpConvertSToF});
1198-
add("ConvertInt4ToFP16INTEL",
1199-
{FPEncodingWrap::Integer, FPEncodingWrap::IEEE754, OpConvertSToF});
1200-
add("ConvertInt4ToBF16INTEL",
1201-
{FPEncodingWrap::Integer, FPEncodingWrap::BF16, OpConvertSToF});
1202-
add("ConvertFP16ToInt4INTEL",
1203-
{FPEncodingWrap::IEEE754, FPEncodingWrap::Integer, OpConvertFToS});
1204-
add("ConvertBF16ToInt4INTEL",
1205-
{FPEncodingWrap::BF16, FPEncodingWrap::Integer, OpConvertFToS});
1177+
// 4-bit conversions
1178+
add("ConvertE2M1ToE4M3INTEL",
1179+
{FPEncodingWrap::E2M1, FPEncodingWrap::E4M3, OpFConvert});
1180+
add("ConvertE2M1ToE5M2INTEL",
1181+
{FPEncodingWrap::E2M1, FPEncodingWrap::E5M2, OpFConvert});
1182+
add("ConvertE2M1ToFP16INTEL",
1183+
{FPEncodingWrap::E2M1, FPEncodingWrap::IEEE754, OpFConvert});
1184+
add("ConvertE2M1ToBF16INTEL",
1185+
{FPEncodingWrap::E2M1, FPEncodingWrap::BF16, OpFConvert});
1186+
1187+
add("ConvertInt4ToE4M3INTEL",
1188+
{FPEncodingWrap::Integer, FPEncodingWrap::E4M3, OpConvertSToF});
1189+
add("ConvertInt4ToE5M2INTEL",
1190+
{FPEncodingWrap::Integer, FPEncodingWrap::E5M2, OpConvertSToF});
1191+
add("ConvertInt4ToFP16INTEL",
1192+
{FPEncodingWrap::Integer, FPEncodingWrap::IEEE754, OpConvertSToF});
1193+
add("ConvertInt4ToBF16INTEL",
1194+
{FPEncodingWrap::Integer, FPEncodingWrap::BF16, OpConvertSToF});
1195+
add("ConvertInt4ToInt8INTEL",
1196+
{FPEncodingWrap::Integer, FPEncodingWrap::Integer, OpSConvert});
1197+
1198+
add("ConvertFP16ToE2M1INTEL",
1199+
{FPEncodingWrap::IEEE754, FPEncodingWrap::E2M1, OpFConvert});
1200+
add("ConvertBF16ToE2M1INTEL",
1201+
{FPEncodingWrap::BF16, FPEncodingWrap::E2M1, OpFConvert});
1202+
add("ConvertFP16ToInt4INTEL",
1203+
{FPEncodingWrap::IEEE754, FPEncodingWrap::Integer, OpConvertFToS});
1204+
add("ConvertBF16ToInt4INTEL",
1205+
{FPEncodingWrap::BF16, FPEncodingWrap::Integer, OpConvertFToS});
1206+
1207+
// 8-bit conversions
1208+
add("ConvertE4M3ToFP16EXT",
1209+
{FPEncodingWrap::E4M3, FPEncodingWrap::IEEE754, OpFConvert});
1210+
add("ConvertE5M2ToFP16EXT",
1211+
{FPEncodingWrap::E5M2, FPEncodingWrap::IEEE754, OpFConvert});
1212+
add("ConvertE4M3ToBF16EXT",
1213+
{FPEncodingWrap::E4M3, FPEncodingWrap::BF16, OpFConvert});
1214+
add("ConvertE5M2ToBF16EXT",
1215+
{FPEncodingWrap::E5M2, FPEncodingWrap::BF16, OpFConvert});
1216+
add("ConvertFP16ToE4M3EXT",
1217+
{FPEncodingWrap::IEEE754, FPEncodingWrap::E4M3, OpFConvert});
1218+
add("ConvertFP16ToE5M2EXT",
1219+
{FPEncodingWrap::IEEE754, FPEncodingWrap::E5M2, OpFConvert});
1220+
add("ConvertBF16ToE4M3EXT",
1221+
{FPEncodingWrap::BF16, FPEncodingWrap::E4M3, OpFConvert});
1222+
add("ConvertBF16ToE5M2EXT",
1223+
{FPEncodingWrap::BF16, FPEncodingWrap::E5M2, OpFConvert});
1224+
1225+
// SPV_INTEL_fp_conversions
1226+
add("ClampConvertFP16ToE2M1INTEL",
1227+
{FPEncodingWrap::IEEE754, FPEncodingWrap::E2M1,
1228+
internal::OpClampConvertFToFINTEL});
1229+
add("ClampConvertBF16ToE2M1INTEL",
1230+
{FPEncodingWrap::BF16, FPEncodingWrap::E2M1,
1231+
internal::OpClampConvertFToFINTEL});
1232+
add("ClampConvertFP16ToE4M3INTEL",
1233+
{FPEncodingWrap::IEEE754, FPEncodingWrap::E4M3,
1234+
internal::OpClampConvertFToFINTEL});
1235+
add("ClampConvertBF16ToE4M3INTEL",
1236+
{FPEncodingWrap::BF16, FPEncodingWrap::E4M3,
1237+
internal::OpClampConvertFToFINTEL});
1238+
add("ClampConvertFP16ToE5M2INTEL",
1239+
{FPEncodingWrap::IEEE754, FPEncodingWrap::E5M2,
1240+
internal::OpClampConvertFToFINTEL});
1241+
add("ClampConvertBF16ToE5M2INTEL",
1242+
{FPEncodingWrap::BF16, FPEncodingWrap::E5M2,
1243+
internal::OpClampConvertFToFINTEL});
1244+
add("ClampConvertFP16ToInt4INTEL",
1245+
{FPEncodingWrap::IEEE754, FPEncodingWrap::Integer,
1246+
internal::OpClampConvertFToSINTEL});
1247+
add("ClampConvertBF16ToInt4INTEL",
1248+
{FPEncodingWrap::BF16, FPEncodingWrap::Integer,
1249+
internal::OpClampConvertFToSINTEL});
1250+
1251+
add("StochasticRoundFP16ToE5M2INTEL",
1252+
{FPEncodingWrap::IEEE754, FPEncodingWrap::E5M2,
1253+
internal::OpStochasticRoundFToFINTEL});
1254+
add("StochasticRoundFP16ToE4M3INTEL",
1255+
{FPEncodingWrap::IEEE754, FPEncodingWrap::E4M3,
1256+
internal::OpStochasticRoundFToFINTEL});
1257+
add("StochasticRoundBF16ToE5M2INTEL",
1258+
{FPEncodingWrap::BF16, FPEncodingWrap::E5M2,
1259+
internal::OpStochasticRoundFToFINTEL});
1260+
add("StochasticRoundBF16ToE4M3INTEL",
1261+
{FPEncodingWrap::BF16, FPEncodingWrap::E4M3,
1262+
internal::OpStochasticRoundFToFINTEL});
1263+
add("StochasticRoundFP16ToE2M1INTEL",
1264+
{FPEncodingWrap::IEEE754, FPEncodingWrap::E2M1,
1265+
internal::OpStochasticRoundFToFINTEL});
1266+
add("StochasticRoundBF16ToE2M1INTEL",
1267+
{FPEncodingWrap::BF16, FPEncodingWrap::E2M1,
1268+
internal::OpStochasticRoundFToFINTEL});
1269+
add("ClampStochasticRoundFP16ToInt4INTEL",
1270+
{FPEncodingWrap::IEEE754, FPEncodingWrap::Integer,
1271+
internal::OpClampStochasticRoundFToSINTEL});
1272+
add("ClampStochasticRoundBF16ToInt4INTEL",
1273+
{FPEncodingWrap::BF16, FPEncodingWrap::Integer,
1274+
internal::OpClampStochasticRoundFToSINTEL});
1275+
1276+
add("ClampStochasticRoundFP16ToE5M2INTEL",
1277+
{FPEncodingWrap::IEEE754, FPEncodingWrap::E5M2,
1278+
internal::OpClampStochasticRoundFToFINTEL});
1279+
add("ClampStochasticRoundFP16ToE4M3INTEL",
1280+
{FPEncodingWrap::IEEE754, FPEncodingWrap::E4M3,
1281+
internal::OpClampStochasticRoundFToFINTEL});
1282+
add("ClampStochasticRoundBF16ToE5M2INTEL",
1283+
{FPEncodingWrap::BF16, FPEncodingWrap::E5M2,
1284+
internal::OpClampStochasticRoundFToFINTEL});
1285+
add("ClampStochasticRoundBF16ToE4M3INTEL",
1286+
{FPEncodingWrap::BF16, FPEncodingWrap::E4M3,
1287+
internal::OpClampStochasticRoundFToFINTEL});
12061288
}
12071289

12081290
// clang-format on

lib/SPIRV/SPIRVReader.cpp

Lines changed: 61 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -310,8 +310,11 @@ llvm::Optional<uint64_t> SPIRVToLLVM::getAlignment(SPIRVValue *V) {
310310

311311
Type *SPIRVToLLVM::transFPType(SPIRVType *T) {
312312
switch (T->getFloatBitWidth()) {
313+
case 4:
314+
// No LLVM IR counter part for FP4 - map it on i4.
315+
return Type::getIntNTy(*Context, 4);
313316
case 8:
314-
// No LLVM IR counter part for FP8 - map it on i8
317+
// No LLVM IR counter part for FP8 - map it on i8.
315318
return Type::getIntNTy(*Context, 8);
316319
case 16:
317320
if (T->isTypeFloat(16, FPEncodingBFloat16KHR))
@@ -1171,11 +1174,12 @@ Value *SPIRVToLLVM::transConvertInst(SPIRVValue *BV, Function *F,
11711174
return FPEncodingWrap::IEEE754;
11721175
};
11731176

1174-
auto IsFP8Encoding = [](FPEncodingWrap Encoding) -> bool {
1175-
return Encoding == FPEncodingWrap::E4M3 || Encoding == FPEncodingWrap::E5M2;
1177+
auto IsFP4OrFP8Encoding = [](FPEncodingWrap Encoding) -> bool {
1178+
return Encoding == FPEncodingWrap::E4M3 ||
1179+
Encoding == FPEncodingWrap::E5M2 || Encoding == FPEncodingWrap::E2M1;
11761180
};
11771181

1178-
switch (BC->getOpCode()) {
1182+
switch (static_cast<unsigned>(BC->getOpCode())) {
11791183
case OpPtrCastToGeneric:
11801184
case OpGenericCastToPtr:
11811185
case OpPtrCastToCrossWorkgroupINTEL:
@@ -1196,6 +1200,11 @@ Value *SPIRVToLLVM::transConvertInst(SPIRVValue *BV, Function *F,
11961200
case OpUConvert:
11971201
CO = IsExt ? Instruction::ZExt : Instruction::Trunc;
11981202
break;
1203+
case internal::OpClampConvertFToFINTEL:
1204+
case internal::OpClampConvertFToSINTEL:
1205+
case internal::OpStochasticRoundFToFINTEL:
1206+
case internal::OpClampStochasticRoundFToFINTEL:
1207+
case internal::OpClampStochasticRoundFToSINTEL:
11991208
case OpConvertSToF:
12001209
case OpConvertFToS:
12011210
case OpConvertUToF:
@@ -1220,7 +1229,7 @@ Value *SPIRVToLLVM::transConvertInst(SPIRVValue *BV, Function *F,
12201229

12211230
FPEncodingWrap SrcEnc = GetEncodingAndUpdateType(SPVSrcTy);
12221231
FPEncodingWrap DstEnc = GetEncodingAndUpdateType(SPVDstTy);
1223-
if (IsFP8Encoding(SrcEnc) || IsFP8Encoding(DstEnc) ||
1232+
if (IsFP4OrFP8Encoding(SrcEnc) || IsFP4OrFP8Encoding(DstEnc) ||
12241233
SPVSrcTy->isTypeInt(4) || SPVDstTy->isTypeInt(4)) {
12251234
FPConversionDesc FPDesc = {SrcEnc, DstEnc, BC->getOpCode()};
12261235
auto Conv = SPIRV::FPConvertToEncodingMap::rmap(FPDesc);
@@ -1230,13 +1239,47 @@ Value *SPIRVToLLVM::transConvertInst(SPIRVValue *BV, Function *F,
12301239
std::string BuiltinName =
12311240
kSPIRVName::InternalBuiltinPrefix + std::string(Conv);
12321241
BuiltinFuncMangleInfo Info;
1233-
std::string MangledName = mangleBuiltin(BuiltinName, OpsTys, &Info);
1242+
std::string MangledName;
1243+
// Translate additional Ops for stochastic conversions.
1244+
if (OC == internal::OpStochasticRoundFToFINTEL ||
1245+
OC == internal::OpClampStochasticRoundFToFINTEL ||
1246+
OC == internal::OpClampStochasticRoundFToSINTEL) {
1247+
// Seed.
1248+
Ops.emplace_back(transValue(SPVOps[1], F, BB, true));
1249+
OpsTys.emplace_back(Ops[1]->getType());
1250+
constexpr unsigned MaxOpsSize = 3;
1251+
if (SPVOps.size() == MaxOpsSize) {
1252+
// New Seed.
1253+
Ops.emplace_back(transValue(SPVOps[2], F, BB, true));
1254+
1255+
// The following mess is needed to create a function with correct
1256+
// mangling.
1257+
SPIRVType *PtrTy = SPVOps[2]->getType();
1258+
const unsigned AS =
1259+
SPIRSPIRVAddrSpaceMap::rmap(PtrTy->getPointerStorageClass());
1260+
Type *ElementTy = transType(PtrTy->getPointerElementType());
1261+
OpsTys.emplace_back(TypedPointerType::get(ElementTy, AS));
1262+
MangledName = mangleBuiltin(BuiltinName, OpsTys, &Info);
1263+
// But to create function itself we need untyped pointer type.
1264+
OpsTys[2] = opaquifyType(OpsTys[2]);
1265+
}
1266+
}
1267+
1268+
if (MangledName.empty())
1269+
MangledName = mangleBuiltin(BuiltinName, OpsTys, &Info);
12341270

12351271
FunctionType *FTy = FunctionType::get(Dst, OpsTys, false);
12361272
FunctionCallee Func = M->getOrInsertFunction(MangledName, FTy);
12371273
return CallInst::Create(Func, Ops, "", BB);
12381274
}
12391275
}
1276+
// These conversions can be done without __builtin_spirv prefixed functions
1277+
// as their operand and result types have native representation in LLVM IR.
1278+
if (OC == internal::OpClampConvertFToFINTEL ||
1279+
OC == internal::OpStochasticRoundFToFINTEL ||
1280+
OC == internal::OpClampStochasticRoundFToFINTEL)
1281+
return mapValue(BV, transSPIRVBuiltinFromInst(
1282+
static_cast<SPIRVInstruction *>(BV), BB));
12401283

12411284
if (OC == OpFConvert) {
12421285
CO = IsExt ? Instruction::FPExt : Instruction::FPTrunc;
@@ -3062,7 +3105,11 @@ Value *SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F,
30623105
if (OutMatrixElementTy->isTypeFloat(8, FPEncodingFloat8E4M3EXT) ||
30633106
OutMatrixElementTy->isTypeFloat(8, FPEncodingFloat8E5M2EXT) ||
30643107
InMatrixElementTy->isTypeFloat(8, FPEncodingFloat8E4M3EXT) ||
3065-
InMatrixElementTy->isTypeFloat(8, FPEncodingFloat8E5M2EXT))
3108+
InMatrixElementTy->isTypeFloat(8, FPEncodingFloat8E5M2EXT) ||
3109+
OutMatrixElementTy->isTypeFloat(
3110+
4, internal::FPEncodingFloat4E2M1INTEL) ||
3111+
InMatrixElementTy->isTypeFloat(4,
3112+
internal::FPEncodingFloat4E2M1INTEL))
30663113
Inst = transConvertInst(BV, F, BB);
30673114
else
30683115
Inst = transSPIRVBuiltinFromInst(BI, BB);
@@ -3071,6 +3118,8 @@ Value *SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F,
30713118
}
30723119
return mapValue(BV, Inst);
30733120
}
3121+
if (isIntelCvtOpCode(OC))
3122+
return mapValue(BV, transConvertInst(BV, F, BB));
30743123
return mapValue(
30753124
BV, transSPIRVBuiltinFromInst(static_cast<SPIRVInstruction *>(BV), BB));
30763125
}
@@ -3792,6 +3841,11 @@ Instruction *SPIRVToLLVM::transSPIRVBuiltinFromInst(SPIRVInstruction *BI,
37923841
case internal::OpCooperativeMatrixLoadCheckedINTEL:
37933842
case internal::OpConvertHandleToImageINTEL:
37943843
case internal::OpConvertHandleToSampledImageINTEL:
3844+
case internal::OpClampConvertFToFINTEL:
3845+
case internal::OpClampConvertFToSINTEL:
3846+
case internal::OpStochasticRoundFToFINTEL:
3847+
case internal::OpClampStochasticRoundFToFINTEL:
3848+
case internal::OpClampStochasticRoundFToSINTEL:
37953849
AddRetTypePostfix = true;
37963850
break;
37973851
default: {

lib/SPIRV/SPIRVToOCL.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,10 @@ void SPIRVToOCLBase::visitCastInst(CastInst &Cast) {
237237
DstVecTy->getScalarSizeInBits() == 1)
238238
return;
239239

240+
// We don't have OpenCL builtins for 4-bit conversions.
241+
if (DstVecTy->getScalarSizeInBits() == 4 || SrcTy->getScalarSizeInBits() == 4)
242+
return;
243+
240244
// Assemble built-in name -> convert_gentypeN
241245
std::string CastBuiltInName(kOCLBuiltinName::ConvertPrefix);
242246
// Check if this is 'floating point -> unsigned integer' cast

0 commit comments

Comments
 (0)