@@ -305,6 +305,8 @@ class SPIRVInstructionSelector : public InstructionSelector {
305305 bool selectImageWriteIntrinsic (MachineInstr &I) const ;
306306 bool selectResourceGetPointer (Register &ResVReg, const SPIRVType *ResType,
307307 MachineInstr &I) const ;
308+ bool selectIsFpclass (Register ResVReg, const SPIRVType *ResType,
309+ MachineInstr &I) const ;
308310
309311 // Utilities
310312 std::pair<Register, bool >
@@ -893,6 +895,8 @@ bool SPIRVInstructionSelector::spvSelect(Register ResVReg,
893895 case TargetOpcode::G_UBSANTRAP:
894896 case TargetOpcode::DBG_LABEL:
895897 return true ;
898+ case TargetOpcode::G_IS_FPCLASS:
899+ return selectIsFpclass (ResVReg,ResType,I);
896900
897901 default :
898902 return false ;
@@ -4021,6 +4025,316 @@ bool SPIRVInstructionSelector::loadHandleBeforePosition(
40214025 .constrainAllUses (TII, TRI, RBI);
40224026}
40234027
4028+ llvm::Type* getLLVMType (LLT srcLLTType, llvm::LLVMContext &context) {
4029+ if (srcLLTType.isScalar ()) {
4030+ switch (srcLLTType.getSizeInBits ()) {
4031+ case 32 : return llvm::Type::getFloatTy (context); // float
4032+ case 64 : return llvm::Type::getDoubleTy (context); // double
4033+ default :
4034+ llvm_unreachable (" Unsupported scalar floating-point type!" );
4035+ }
4036+ } else if (srcLLTType.isVector ()) {
4037+ unsigned numElements = srcLLTType.getNumElements ();
4038+ LLT elemType = srcLLTType.getElementType ();
4039+
4040+ if (elemType.isScalar ()) {
4041+ llvm::Type* baseType = nullptr ;
4042+ switch (elemType.getSizeInBits ()) {
4043+ case 32 : baseType = llvm::Type::getFloatTy (context); break ;
4044+ case 64 : baseType = llvm::Type::getDoubleTy (context); break ;
4045+ case 16 : baseType = llvm::Type::getHalfTy (context); break ;
4046+ default :
4047+ llvm_unreachable (" Unsupported vector element type!" );
4048+ }
4049+ return llvm::VectorType::get (baseType, llvm::ElementCount::getFixed (numElements));
4050+ }
4051+ }
4052+
4053+ llvm_unreachable (" Unsupported LLT type conversion!" );
4054+ }
4055+
4056+ int getBitWidth (LLT srcLLTType){
4057+ int bitWidth;
4058+ if (srcLLTType.isScalar ()){
4059+ bitWidth = srcLLTType.getSizeInBits ();
4060+ }else if (srcLLTType.isVector ()){
4061+ bitWidth = srcLLTType.getElementType ().getSizeInBits ();
4062+ }
4063+ return bitWidth;
4064+ }
4065+
4066+ bool SPIRVInstructionSelector::selectIsFpclass (Register ResVReg, const SPIRVType *ResType, MachineInstr &I) const {
4067+ MachineIRBuilder MIRBuilder (I);
4068+ // spivType creation for Creating Intermediate Registers
4069+ std::vector<Register> ResultVec;
4070+ Register srcReg = I.getOperand (1 ).getReg ();
4071+ SPIRVType* srcSPIRVType = GR.getSPIRVTypeForVReg (srcReg);
4072+ LLT srcLLTType = GR.getRegType (srcSPIRVType);
4073+ Register ResReg = MRI->createGenericVirtualRegister (srcLLTType);
4074+ DstOp Res (ResReg);
4075+ Type* srcLLVMType = getLLVMType (srcLLTType, GR.CurMF ->getFunction ().getContext ());
4076+ int bitWidth = getBitWidth (srcLLTType);
4077+
4078+ Type *IntOpLLVMTy = IntegerType::getIntNTy (GR.CurMF ->getFunction ().getContext (), bitWidth);
4079+ if (srcLLVMType->isVectorTy ())
4080+ IntOpLLVMTy = FixedVectorType::get (IntOpLLVMTy, cast<FixedVectorType>(srcLLVMType)->getNumElements ());
4081+
4082+
4083+ SPIRVType *boolType = GR.getOrCreateSPIRVBoolType (I, TII);
4084+ SPIRVType *intType = GR.getOrCreateSPIRVIntegerType (bitWidth,I,TII);
4085+ SPIRVType *ScalarInt = GR.getOrCreateSPIRVIntegerType (bitWidth,I,TII);
4086+ unsigned N = GR.getScalarOrVectorComponentCount (ResType);
4087+ if (N > 1 ){
4088+ boolType = GR.getOrCreateSPIRVVectorType (boolType, N, I, TII);
4089+ intType = GR.getOrCreateSPIRVVectorType (intType, N, I, TII);
4090+ }
4091+
4092+ // function to create constant
4093+ auto createConstant = [&](int64_t val) -> Register {
4094+ Register constant = MRI->createVirtualRegister (GR.getRegClass (ScalarInt));
4095+ MIRBuilder.buildInstr (SPIRV::OpConstantI)
4096+ .addDef (constant)
4097+ .addUse (GR.getSPIRVTypeID (ScalarInt))
4098+ .addImm (val);
4099+ if (srcLLVMType->isVectorTy ()) {
4100+ Register compositeConstant = MRI->createVirtualRegister (GR.getRegClass (intType));
4101+ auto MIB = MIRBuilder.buildInstr (SPIRV::OpConstantComposite)
4102+ .addDef (compositeConstant)
4103+ .addUse (GR.getSPIRVTypeID (intType));
4104+ unsigned numElements = srcLLTType.getNumElements ();
4105+ for (unsigned i = 0 ; i < numElements; ++i) {
4106+ MIB.addUse (constant);
4107+ }
4108+ return compositeConstant;
4109+ }
4110+ return constant;
4111+ };
4112+ // Type* srcLLVMType;
4113+ // checking src type to create llvm Type Creation
4114+ // srcLLVMType = (bitWidth == 32) ? Type::getFloatTy(GR.CurMF->getFunction().getContext()) : Type::getDoubleTy(GR.CurMF->getFunction().getContext());
4115+ // llvm::FPClassTest FPClass = static_cast<llvm::FPClassTest>(I.getOperand(2).getImm());
4116+ llvm::FPClassTest FPClass = static_cast <llvm::FPClassTest>(I.getOperand (2 ).getImm ());
4117+ // edge cases
4118+ if (FPClass == 0 ){
4119+ // Register trueVector = createConstant(1);
4120+ // MIRBuilder.buildCopy(ResVReg, trueVector);
4121+ return true ;
4122+ }
4123+ if (FPClass == fcAllFlags){
4124+ // Register trueVector = createConstant(1);
4125+ // MIRBuilder.buildCopy(ResVReg, trueVector);
4126+ return true ;
4127+ }
4128+
4129+
4130+ // Instruction Template
4131+ auto instructionTemplate = [&](unsigned Opcode, SPIRVType* DestType, SPIRVType* ReturnType, auto &&... args) -> Register {
4132+ Register result = MRI->createVirtualRegister (GR.getRegClass (DestType));
4133+ auto &Instr = MIRBuilder.buildInstr (Opcode)
4134+ .addDef (result)
4135+ .addUse (GR.getSPIRVTypeID (DestType));
4136+
4137+ ([&](auto && arg) __attribute__ ((optimize (" O0" ))) {
4138+ if (std::is_integral_v<std::decay_t <decltype (arg)>>) {
4139+ Instr.addImm (arg);
4140+ }else {
4141+ Instr.addUse (arg);
4142+ }
4143+ }(args), ...);
4144+ return result;
4145+ };
4146+ // function to check if the sign bit is set or not
4147+ // 1 sign is set if 0 sign is not set
4148+ // inf && sign == 1 then -ve infinity
4149+ // inf && sign == 0 then not -ve infinity
4150+ // --------positive cases -------------//
4151+ // inf && ~sign - positive Infinity
4152+ // inf && ~sign - not positive Infinity
4153+ // sign bit test logic
4154+ Register SignBitTest = Register (0 );
4155+ Register NoSignTest = Register (0 );
4156+ auto GetNegPosInstTest = [&](Register TestInst,
4157+ bool IsNegative) -> Register {
4158+ Register result = MRI->createVirtualRegister (GR.getRegClass (boolType));
4159+ if (!SignBitTest.isValid ()){
4160+ SignBitTest = MRI->createVirtualRegister (GR.getRegClass (boolType));
4161+ MIRBuilder.buildInstr (SPIRV::OpSignBitSet)
4162+ .addDef (SignBitTest)
4163+ .addUse (GR.getSPIRVTypeID (boolType))
4164+ .addUse (srcReg);
4165+ }
4166+
4167+ if (IsNegative) {
4168+ MIRBuilder.buildInstr (SPIRV::OpLogicalAnd)
4169+ .addDef (result)
4170+ .addUse (GR.getSPIRVTypeID (boolType))
4171+ .addUse (SignBitTest)
4172+ .addUse (TestInst);
4173+ return result;
4174+ }
4175+ if (!NoSignTest.isValid ()){
4176+ NoSignTest = MRI->createVirtualRegister (GR.getRegClass (boolType));
4177+ MIRBuilder.buildInstr (SPIRV::OpLogicalNot)
4178+ .addDef (NoSignTest)
4179+ .addUse (GR.getSPIRVTypeID (boolType))
4180+ .addUse (SignBitTest);
4181+ }
4182+
4183+ MIRBuilder.buildInstr (SPIRV::OpLogicalAnd)
4184+ .addDef (result)
4185+ .addUse (GR.getSPIRVTypeID (boolType))
4186+ .addUse (NoSignTest)
4187+ .addUse (TestInst);
4188+ return result;
4189+ };
4190+
4191+ // if (srcLLVMType->isVectorTy())
4192+ // llvm::errs() << "is Vector Type" << "\n";
4193+ // IntOpLLVMTy = FixedVectorType::get(
4194+ // IntOpLLVMTy, cast<FixedVectorType>(srcLLVMType)->getNumElements());
4195+
4196+ const llvm::fltSemantics &Semantics =srcLLVMType->getScalarType ()->getFltSemantics ();
4197+ APInt Inf = APFloat::getInf (Semantics).bitcastToAPInt ();
4198+ APInt AllOneMantissa = APFloat::getLargest (Semantics).bitcastToAPInt () & ~Inf;
4199+
4200+ // Mask Inversion Logic
4201+ auto GetInvertedFPClassTest =
4202+ [](const llvm::FPClassTest Test) -> llvm::FPClassTest {
4203+ llvm::FPClassTest InvertedTest = ~Test & fcAllFlags;
4204+ switch (InvertedTest) {
4205+ case fcNan:
4206+ case fcSNan:
4207+ case fcQNan:
4208+ case fcInf:
4209+ case fcPosInf:
4210+ case fcNegInf:
4211+ case fcNormal:
4212+ case fcPosNormal:
4213+ case fcNegNormal:
4214+ case fcSubnormal:
4215+ case fcPosSubnormal:
4216+ case fcNegSubnormal:
4217+ case fcZero:
4218+ case fcPosZero:
4219+ case fcNegZero:
4220+ case fcFinite:
4221+ case fcPosFinite:
4222+ case fcNegFinite:
4223+ return InvertedTest;
4224+ }
4225+ return fcNone;
4226+ };
4227+
4228+ // if is possible to invert then invert the test
4229+ bool IsInverted = false ;
4230+ if (llvm::FPClassTest InvertedCheck = GetInvertedFPClassTest (FPClass)) {
4231+ IsInverted = true ;
4232+ FPClass = InvertedCheck;
4233+ }
4234+
4235+ auto GetInvertedTestIfNeeded = [&](Register src){
4236+ if (!IsInverted)
4237+ return src;
4238+ Register des = MRI->createVirtualRegister (MRI->getRegClass (src));
4239+ MIRBuilder.buildInstr (SPIRV::OpLogicalNot).addDef (des).addUse (GR.getSPIRVTypeID (boolType)).addUse (src);
4240+ return des;
4241+ };
4242+
4243+ // checking for IsNan
4244+ if (FPClass & fcNan) {
4245+ if (FPClass & fcSNan && FPClass & fcQNan) {
4246+ ResultVec.push_back (instructionTemplate (SPIRV::OpIsNan,boolType,boolType,srcReg));
4247+ } else {
4248+ // isquiet(V) ==> abs(V) >= (unsigned(Inf) | quiet_bit)
4249+ APInt QNaNBitMask = APInt::getOneBitSet (bitWidth, AllOneMantissa.getActiveBits () - 1 );
4250+ APInt InfWithQnanBit = Inf | QNaNBitMask;
4251+ int64_t InfWithQnanBitVal = InfWithQnanBit.getZExtValue ();
4252+ Register constInfwithQuanBit = createConstant (InfWithQnanBitVal);
4253+ Register constIntsrc = instructionTemplate (SPIRV::OpBitcast, intType, intType, srcReg);
4254+ Register TestQnan = instructionTemplate (SPIRV::OpUGreaterThanEqual, boolType, boolType, constIntsrc, constInfwithQuanBit);
4255+ if (FPClass & fcSNan){
4256+ // fcSNan = isNan && !isQNan
4257+ Register notQnan = instructionTemplate (SPIRV::OpLogicalNot, boolType, boolType, TestQnan);
4258+ Register IsNan = instructionTemplate (SPIRV::OpIsNan, boolType, boolType, srcReg);
4259+ ResultVec.push_back (instructionTemplate (SPIRV::OpLogicalAnd, boolType, boolType, IsNan, notQnan));
4260+ }else {
4261+ ResultVec.push_back (TestQnan);
4262+ }
4263+ }
4264+ }
4265+ // checking for isInf
4266+ if (FPClass & fcInf) {
4267+ Register IsInf = instructionTemplate (SPIRV::OpIsInf, boolType, boolType, srcReg);
4268+ if (!((FPClass & fcPosInf)&&(FPClass & fcNegInf))){
4269+ ResultVec.push_back (GetNegPosInstTest (IsInf, FPClass & fcNegInf));
4270+ }else {
4271+ ResultVec.push_back (IsInf);
4272+ }
4273+ }
4274+ // for handling is Normal
4275+ if (FPClass & fcNormal) {
4276+ Register isNormal = instructionTemplate (SPIRV::OpIsNormal, boolType, boolType, srcReg);
4277+ if (!((FPClass & fcNegNormal)&&(FPClass & fcPosNormal))){
4278+ ResultVec.push_back (GetNegPosInstTest (isNormal, FPClass & fcNegNormal));
4279+ }else {
4280+ ResultVec.push_back (isNormal);
4281+ }
4282+ }
4283+ // for handling subnormal
4284+ if (FPClass & fcSubnormal) {
4285+ // issubnormal(V) ==> unsigned(abs(V) - 1) < (all mantissa bits set)
4286+ // APInt zeros = APInt::getZero(bitWidth);
4287+ int64_t AllOneMantissa_int = AllOneMantissa.getZExtValue ();
4288+ Register constAllOneMantisa = createConstant (AllOneMantissa_int);
4289+ Register constantOne = createConstant (1 );
4290+ Register bitCastedsrc = instructionTemplate (SPIRV::OpBitcast, intType, intType, srcReg);
4291+ Register bitCastedSrcMinusOne = instructionTemplate (SPIRV::OpISubS, intType, intType, bitCastedsrc, constantOne);
4292+ Register testIsSubNormal = instructionTemplate (SPIRV::OpULessThan, boolType, boolType, bitCastedSrcMinusOne, constAllOneMantisa);
4293+ if (!((FPClass & fcNegSubnormal)&&(FPClass & fcPosSubnormal))){
4294+ ResultVec.push_back (GetNegPosInstTest (testIsSubNormal, FPClass & fcNegSubnormal));
4295+ }else {
4296+ ResultVec.push_back (testIsSubNormal);
4297+ }
4298+ }
4299+ // check if the number is Zero
4300+ if (FPClass & fcZero){
4301+ auto SetUpCMPToZero = [&](Register BitCastToInt,
4302+ bool IsPositive) -> Register {
4303+ APInt ZeroInt = APInt::getZero (bitWidth);
4304+ Register constantZero;
4305+ if (!IsPositive) {
4306+ ZeroInt.setSignBit ();
4307+ }
4308+ constantZero = createConstant (ZeroInt.getZExtValue ());
4309+ Register isEqual = instructionTemplate (SPIRV::OpIEqual, boolType, boolType, constantZero,BitCastToInt);
4310+ return isEqual;
4311+ };
4312+
4313+ Register bitCastedsrc = instructionTemplate (SPIRV::OpBitcast, intType, intType , srcReg);
4314+ if (FPClass & fcPosZero && FPClass & fcNegZero) {
4315+ APInt ZeroInt = APInt::getZero (bitWidth);
4316+ APInt MaskToClearSignBit = APInt::getSignedMaxValue (bitWidth);
4317+ Register MaskToClearSignBitConst = createConstant (MaskToClearSignBit.getZExtValue ());
4318+ Register zeroConst = createConstant (ZeroInt.getZExtValue ());
4319+ Register bitwiseAndRes = instructionTemplate (SPIRV::OpBitwiseAndS, intType, intType , MaskToClearSignBitConst, bitCastedsrc);
4320+ ResultVec.push_back (instructionTemplate (SPIRV::OpIEqual, boolType, boolType, bitwiseAndRes, zeroConst));
4321+ }else if (FPClass & fcPosZero) {
4322+ ResultVec.push_back (SetUpCMPToZero (bitCastedsrc, true ));
4323+ } else {
4324+ ResultVec.push_back (SetUpCMPToZero (bitCastedsrc, false ));
4325+ }
4326+ }
4327+ Register Result = ResultVec[0 ];
4328+ if (ResultVec.size () > 1 ){
4329+ for (size_t I = 1 ; I < ResultVec.size (); I++) {
4330+ Result = instructionTemplate (SPIRV::OpLogicalOr, boolType, boolType, Result , ResultVec[I]);
4331+ }
4332+ }
4333+ MIRBuilder.buildCopy (ResVReg, Result);
4334+ return true ;
4335+ }
4336+
4337+
40244338namespace llvm {
40254339InstructionSelector *
40264340createSPIRVInstructionSelector (const SPIRVTargetMachine &TM,
0 commit comments