Skip to content

Commit f4ef94a

Browse files
committed
Added support for isFpClass
1 parent b3d5056 commit f4ef94a

File tree

3 files changed

+706
-0
lines changed

3 files changed

+706
-0
lines changed

llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp

Lines changed: 314 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,8 @@ class SPIRVInstructionSelector : public InstructionSelector {
305305
bool selectImageWriteIntrinsic(MachineInstr &I) const;
306306
bool selectResourceGetPointer(Register &ResVReg, const SPIRVType *ResType,
307307
MachineInstr &I) const;
308+
bool selectIsFpclass(Register ResVReg, const SPIRVType *ResType,
309+
MachineInstr &I) const;
308310

309311
// Utilities
310312
std::pair<Register, bool>
@@ -893,6 +895,8 @@ bool SPIRVInstructionSelector::spvSelect(Register ResVReg,
893895
case TargetOpcode::G_UBSANTRAP:
894896
case TargetOpcode::DBG_LABEL:
895897
return true;
898+
case TargetOpcode::G_IS_FPCLASS:
899+
return selectIsFpclass(ResVReg,ResType,I);
896900

897901
default:
898902
return false;
@@ -4021,6 +4025,316 @@ bool SPIRVInstructionSelector::loadHandleBeforePosition(
40214025
.constrainAllUses(TII, TRI, RBI);
40224026
}
40234027

4028+
llvm::Type* getLLVMType(LLT srcLLTType, llvm::LLVMContext &context) {
4029+
if (srcLLTType.isScalar()) {
4030+
switch (srcLLTType.getSizeInBits()) {
4031+
case 32: return llvm::Type::getFloatTy(context); // float
4032+
case 64: return llvm::Type::getDoubleTy(context); // double
4033+
default:
4034+
llvm_unreachable("Unsupported scalar floating-point type!");
4035+
}
4036+
} else if (srcLLTType.isVector()) {
4037+
unsigned numElements = srcLLTType.getNumElements();
4038+
LLT elemType = srcLLTType.getElementType();
4039+
4040+
if (elemType.isScalar()) {
4041+
llvm::Type* baseType = nullptr;
4042+
switch (elemType.getSizeInBits()) {
4043+
case 32: baseType = llvm::Type::getFloatTy(context); break;
4044+
case 64: baseType = llvm::Type::getDoubleTy(context); break;
4045+
case 16: baseType = llvm::Type::getHalfTy(context); break;
4046+
default:
4047+
llvm_unreachable("Unsupported vector element type!");
4048+
}
4049+
return llvm::VectorType::get(baseType, llvm::ElementCount::getFixed(numElements));
4050+
}
4051+
}
4052+
4053+
llvm_unreachable("Unsupported LLT type conversion!");
4054+
}
4055+
4056+
int getBitWidth(LLT srcLLTType){
4057+
int bitWidth;
4058+
if(srcLLTType.isScalar()){
4059+
bitWidth = srcLLTType.getSizeInBits();
4060+
}else if(srcLLTType.isVector()){
4061+
bitWidth = srcLLTType.getElementType().getSizeInBits();
4062+
}
4063+
return bitWidth;
4064+
}
4065+
4066+
bool SPIRVInstructionSelector::selectIsFpclass(Register ResVReg, const SPIRVType *ResType, MachineInstr &I) const {
4067+
MachineIRBuilder MIRBuilder(I);
4068+
//spivType creation for Creating Intermediate Registers
4069+
std::vector<Register> ResultVec;
4070+
Register srcReg = I.getOperand(1).getReg();
4071+
SPIRVType* srcSPIRVType = GR.getSPIRVTypeForVReg(srcReg);
4072+
LLT srcLLTType = GR.getRegType(srcSPIRVType);
4073+
Register ResReg = MRI->createGenericVirtualRegister(srcLLTType);
4074+
DstOp Res(ResReg);
4075+
Type* srcLLVMType = getLLVMType(srcLLTType, GR.CurMF->getFunction().getContext());
4076+
int bitWidth = getBitWidth(srcLLTType);
4077+
4078+
Type *IntOpLLVMTy = IntegerType::getIntNTy(GR.CurMF->getFunction().getContext(), bitWidth);
4079+
if (srcLLVMType->isVectorTy())
4080+
IntOpLLVMTy = FixedVectorType::get(IntOpLLVMTy, cast<FixedVectorType>(srcLLVMType)->getNumElements());
4081+
4082+
4083+
SPIRVType *boolType = GR.getOrCreateSPIRVBoolType(I, TII);
4084+
SPIRVType *intType = GR.getOrCreateSPIRVIntegerType(bitWidth,I,TII);
4085+
SPIRVType *ScalarInt = GR.getOrCreateSPIRVIntegerType(bitWidth,I,TII);
4086+
unsigned N = GR.getScalarOrVectorComponentCount(ResType);
4087+
if (N > 1){
4088+
boolType = GR.getOrCreateSPIRVVectorType(boolType, N, I, TII);
4089+
intType = GR.getOrCreateSPIRVVectorType(intType, N, I, TII);
4090+
}
4091+
4092+
//function to create constant
4093+
auto createConstant = [&](int64_t val) -> Register {
4094+
Register constant = MRI->createVirtualRegister(GR.getRegClass(ScalarInt));
4095+
MIRBuilder.buildInstr(SPIRV::OpConstantI)
4096+
.addDef(constant)
4097+
.addUse(GR.getSPIRVTypeID(ScalarInt))
4098+
.addImm(val);
4099+
if (srcLLVMType->isVectorTy()) {
4100+
Register compositeConstant = MRI->createVirtualRegister(GR.getRegClass(intType));
4101+
auto MIB = MIRBuilder.buildInstr(SPIRV::OpConstantComposite)
4102+
.addDef(compositeConstant)
4103+
.addUse(GR.getSPIRVTypeID(intType));
4104+
unsigned numElements = srcLLTType.getNumElements();
4105+
for (unsigned i = 0; i < numElements; ++i) {
4106+
MIB.addUse(constant);
4107+
}
4108+
return compositeConstant;
4109+
}
4110+
return constant;
4111+
};
4112+
//Type* srcLLVMType;
4113+
//checking src type to create llvm Type Creation
4114+
//srcLLVMType = (bitWidth == 32) ? Type::getFloatTy(GR.CurMF->getFunction().getContext()) : Type::getDoubleTy(GR.CurMF->getFunction().getContext());
4115+
// llvm::FPClassTest FPClass = static_cast<llvm::FPClassTest>(I.getOperand(2).getImm());
4116+
llvm::FPClassTest FPClass = static_cast<llvm::FPClassTest>(I.getOperand(2).getImm());
4117+
//edge cases
4118+
if (FPClass == 0){
4119+
// Register trueVector = createConstant(1);
4120+
// MIRBuilder.buildCopy(ResVReg, trueVector);
4121+
return true;
4122+
}
4123+
if (FPClass == fcAllFlags){
4124+
// Register trueVector = createConstant(1);
4125+
// MIRBuilder.buildCopy(ResVReg, trueVector);
4126+
return true;
4127+
}
4128+
4129+
4130+
//Instruction Template
4131+
auto instructionTemplate = [&](unsigned Opcode, SPIRVType* DestType, SPIRVType* ReturnType, auto&&... args) -> Register {
4132+
Register result = MRI->createVirtualRegister(GR.getRegClass(DestType));
4133+
auto &Instr = MIRBuilder.buildInstr(Opcode)
4134+
.addDef(result)
4135+
.addUse(GR.getSPIRVTypeID(DestType));
4136+
4137+
([&](auto&& arg) __attribute__((optimize("O0"))) {
4138+
if(std::is_integral_v<std::decay_t<decltype(arg)>>) {
4139+
Instr.addImm(arg);
4140+
}else{
4141+
Instr.addUse(arg);
4142+
}
4143+
}(args), ...);
4144+
return result;
4145+
};
4146+
//function to check if the sign bit is set or not
4147+
//1 sign is set if 0 sign is not set
4148+
//inf && sign == 1 then -ve infinity
4149+
// inf && sign == 0 then not -ve infinity
4150+
//--------positive cases -------------//
4151+
// inf && ~sign - positive Infinity
4152+
// inf && ~sign - not positive Infinity
4153+
//sign bit test logic
4154+
Register SignBitTest = Register(0);
4155+
Register NoSignTest = Register(0);
4156+
auto GetNegPosInstTest = [&](Register TestInst,
4157+
bool IsNegative) -> Register {
4158+
Register result = MRI->createVirtualRegister(GR.getRegClass(boolType));
4159+
if(!SignBitTest.isValid()){
4160+
SignBitTest = MRI->createVirtualRegister(GR.getRegClass(boolType));
4161+
MIRBuilder.buildInstr(SPIRV::OpSignBitSet)
4162+
.addDef(SignBitTest)
4163+
.addUse(GR.getSPIRVTypeID(boolType))
4164+
.addUse(srcReg);
4165+
}
4166+
4167+
if (IsNegative) {
4168+
MIRBuilder.buildInstr(SPIRV::OpLogicalAnd)
4169+
.addDef(result)
4170+
.addUse(GR.getSPIRVTypeID(boolType))
4171+
.addUse(SignBitTest)
4172+
.addUse(TestInst);
4173+
return result;
4174+
}
4175+
if(!NoSignTest.isValid()){
4176+
NoSignTest = MRI->createVirtualRegister(GR.getRegClass(boolType));
4177+
MIRBuilder.buildInstr(SPIRV::OpLogicalNot)
4178+
.addDef(NoSignTest)
4179+
.addUse(GR.getSPIRVTypeID(boolType))
4180+
.addUse(SignBitTest);
4181+
}
4182+
4183+
MIRBuilder.buildInstr(SPIRV::OpLogicalAnd)
4184+
.addDef(result)
4185+
.addUse(GR.getSPIRVTypeID(boolType))
4186+
.addUse(NoSignTest)
4187+
.addUse(TestInst);
4188+
return result;
4189+
};
4190+
4191+
// if (srcLLVMType->isVectorTy())
4192+
// llvm::errs() << "is Vector Type" << "\n";
4193+
// IntOpLLVMTy = FixedVectorType::get(
4194+
// IntOpLLVMTy, cast<FixedVectorType>(srcLLVMType)->getNumElements());
4195+
4196+
const llvm::fltSemantics &Semantics =srcLLVMType->getScalarType()->getFltSemantics();
4197+
APInt Inf = APFloat::getInf(Semantics).bitcastToAPInt();
4198+
APInt AllOneMantissa = APFloat::getLargest(Semantics).bitcastToAPInt() & ~Inf;
4199+
4200+
//Mask Inversion Logic
4201+
auto GetInvertedFPClassTest =
4202+
[](const llvm::FPClassTest Test) -> llvm::FPClassTest {
4203+
llvm::FPClassTest InvertedTest = ~Test & fcAllFlags;
4204+
switch (InvertedTest) {
4205+
case fcNan:
4206+
case fcSNan:
4207+
case fcQNan:
4208+
case fcInf:
4209+
case fcPosInf:
4210+
case fcNegInf:
4211+
case fcNormal:
4212+
case fcPosNormal:
4213+
case fcNegNormal:
4214+
case fcSubnormal:
4215+
case fcPosSubnormal:
4216+
case fcNegSubnormal:
4217+
case fcZero:
4218+
case fcPosZero:
4219+
case fcNegZero:
4220+
case fcFinite:
4221+
case fcPosFinite:
4222+
case fcNegFinite:
4223+
return InvertedTest;
4224+
}
4225+
return fcNone;
4226+
};
4227+
4228+
//if is possible to invert then invert the test
4229+
bool IsInverted = false;
4230+
if (llvm::FPClassTest InvertedCheck = GetInvertedFPClassTest(FPClass)) {
4231+
IsInverted = true;
4232+
FPClass = InvertedCheck;
4233+
}
4234+
4235+
auto GetInvertedTestIfNeeded = [&](Register src){
4236+
if (!IsInverted)
4237+
return src;
4238+
Register des = MRI->createVirtualRegister(MRI->getRegClass(src));
4239+
MIRBuilder.buildInstr(SPIRV::OpLogicalNot).addDef(des).addUse(GR.getSPIRVTypeID(boolType)).addUse(src);
4240+
return des;
4241+
};
4242+
4243+
//checking for IsNan
4244+
if (FPClass & fcNan) {
4245+
if (FPClass & fcSNan && FPClass & fcQNan) {
4246+
ResultVec.push_back(instructionTemplate(SPIRV::OpIsNan,boolType,boolType,srcReg));
4247+
} else {
4248+
// isquiet(V) ==> abs(V) >= (unsigned(Inf) | quiet_bit)
4249+
APInt QNaNBitMask = APInt::getOneBitSet(bitWidth, AllOneMantissa.getActiveBits() - 1);
4250+
APInt InfWithQnanBit = Inf | QNaNBitMask;
4251+
int64_t InfWithQnanBitVal = InfWithQnanBit.getZExtValue();
4252+
Register constInfwithQuanBit = createConstant(InfWithQnanBitVal);
4253+
Register constIntsrc = instructionTemplate(SPIRV::OpBitcast, intType, intType, srcReg);
4254+
Register TestQnan = instructionTemplate(SPIRV::OpUGreaterThanEqual, boolType, boolType, constIntsrc, constInfwithQuanBit);
4255+
if(FPClass & fcSNan){
4256+
//fcSNan = isNan && !isQNan
4257+
Register notQnan = instructionTemplate(SPIRV::OpLogicalNot, boolType, boolType, TestQnan);
4258+
Register IsNan = instructionTemplate(SPIRV::OpIsNan, boolType, boolType, srcReg);
4259+
ResultVec.push_back(instructionTemplate(SPIRV::OpLogicalAnd, boolType, boolType, IsNan, notQnan));
4260+
}else{
4261+
ResultVec.push_back(TestQnan);
4262+
}
4263+
}
4264+
}
4265+
//checking for isInf
4266+
if (FPClass & fcInf) {
4267+
Register IsInf = instructionTemplate(SPIRV::OpIsInf, boolType, boolType, srcReg);
4268+
if(!((FPClass & fcPosInf)&&(FPClass & fcNegInf))){
4269+
ResultVec.push_back(GetNegPosInstTest(IsInf, FPClass & fcNegInf));
4270+
}else{
4271+
ResultVec.push_back(IsInf);
4272+
}
4273+
}
4274+
//for handling is Normal
4275+
if (FPClass & fcNormal) {
4276+
Register isNormal = instructionTemplate(SPIRV::OpIsNormal, boolType, boolType, srcReg);
4277+
if(!((FPClass & fcNegNormal)&&(FPClass & fcPosNormal))){
4278+
ResultVec.push_back(GetNegPosInstTest(isNormal, FPClass & fcNegNormal));
4279+
}else{
4280+
ResultVec.push_back(isNormal);
4281+
}
4282+
}
4283+
//for handling subnormal
4284+
if (FPClass & fcSubnormal) {
4285+
// issubnormal(V) ==> unsigned(abs(V) - 1) < (all mantissa bits set)
4286+
//APInt zeros = APInt::getZero(bitWidth);
4287+
int64_t AllOneMantissa_int = AllOneMantissa.getZExtValue();
4288+
Register constAllOneMantisa = createConstant(AllOneMantissa_int);
4289+
Register constantOne = createConstant(1);
4290+
Register bitCastedsrc = instructionTemplate(SPIRV::OpBitcast, intType, intType, srcReg);
4291+
Register bitCastedSrcMinusOne = instructionTemplate(SPIRV::OpISubS, intType, intType, bitCastedsrc, constantOne);
4292+
Register testIsSubNormal = instructionTemplate(SPIRV::OpULessThan, boolType, boolType, bitCastedSrcMinusOne, constAllOneMantisa);
4293+
if(!((FPClass & fcNegSubnormal)&&(FPClass & fcPosSubnormal))){
4294+
ResultVec.push_back(GetNegPosInstTest(testIsSubNormal, FPClass & fcNegSubnormal));
4295+
}else{
4296+
ResultVec.push_back(testIsSubNormal);
4297+
}
4298+
}
4299+
//check if the number is Zero
4300+
if(FPClass & fcZero){
4301+
auto SetUpCMPToZero = [&](Register BitCastToInt,
4302+
bool IsPositive) -> Register {
4303+
APInt ZeroInt = APInt::getZero(bitWidth);
4304+
Register constantZero;
4305+
if (!IsPositive) {
4306+
ZeroInt.setSignBit();
4307+
}
4308+
constantZero = createConstant(ZeroInt.getZExtValue());
4309+
Register isEqual = instructionTemplate(SPIRV::OpIEqual, boolType, boolType, constantZero,BitCastToInt);
4310+
return isEqual;
4311+
};
4312+
4313+
Register bitCastedsrc = instructionTemplate(SPIRV::OpBitcast, intType, intType , srcReg);
4314+
if (FPClass & fcPosZero && FPClass & fcNegZero) {
4315+
APInt ZeroInt = APInt::getZero(bitWidth);
4316+
APInt MaskToClearSignBit = APInt::getSignedMaxValue(bitWidth);
4317+
Register MaskToClearSignBitConst = createConstant(MaskToClearSignBit.getZExtValue());
4318+
Register zeroConst = createConstant(ZeroInt.getZExtValue());
4319+
Register bitwiseAndRes = instructionTemplate(SPIRV::OpBitwiseAndS, intType, intType , MaskToClearSignBitConst, bitCastedsrc);
4320+
ResultVec.push_back(instructionTemplate(SPIRV::OpIEqual, boolType, boolType, bitwiseAndRes, zeroConst));
4321+
}else if (FPClass & fcPosZero) {
4322+
ResultVec.push_back(SetUpCMPToZero(bitCastedsrc, true));
4323+
} else {
4324+
ResultVec.push_back(SetUpCMPToZero(bitCastedsrc, false));
4325+
}
4326+
}
4327+
Register Result = ResultVec[0];
4328+
if(ResultVec.size() > 1){
4329+
for (size_t I = 1; I < ResultVec.size(); I++) {
4330+
Result = instructionTemplate(SPIRV::OpLogicalOr, boolType, boolType, Result , ResultVec[I]);
4331+
}
4332+
}
4333+
MIRBuilder.buildCopy(ResVReg, Result);
4334+
return true;
4335+
}
4336+
4337+
40244338
namespace llvm {
40254339
InstructionSelector *
40264340
createSPIRVInstructionSelector(const SPIRVTargetMachine &TM,

llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,8 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
354354
getActionDefinitionsBuilder(G_FCOPYSIGN)
355355
.legalForCartesianProduct(allFloatScalarsAndVectors,
356356
allFloatScalarsAndVectors);
357+
getActionDefinitionsBuilder(G_IS_FPCLASS)
358+
.legalForCartesianProduct(allBoolScalarsAndVectors, allFloatScalarsAndVectors);
357359

358360
getActionDefinitionsBuilder(G_FPOWI).legalForCartesianProduct(
359361
allFloatScalarsAndVectors, allIntScalarsAndVectors);

0 commit comments

Comments
 (0)