Skip to content

Commit 6b00efc

Browse files
author
Dmitry Sidorov
committed
[Backport to 14] Implement SPV_INTEL_float4 and SPV_INTEL_fp_conversions extensions (KhronosGroup#3419)
As well as their appropriate conversions via __builtin_spirv mechanism. Specification: intel/llvm#20467 Signed-off-by: Dmitry Sidorov <dmitry.sidorov@intel.com>
1 parent 96171e3 commit 6b00efc

File tree

15 files changed

+1102
-105
lines changed

15 files changed

+1102
-105
lines changed

include/LLVMSPIRVExtensions.inc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,3 +80,5 @@ EXT(SPV_INTEL_predicated_io)
8080
EXT(SPV_INTEL_sigmoid)
8181
EXT(SPV_INTEL_ternary_bitwise_function)
8282
EXT(SPV_INTEL_int4)
83+
EXT(SPV_INTEL_float4)
84+
EXT(SPV_INTEL_fp_conversions)

lib/SPIRV/SPIRVInternal.h

Lines changed: 112 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1169,6 +1169,7 @@ enum FPEncodingWrap {
11691169
BF16 = FPEncoding::FPEncodingBFloat16KHR,
11701170
E4M3 = FPEncoding::FPEncodingFloat8E4M3EXT,
11711171
E5M2 = FPEncoding::FPEncodingFloat8E5M2EXT,
1172+
E2M1 = internal::FPEncodingFloat4E2M1INTEL,
11721173
};
11731174

11741175
// Structure describing non-trivial conversions (FP8 and int4)
@@ -1198,36 +1199,117 @@ typedef SPIRVMap<llvm::StringRef, FPConversionDesc> FPConvertToEncodingMap;
11981199

11991200
// clang-format off
12001201
template <> inline void FPConvertToEncodingMap::init() {
1201-
// 8-bit conversions
1202-
add("ConvertE4M3ToFP16EXT",
1203-
{FPEncodingWrap::E4M3, FPEncodingWrap::IEEE754, OpFConvert});
1204-
add("ConvertE5M2ToFP16EXT",
1205-
{FPEncodingWrap::E5M2, FPEncodingWrap::IEEE754, OpFConvert});
1206-
add("ConvertE4M3ToBF16EXT",
1207-
{FPEncodingWrap::E4M3, FPEncodingWrap::BF16, OpFConvert});
1208-
add("ConvertE5M2ToBF16EXT",
1209-
{FPEncodingWrap::E5M2, FPEncodingWrap::BF16, OpFConvert});
1210-
add("ConvertFP16ToE4M3EXT",
1211-
{FPEncodingWrap::IEEE754, FPEncodingWrap::E4M3, OpFConvert});
1212-
add("ConvertFP16ToE5M2EXT",
1213-
{FPEncodingWrap::IEEE754, FPEncodingWrap::E5M2, OpFConvert});
1214-
add("ConvertBF16ToE4M3EXT",
1215-
{FPEncodingWrap::BF16, FPEncodingWrap::E4M3, OpFConvert});
1216-
add("ConvertBF16ToE5M2EXT",
1217-
{FPEncodingWrap::BF16, FPEncodingWrap::E5M2, OpFConvert});
1218-
1219-
add("ConvertInt4ToE4M3INTEL",
1220-
{FPEncodingWrap::Integer, FPEncodingWrap::E4M3, OpConvertSToF});
1221-
add("ConvertInt4ToE5M2INTEL",
1222-
{FPEncodingWrap::Integer, FPEncodingWrap::E5M2, OpConvertSToF});
1223-
add("ConvertInt4ToFP16INTEL",
1224-
{FPEncodingWrap::Integer, FPEncodingWrap::IEEE754, OpConvertSToF});
1225-
add("ConvertInt4ToBF16INTEL",
1226-
{FPEncodingWrap::Integer, FPEncodingWrap::BF16, OpConvertSToF});
1227-
add("ConvertFP16ToInt4INTEL",
1228-
{FPEncodingWrap::IEEE754, FPEncodingWrap::Integer, OpConvertFToS});
1229-
add("ConvertBF16ToInt4INTEL",
1230-
{FPEncodingWrap::BF16, FPEncodingWrap::Integer, OpConvertFToS});
1202+
// 4-bit conversions
1203+
add("ConvertE2M1ToE4M3INTEL",
1204+
{FPEncodingWrap::E2M1, FPEncodingWrap::E4M3, OpFConvert});
1205+
add("ConvertE2M1ToE5M2INTEL",
1206+
{FPEncodingWrap::E2M1, FPEncodingWrap::E5M2, OpFConvert});
1207+
add("ConvertE2M1ToFP16INTEL",
1208+
{FPEncodingWrap::E2M1, FPEncodingWrap::IEEE754, OpFConvert});
1209+
add("ConvertE2M1ToBF16INTEL",
1210+
{FPEncodingWrap::E2M1, FPEncodingWrap::BF16, OpFConvert});
1211+
1212+
add("ConvertInt4ToE4M3INTEL",
1213+
{FPEncodingWrap::Integer, FPEncodingWrap::E4M3, OpConvertSToF});
1214+
add("ConvertInt4ToE5M2INTEL",
1215+
{FPEncodingWrap::Integer, FPEncodingWrap::E5M2, OpConvertSToF});
1216+
add("ConvertInt4ToFP16INTEL",
1217+
{FPEncodingWrap::Integer, FPEncodingWrap::IEEE754, OpConvertSToF});
1218+
add("ConvertInt4ToBF16INTEL",
1219+
{FPEncodingWrap::Integer, FPEncodingWrap::BF16, OpConvertSToF});
1220+
add("ConvertInt4ToInt8INTEL",
1221+
{FPEncodingWrap::Integer, FPEncodingWrap::Integer, OpSConvert});
1222+
1223+
add("ConvertFP16ToE2M1INTEL",
1224+
{FPEncodingWrap::IEEE754, FPEncodingWrap::E2M1, OpFConvert});
1225+
add("ConvertBF16ToE2M1INTEL",
1226+
{FPEncodingWrap::BF16, FPEncodingWrap::E2M1, OpFConvert});
1227+
add("ConvertFP16ToInt4INTEL",
1228+
{FPEncodingWrap::IEEE754, FPEncodingWrap::Integer, OpConvertFToS});
1229+
add("ConvertBF16ToInt4INTEL",
1230+
{FPEncodingWrap::BF16, FPEncodingWrap::Integer, OpConvertFToS});
1231+
1232+
// 8-bit conversions
1233+
add("ConvertE4M3ToFP16EXT",
1234+
{FPEncodingWrap::E4M3, FPEncodingWrap::IEEE754, OpFConvert});
1235+
add("ConvertE5M2ToFP16EXT",
1236+
{FPEncodingWrap::E5M2, FPEncodingWrap::IEEE754, OpFConvert});
1237+
add("ConvertE4M3ToBF16EXT",
1238+
{FPEncodingWrap::E4M3, FPEncodingWrap::BF16, OpFConvert});
1239+
add("ConvertE5M2ToBF16EXT",
1240+
{FPEncodingWrap::E5M2, FPEncodingWrap::BF16, OpFConvert});
1241+
add("ConvertFP16ToE4M3EXT",
1242+
{FPEncodingWrap::IEEE754, FPEncodingWrap::E4M3, OpFConvert});
1243+
add("ConvertFP16ToE5M2EXT",
1244+
{FPEncodingWrap::IEEE754, FPEncodingWrap::E5M2, OpFConvert});
1245+
add("ConvertBF16ToE4M3EXT",
1246+
{FPEncodingWrap::BF16, FPEncodingWrap::E4M3, OpFConvert});
1247+
add("ConvertBF16ToE5M2EXT",
1248+
{FPEncodingWrap::BF16, FPEncodingWrap::E5M2, OpFConvert});
1249+
1250+
// SPV_INTEL_fp_conversions
1251+
add("ClampConvertFP16ToE2M1INTEL",
1252+
{FPEncodingWrap::IEEE754, FPEncodingWrap::E2M1,
1253+
internal::OpClampConvertFToFINTEL});
1254+
add("ClampConvertBF16ToE2M1INTEL",
1255+
{FPEncodingWrap::BF16, FPEncodingWrap::E2M1,
1256+
internal::OpClampConvertFToFINTEL});
1257+
add("ClampConvertFP16ToE4M3INTEL",
1258+
{FPEncodingWrap::IEEE754, FPEncodingWrap::E4M3,
1259+
internal::OpClampConvertFToFINTEL});
1260+
add("ClampConvertBF16ToE4M3INTEL",
1261+
{FPEncodingWrap::BF16, FPEncodingWrap::E4M3,
1262+
internal::OpClampConvertFToFINTEL});
1263+
add("ClampConvertFP16ToE5M2INTEL",
1264+
{FPEncodingWrap::IEEE754, FPEncodingWrap::E5M2,
1265+
internal::OpClampConvertFToFINTEL});
1266+
add("ClampConvertBF16ToE5M2INTEL",
1267+
{FPEncodingWrap::BF16, FPEncodingWrap::E5M2,
1268+
internal::OpClampConvertFToFINTEL});
1269+
add("ClampConvertFP16ToInt4INTEL",
1270+
{FPEncodingWrap::IEEE754, FPEncodingWrap::Integer,
1271+
internal::OpClampConvertFToSINTEL});
1272+
add("ClampConvertBF16ToInt4INTEL",
1273+
{FPEncodingWrap::BF16, FPEncodingWrap::Integer,
1274+
internal::OpClampConvertFToSINTEL});
1275+
1276+
add("StochasticRoundFP16ToE5M2INTEL",
1277+
{FPEncodingWrap::IEEE754, FPEncodingWrap::E5M2,
1278+
internal::OpStochasticRoundFToFINTEL});
1279+
add("StochasticRoundFP16ToE4M3INTEL",
1280+
{FPEncodingWrap::IEEE754, FPEncodingWrap::E4M3,
1281+
internal::OpStochasticRoundFToFINTEL});
1282+
add("StochasticRoundBF16ToE5M2INTEL",
1283+
{FPEncodingWrap::BF16, FPEncodingWrap::E5M2,
1284+
internal::OpStochasticRoundFToFINTEL});
1285+
add("StochasticRoundBF16ToE4M3INTEL",
1286+
{FPEncodingWrap::BF16, FPEncodingWrap::E4M3,
1287+
internal::OpStochasticRoundFToFINTEL});
1288+
add("StochasticRoundFP16ToE2M1INTEL",
1289+
{FPEncodingWrap::IEEE754, FPEncodingWrap::E2M1,
1290+
internal::OpStochasticRoundFToFINTEL});
1291+
add("StochasticRoundBF16ToE2M1INTEL",
1292+
{FPEncodingWrap::BF16, FPEncodingWrap::E2M1,
1293+
internal::OpStochasticRoundFToFINTEL});
1294+
add("ClampStochasticRoundFP16ToInt4INTEL",
1295+
{FPEncodingWrap::IEEE754, FPEncodingWrap::Integer,
1296+
internal::OpClampStochasticRoundFToSINTEL});
1297+
add("ClampStochasticRoundBF16ToInt4INTEL",
1298+
{FPEncodingWrap::BF16, FPEncodingWrap::Integer,
1299+
internal::OpClampStochasticRoundFToSINTEL});
1300+
1301+
add("ClampStochasticRoundFP16ToE5M2INTEL",
1302+
{FPEncodingWrap::IEEE754, FPEncodingWrap::E5M2,
1303+
internal::OpClampStochasticRoundFToFINTEL});
1304+
add("ClampStochasticRoundFP16ToE4M3INTEL",
1305+
{FPEncodingWrap::IEEE754, FPEncodingWrap::E4M3,
1306+
internal::OpClampStochasticRoundFToFINTEL});
1307+
add("ClampStochasticRoundBF16ToE5M2INTEL",
1308+
{FPEncodingWrap::BF16, FPEncodingWrap::E5M2,
1309+
internal::OpClampStochasticRoundFToFINTEL});
1310+
add("ClampStochasticRoundBF16ToE4M3INTEL",
1311+
{FPEncodingWrap::BF16, FPEncodingWrap::E4M3,
1312+
internal::OpClampStochasticRoundFToFINTEL});
12311313
}
12321314

12331315
// clang-format on

lib/SPIRV/SPIRVReader.cpp

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -309,8 +309,11 @@ llvm::Optional<uint64_t> SPIRVToLLVM::getAlignment(SPIRVValue *V) {
309309

310310
Type *SPIRVToLLVM::transFPType(SPIRVType *T) {
311311
switch (T->getFloatBitWidth()) {
312+
case 4:
313+
// No LLVM IR counter part for FP4 - map it on i4.
314+
return Type::getIntNTy(*Context, 4);
312315
case 8:
313-
// No LLVM IR counter part for FP8 - map it on i8
316+
// No LLVM IR counter part for FP8 - map it on i8.
314317
return Type::getIntNTy(*Context, 8);
315318
case 16:
316319
if (T->isTypeFloat(16, FPEncodingBFloat16KHR))
@@ -1066,11 +1069,12 @@ Value *SPIRVToLLVM::transConvertInst(SPIRVValue *BV, Function *F,
10661069
return FPEncodingWrap::IEEE754;
10671070
};
10681071

1069-
auto IsFP8Encoding = [](FPEncodingWrap Encoding) -> bool {
1070-
return Encoding == FPEncodingWrap::E4M3 || Encoding == FPEncodingWrap::E5M2;
1072+
auto IsFP4OrFP8Encoding = [](FPEncodingWrap Encoding) -> bool {
1073+
return Encoding == FPEncodingWrap::E4M3 ||
1074+
Encoding == FPEncodingWrap::E5M2 || Encoding == FPEncodingWrap::E2M1;
10711075
};
10721076

1073-
switch (BC->getOpCode()) {
1077+
switch (static_cast<unsigned>(BC->getOpCode())) {
10741078
case OpPtrCastToGeneric:
10751079
case OpGenericCastToPtr:
10761080
case OpPtrCastToCrossWorkgroupINTEL:
@@ -1091,6 +1095,11 @@ Value *SPIRVToLLVM::transConvertInst(SPIRVValue *BV, Function *F,
10911095
case OpUConvert:
10921096
CO = IsExt ? Instruction::ZExt : Instruction::Trunc;
10931097
break;
1098+
case internal::OpClampConvertFToFINTEL:
1099+
case internal::OpClampConvertFToSINTEL:
1100+
case internal::OpStochasticRoundFToFINTEL:
1101+
case internal::OpClampStochasticRoundFToFINTEL:
1102+
case internal::OpClampStochasticRoundFToSINTEL:
10941103
case OpConvertSToF:
10951104
case OpConvertFToS:
10961105
case OpConvertUToF:
@@ -1115,7 +1124,7 @@ Value *SPIRVToLLVM::transConvertInst(SPIRVValue *BV, Function *F,
11151124

11161125
FPEncodingWrap SrcEnc = GetEncodingAndUpdateType(SPVSrcTy);
11171126
FPEncodingWrap DstEnc = GetEncodingAndUpdateType(SPVDstTy);
1118-
if (IsFP8Encoding(SrcEnc) || IsFP8Encoding(DstEnc) ||
1127+
if (IsFP4OrFP8Encoding(SrcEnc) || IsFP4OrFP8Encoding(DstEnc) ||
11191128
SPVSrcTy->isTypeInt(4) || SPVDstTy->isTypeInt(4)) {
11201129
FPConversionDesc FPDesc = {SrcEnc, DstEnc, BC->getOpCode()};
11211130
auto Conv = SPIRV::FPConvertToEncodingMap::rmap(FPDesc);
@@ -1158,6 +1167,13 @@ Value *SPIRVToLLVM::transConvertInst(SPIRVValue *BV, Function *F,
11581167
return CallInst::Create(Func, Ops, "", BB);
11591168
}
11601169
}
1170+
// These conversions can be done without __builtin_spirv prefixed functions
1171+
// as their operand and result types have native representation in LLVM IR.
1172+
if (OC == internal::OpClampConvertFToFINTEL ||
1173+
OC == internal::OpStochasticRoundFToFINTEL ||
1174+
OC == internal::OpClampStochasticRoundFToFINTEL)
1175+
return mapValue(BV, transSPIRVBuiltinFromInst(
1176+
static_cast<SPIRVInstruction *>(BV), BB));
11611177

11621178
if (OC == OpFConvert) {
11631179
CO = IsExt ? Instruction::FPExt : Instruction::FPTrunc;
@@ -2968,7 +2984,11 @@ Value *SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F,
29682984
if (OutMatrixElementTy->isTypeFloat(8, FPEncodingFloat8E4M3EXT) ||
29692985
OutMatrixElementTy->isTypeFloat(8, FPEncodingFloat8E5M2EXT) ||
29702986
InMatrixElementTy->isTypeFloat(8, FPEncodingFloat8E4M3EXT) ||
2971-
InMatrixElementTy->isTypeFloat(8, FPEncodingFloat8E5M2EXT))
2987+
InMatrixElementTy->isTypeFloat(8, FPEncodingFloat8E5M2EXT) ||
2988+
OutMatrixElementTy->isTypeFloat(
2989+
4, internal::FPEncodingFloat4E2M1INTEL) ||
2990+
InMatrixElementTy->isTypeFloat(4,
2991+
internal::FPEncodingFloat4E2M1INTEL))
29722992
Inst = transConvertInst(BV, F, BB);
29732993
else
29742994
Inst = transSPIRVBuiltinFromInst(BI, BB);
@@ -2977,6 +2997,8 @@ Value *SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F,
29772997
}
29782998
return mapValue(BV, Inst);
29792999
}
3000+
if (isIntelCvtOpCode(OC))
3001+
return mapValue(BV, transConvertInst(BV, F, BB));
29803002
return mapValue(
29813003
BV, transSPIRVBuiltinFromInst(static_cast<SPIRVInstruction *>(BV), BB));
29823004
}
@@ -3696,6 +3718,11 @@ Instruction *SPIRVToLLVM::transSPIRVBuiltinFromInst(SPIRVInstruction *BI,
36963718
case internal::OpCooperativeMatrixLoadCheckedINTEL:
36973719
case internal::OpConvertHandleToImageINTEL:
36983720
case internal::OpConvertHandleToSampledImageINTEL:
3721+
case internal::OpClampConvertFToFINTEL:
3722+
case internal::OpClampConvertFToSINTEL:
3723+
case internal::OpStochasticRoundFToFINTEL:
3724+
case internal::OpClampStochasticRoundFToFINTEL:
3725+
case internal::OpClampStochasticRoundFToSINTEL:
36993726
AddRetTypePostfix = true;
37003727
break;
37013728
default: {

lib/SPIRV/SPIRVToOCL.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,10 @@ void SPIRVToOCLBase::visitCastInst(CastInst &Cast) {
233233
DstVecTy->getScalarSizeInBits() == 1)
234234
return;
235235

236+
// We don't have OpenCL builtins for 4-bit conversions.
237+
if (DstVecTy->getScalarSizeInBits() == 4 || SrcTy->getScalarSizeInBits() == 4)
238+
return;
239+
236240
// Assemble built-in name -> convert_gentypeN
237241
std::string CastBuiltInName(kOCLBuiltinName::ConvertPrefix);
238242
// Check if this is 'floating point -> unsigned integer' cast

0 commit comments

Comments
 (0)