@@ -314,6 +314,9 @@ std::optional<uint64_t> SPIRVToLLVM::getAlignment(SPIRVValue *V) {
314314
315315Type *SPIRVToLLVM::transFPType (SPIRVType *T) {
316316 switch (T->getFloatBitWidth ()) {
317+ case 8 :
318+ // No LLVM IR counter part for FP8 - map it on i8
319+ return Type::getIntNTy (*Context, 8 );
317320 case 16 :
318321 if (T->isTypeFloat (16 , FPEncodingBFloat16KHR))
319322 return Type::getBFloatTy (*Context);
@@ -1066,6 +1069,22 @@ Value *SPIRVToLLVM::transConvertInst(SPIRVValue *BV, Function *F,
10661069 CastInst::CastOps CO = Instruction::BitCast;
10671070 bool IsExt =
10681071 Dst->getScalarSizeInBits () > Src->getType ()->getScalarSizeInBits ();
1072+
1073+ auto GetFPEncoding = [](SPIRVType *Ty) -> FPEncodingWrap {
1074+ if (Ty->isTypeFloat ()) {
1075+ unsigned Enc =
1076+ static_cast <SPIRVTypeFloat *>(Ty)->getFloatingPointEncoding ();
1077+ return static_cast <FPEncodingWrap>(Enc);
1078+ }
1079+ if (Ty->isTypeInt ())
1080+ return FPEncodingWrap::Integer;
1081+ return FPEncodingWrap::IEEE754;
1082+ };
1083+
1084+ auto IsFP8Encoding = [](FPEncodingWrap Encoding) -> bool {
1085+ return Encoding == FPEncodingWrap::E4M3 || Encoding == FPEncodingWrap::E5M2;
1086+ };
1087+
10691088 switch (BC->getOpCode ()) {
10701089 case OpPtrCastToGeneric:
10711090 case OpGenericCastToPtr:
@@ -1087,10 +1106,61 @@ Value *SPIRVToLLVM::transConvertInst(SPIRVValue *BV, Function *F,
10871106 case OpUConvert:
10881107 CO = IsExt ? Instruction::ZExt : Instruction::Trunc;
10891108 break ;
1090- case OpFConvert:
1091- CO = IsExt ? Instruction::FPExt : Instruction::FPTrunc;
1109+ case OpConvertSToF:
1110+ case OpConvertFToS:
1111+ case OpConvertUToF:
1112+ case OpConvertFToU:
1113+ case OpFConvert: {
1114+ const auto OC = BC->getOpCode ();
1115+ {
1116+ auto SPVOps = BC->getOperands ();
1117+ auto *SPVSrcTy = SPVOps[0 ]->getType ();
1118+ auto *SPVDstTy = BC->getType ();
1119+ if (SPVSrcTy->isTypeVector ()) {
1120+ SPVSrcTy = SPVSrcTy->getVectorComponentType ();
1121+ } else if (SPVSrcTy->isTypeCooperativeMatrixKHR ()) {
1122+ auto *MT = static_cast <SPIRVTypeCooperativeMatrixKHR *>(SPVSrcTy);
1123+ SPVSrcTy = MT->getCompType ();
1124+ }
1125+ if (SPVDstTy->isTypeVector ()) {
1126+ SPVDstTy = SPVDstTy->getVectorComponentType ();
1127+ } else if (SPVDstTy->isTypeCooperativeMatrixKHR ()) {
1128+ auto *MT = static_cast <SPIRVTypeCooperativeMatrixKHR *>(SPVDstTy);
1129+ SPVDstTy = MT->getCompType ();
1130+ }
1131+ FPEncodingWrap SrcEnc = GetFPEncoding (SPVSrcTy);
1132+ FPEncodingWrap DstEnc = GetFPEncoding (SPVDstTy);
1133+ if (IsFP8Encoding (SrcEnc) || IsFP8Encoding (DstEnc) ||
1134+ SPVSrcTy->isTypeInt (4 ) || SPVDstTy->isTypeInt (4 )) {
1135+ FPConversionDesc FPDesc = {SrcEnc, DstEnc, BC->getOpCode ()};
1136+ auto Conv = SPIRV::FPConvertToEncodingMap::rmap (FPDesc);
1137+ std::vector<Value *> Ops = {Src};
1138+ std::vector<Type *> OpsTys = {Src->getType ()};
1139+
1140+ std::string BuiltinName =
1141+ kSPIRVName ::InternalPrefix + std::string (Conv);
1142+ BuiltinFuncMangleInfo Info;
1143+ std::string MangledName;
1144+
1145+ if (MangledName.empty ())
1146+ MangledName = mangleBuiltin (BuiltinName, OpsTys, &Info);
1147+
1148+ FunctionType *FTy = FunctionType::get (Dst, OpsTys, false );
1149+ FunctionCallee Func = M->getOrInsertFunction (MangledName, FTy);
1150+ return CallInst::Create (Func, Ops, " " , BB);
1151+ }
1152+ }
1153+
1154+ if (OC == OpFConvert) {
1155+ CO = IsExt ? Instruction::FPExt : Instruction::FPTrunc;
1156+ break ;
1157+ }
1158+ CO = static_cast <CastInst::CastOps>(OpCodeMap::rmap (OC));
10921159 break ;
1160+ }
10931161 case OpBitcast:
1162+ if (!Dst->isPointerTy () && Dst == Src->getType ())
1163+ return Src;
10941164 // OpBitcast need to be handled as a special-case when the source is a
10951165 // pointer and the destination is not a pointer, and where the source is not
10961166 // a pointer and the destination is a pointer. This is supported by the
@@ -2990,11 +3060,29 @@ Value *SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F,
29903060 if (isCvtOpCode (OC) && OC != OpGenericCastToPtrExplicit) {
29913061 auto *BI = static_cast <SPIRVInstruction *>(BV);
29923062 Value *Inst = nullptr ;
2993- if (BI->hasFPRoundingMode () || BI->isSaturatedConversion () ||
2994- BI->getType ()->isTypeCooperativeMatrixKHR ())
3063+ if (BI->hasFPRoundingMode () || BI->isSaturatedConversion ()) {
29953064 Inst = transSPIRVBuiltinFromInst (BI, BB);
2996- else
3065+ } else if (BI->getType ()->isTypeCooperativeMatrixKHR ()) {
3066+ // For cooperative matrix conversions generate __builtin_spirv
3067+ // conversions instead of __spirv_FConvert in case of mini-float
3068+ // type element type.
3069+ auto *OutMatrixElementTy =
3070+ static_cast <SPIRVTypeCooperativeMatrixKHR *>(BI->getType ())
3071+ ->getCompType ();
3072+ auto *InMatrixElementTy =
3073+ static_cast <SPIRVTypeCooperativeMatrixKHR *>(
3074+ static_cast <SPIRVUnary *>(BI)->getOperand (0 )->getType ())
3075+ ->getCompType ();
3076+ if (OutMatrixElementTy->isTypeFloat (8 , FPEncodingFloat8E4M3EXT) ||
3077+ OutMatrixElementTy->isTypeFloat (8 , FPEncodingFloat8E5M2EXT) ||
3078+ InMatrixElementTy->isTypeFloat (8 , FPEncodingFloat8E4M3EXT) ||
3079+ InMatrixElementTy->isTypeFloat (8 , FPEncodingFloat8E5M2EXT))
3080+ Inst = transConvertInst (BV, F, BB);
3081+ else
3082+ Inst = transSPIRVBuiltinFromInst (BI, BB);
3083+ } else {
29973084 Inst = transConvertInst (BV, F, BB);
3085+ }
29983086 return mapValue (BV, Inst);
29993087 }
30003088 return mapValue (
0 commit comments