Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
326 changes: 325 additions & 1 deletion llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,15 @@
#include "llvm/IR/IntrinsicsSPIRV.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/ErrorHandling.h"

#define DEBUG_TYPE "spirv-isel"
#ifdef _MSC_VER
#define NO_OPTIMIZE __pragma(optimize("", off))
#elif defined(__GNUC__) || defined(__clang__)
#define NO_OPTIMIZE __attribute__((optimize("O0")))
#else
#define NO_OPTIMIZE
#endif


using namespace llvm;
namespace CL = SPIRV::OpenCLExtInst;
Expand Down Expand Up @@ -305,6 +312,8 @@ class SPIRVInstructionSelector : public InstructionSelector {
bool selectImageWriteIntrinsic(MachineInstr &I) const;
bool selectResourceGetPointer(Register &ResVReg, const SPIRVType *ResType,
MachineInstr &I) const;
bool selectIsFpclass(Register ResVReg, const SPIRVType *ResType,
MachineInstr &I) const;

// Utilities
std::pair<Register, bool>
Expand Down Expand Up @@ -893,6 +902,8 @@ bool SPIRVInstructionSelector::spvSelect(Register ResVReg,
case TargetOpcode::G_UBSANTRAP:
case TargetOpcode::DBG_LABEL:
return true;
case TargetOpcode::G_IS_FPCLASS:
return selectIsFpclass(ResVReg,ResType,I);

default:
return false;
Expand Down Expand Up @@ -4021,6 +4032,319 @@ bool SPIRVInstructionSelector::loadHandleBeforePosition(
.constrainAllUses(TII, TRI, RBI);
}

llvm::Type* getLLVMType(LLT srcLLTType, llvm::LLVMContext &context) {
if (srcLLTType.isScalar()) {
switch (srcLLTType.getSizeInBits()) {
case 32: return llvm::Type::getFloatTy(context); // float
case 64: return llvm::Type::getDoubleTy(context); // double
default:
llvm_unreachable("Unsupported scalar floating-point type!");
}
} else if (srcLLTType.isVector()) {
unsigned numElements = srcLLTType.getNumElements();
LLT elemType = srcLLTType.getElementType();

if (elemType.isScalar()) {
llvm::Type* baseType = nullptr;
switch (elemType.getSizeInBits()) {
case 32: baseType = llvm::Type::getFloatTy(context); break;
case 64: baseType = llvm::Type::getDoubleTy(context); break;
case 16: baseType = llvm::Type::getHalfTy(context); break;
default:
llvm_unreachable("Unsupported vector element type!");
}
return llvm::VectorType::get(baseType, llvm::ElementCount::getFixed(numElements));
}
}

llvm_unreachable("Unsupported LLT type conversion!");
}

int getBitWidth(LLT srcLLTType){
int bitWidth;
if(srcLLTType.isScalar()){
bitWidth = srcLLTType.getSizeInBits();
}else if(srcLLTType.isVector()){
bitWidth = srcLLTType.getElementType().getSizeInBits();
}
return bitWidth;
}

bool SPIRVInstructionSelector::selectIsFpclass(Register ResVReg, const SPIRVType *ResType, MachineInstr &I) const {
MachineIRBuilder MIRBuilder(I);
//spivType creation for Creating Intermediate Registers
std::vector<Register> ResultVec;
Register srcReg = I.getOperand(1).getReg();
SPIRVType* srcSPIRVType = GR.getSPIRVTypeForVReg(srcReg);
LLT srcLLTType = GR.getRegType(srcSPIRVType);
Register ResReg = MRI->createGenericVirtualRegister(srcLLTType);
DstOp Res(ResReg);
Type* srcLLVMType = getLLVMType(srcLLTType, GR.CurMF->getFunction().getContext());
int bitWidth = getBitWidth(srcLLTType);

Type *IntOpLLVMTy = IntegerType::getIntNTy(GR.CurMF->getFunction().getContext(), bitWidth);
if (srcLLVMType->isVectorTy())
IntOpLLVMTy = FixedVectorType::get(IntOpLLVMTy, cast<FixedVectorType>(srcLLVMType)->getNumElements());


SPIRVType *boolType = GR.getOrCreateSPIRVBoolType(I, TII);
SPIRVType *intType = GR.getOrCreateSPIRVIntegerType(bitWidth,I,TII);
SPIRVType *ScalarInt = GR.getOrCreateSPIRVIntegerType(bitWidth,I,TII);
unsigned N = GR.getScalarOrVectorComponentCount(ResType);
if (N > 1){
boolType = GR.getOrCreateSPIRVVectorType(boolType, N, I, TII);
intType = GR.getOrCreateSPIRVVectorType(intType, N, I, TII);
}

//function to create constant
auto createConstant = [&](int64_t val) -> Register {
Register constant = MRI->createVirtualRegister(GR.getRegClass(ScalarInt));
MIRBuilder.buildInstr(SPIRV::OpConstantI)
.addDef(constant)
.addUse(GR.getSPIRVTypeID(ScalarInt))
.addImm(val);
if (srcLLVMType->isVectorTy()) {
Register compositeConstant = MRI->createVirtualRegister(GR.getRegClass(intType));
auto MIB = MIRBuilder.buildInstr(SPIRV::OpConstantComposite)
.addDef(compositeConstant)
.addUse(GR.getSPIRVTypeID(intType));
unsigned numElements = srcLLTType.getNumElements();
for (unsigned i = 0; i < numElements; ++i) {
MIB.addUse(constant);
}
return compositeConstant;
}
return constant;
};
//Type* srcLLVMType;
//checking src type to create llvm Type Creation
//srcLLVMType = (bitWidth == 32) ? Type::getFloatTy(GR.CurMF->getFunction().getContext()) : Type::getDoubleTy(GR.CurMF->getFunction().getContext());
// llvm::FPClassTest FPClass = static_cast<llvm::FPClassTest>(I.getOperand(2).getImm());
llvm::FPClassTest FPClass = static_cast<llvm::FPClassTest>(I.getOperand(2).getImm());
//edge cases
if (FPClass == 0){
// Register trueVector = createConstant(1);
// MIRBuilder.buildCopy(ResVReg, trueVector);
return true;
}
if (FPClass == fcAllFlags){
// Register trueVector = createConstant(1);
// MIRBuilder.buildCopy(ResVReg, trueVector);
return true;
}


// Instruction Template
auto instructionTemplate = [&](unsigned Opcode, SPIRVType* DestType, SPIRVType* ReturnType, auto&&... args) -> Register {
Register result = MRI->createVirtualRegister(GR.getRegClass(DestType));
auto &Instr = MIRBuilder.buildInstr(Opcode)
.addDef(result)
.addUse(GR.getSPIRVTypeID(DestType));

(void)std::initializer_list<int>{
([&](auto&& arg) NO_OPTIMIZE {
if constexpr (std::is_integral_v<std::decay_t<decltype(arg)>>) {
Instr.addImm(arg);
} else {
Instr.addUse(arg);
}
}(args), 0)...
};
return result;
};

//function to check if the sign bit is set or not
//1 sign is set if 0 sign is not set
//inf && sign == 1 then -ve infinity
// inf && sign == 0 then not -ve infinity
//--------positive cases -------------//
// inf && ~sign - positive Infinity
// inf && ~sign - not positive Infinity
//sign bit test logic
Register SignBitTest = Register(0);
Register NoSignTest = Register(0);
auto GetNegPosInstTest = [&](Register TestInst,
bool IsNegative) -> Register {
Register result = MRI->createVirtualRegister(GR.getRegClass(boolType));
if(!SignBitTest.isValid()){
SignBitTest = MRI->createVirtualRegister(GR.getRegClass(boolType));
MIRBuilder.buildInstr(SPIRV::OpSignBitSet)
.addDef(SignBitTest)
.addUse(GR.getSPIRVTypeID(boolType))
.addUse(srcReg);
}

if (IsNegative) {
MIRBuilder.buildInstr(SPIRV::OpLogicalAnd)
.addDef(result)
.addUse(GR.getSPIRVTypeID(boolType))
.addUse(SignBitTest)
.addUse(TestInst);
return result;
}
if(!NoSignTest.isValid()){
NoSignTest = MRI->createVirtualRegister(GR.getRegClass(boolType));
MIRBuilder.buildInstr(SPIRV::OpLogicalNot)
.addDef(NoSignTest)
.addUse(GR.getSPIRVTypeID(boolType))
.addUse(SignBitTest);
}

MIRBuilder.buildInstr(SPIRV::OpLogicalAnd)
.addDef(result)
.addUse(GR.getSPIRVTypeID(boolType))
.addUse(NoSignTest)
.addUse(TestInst);
return result;
};

// if (srcLLVMType->isVectorTy())
// llvm::errs() << "is Vector Type" << "\n";
// IntOpLLVMTy = FixedVectorType::get(
// IntOpLLVMTy, cast<FixedVectorType>(srcLLVMType)->getNumElements());

const llvm::fltSemantics &Semantics =srcLLVMType->getScalarType()->getFltSemantics();
APInt Inf = APFloat::getInf(Semantics).bitcastToAPInt();
APInt AllOneMantissa = APFloat::getLargest(Semantics).bitcastToAPInt() & ~Inf;

//Mask Inversion Logic
auto GetInvertedFPClassTest =
[](const llvm::FPClassTest Test) -> llvm::FPClassTest {
llvm::FPClassTest InvertedTest = ~Test & fcAllFlags;
switch (InvertedTest) {
case fcNan:
case fcSNan:
case fcQNan:
case fcInf:
case fcPosInf:
case fcNegInf:
case fcNormal:
case fcPosNormal:
case fcNegNormal:
case fcSubnormal:
case fcPosSubnormal:
case fcNegSubnormal:
case fcZero:
case fcPosZero:
case fcNegZero:
case fcFinite:
case fcPosFinite:
case fcNegFinite:
return InvertedTest;
}
return fcNone;
};

//if is possible to invert then invert the test
bool IsInverted = false;
if (llvm::FPClassTest InvertedCheck = GetInvertedFPClassTest(FPClass)) {
IsInverted = true;
FPClass = InvertedCheck;
}

auto GetInvertedTestIfNeeded = [&](Register src){
if (!IsInverted)
return src;
Register des = MRI->createVirtualRegister(MRI->getRegClass(src));
MIRBuilder.buildInstr(SPIRV::OpLogicalNot).addDef(des).addUse(GR.getSPIRVTypeID(boolType)).addUse(src);
return des;
};

//checking for IsNan
if (FPClass & fcNan) {
if (FPClass & fcSNan && FPClass & fcQNan) {
ResultVec.push_back(instructionTemplate(SPIRV::OpIsNan,boolType,boolType,srcReg));
} else {
// isquiet(V) ==> abs(V) >= (unsigned(Inf) | quiet_bit)
APInt QNaNBitMask = APInt::getOneBitSet(bitWidth, AllOneMantissa.getActiveBits() - 1);
APInt InfWithQnanBit = Inf | QNaNBitMask;
int64_t InfWithQnanBitVal = InfWithQnanBit.getZExtValue();
Register constInfwithQuanBit = createConstant(InfWithQnanBitVal);
Register constIntsrc = instructionTemplate(SPIRV::OpBitcast, intType, intType, srcReg);
Register TestQnan = instructionTemplate(SPIRV::OpUGreaterThanEqual, boolType, boolType, constIntsrc, constInfwithQuanBit);
if(FPClass & fcSNan){
//fcSNan = isNan && !isQNan
Register notQnan = instructionTemplate(SPIRV::OpLogicalNot, boolType, boolType, TestQnan);
Register IsNan = instructionTemplate(SPIRV::OpIsNan, boolType, boolType, srcReg);
ResultVec.push_back(instructionTemplate(SPIRV::OpLogicalAnd, boolType, boolType, IsNan, notQnan));
}else{
ResultVec.push_back(TestQnan);
}
}
}
//checking for isInf
if (FPClass & fcInf) {
Register IsInf = instructionTemplate(SPIRV::OpIsInf, boolType, boolType, srcReg);
if(!((FPClass & fcPosInf)&&(FPClass & fcNegInf))){
ResultVec.push_back(GetNegPosInstTest(IsInf, FPClass & fcNegInf));
}else{
ResultVec.push_back(IsInf);
}
}
//for handling is Normal
if (FPClass & fcNormal) {
Register isNormal = instructionTemplate(SPIRV::OpIsNormal, boolType, boolType, srcReg);
if(!((FPClass & fcNegNormal)&&(FPClass & fcPosNormal))){
ResultVec.push_back(GetNegPosInstTest(isNormal, FPClass & fcNegNormal));
}else{
ResultVec.push_back(isNormal);
}
}
//for handling subnormal
if (FPClass & fcSubnormal) {
// issubnormal(V) ==> unsigned(abs(V) - 1) < (all mantissa bits set)
//APInt zeros = APInt::getZero(bitWidth);
int64_t AllOneMantissa_int = AllOneMantissa.getZExtValue();
Register constAllOneMantisa = createConstant(AllOneMantissa_int);
Register constantOne = createConstant(1);
Register bitCastedsrc = instructionTemplate(SPIRV::OpBitcast, intType, intType, srcReg);
Register bitCastedSrcMinusOne = instructionTemplate(SPIRV::OpISubS, intType, intType, bitCastedsrc, constantOne);
Register testIsSubNormal = instructionTemplate(SPIRV::OpULessThan, boolType, boolType, bitCastedSrcMinusOne, constAllOneMantisa);
if(!((FPClass & fcNegSubnormal)&&(FPClass & fcPosSubnormal))){
ResultVec.push_back(GetNegPosInstTest(testIsSubNormal, FPClass & fcNegSubnormal));
}else{
ResultVec.push_back(testIsSubNormal);
}
}
//check if the number is Zero
if(FPClass & fcZero){
auto SetUpCMPToZero = [&, bitWidth](Register BitCastToInt,
bool IsPositive) -> Register {
APInt ZeroInt = APInt::getZero(bitWidth);
Register constantZero;
if (!IsPositive) {
ZeroInt.setSignBit();
}
constantZero = createConstant(ZeroInt.getZExtValue());
Register isEqual = instructionTemplate(SPIRV::OpIEqual, boolType, boolType, constantZero,BitCastToInt);
return isEqual;
};

Register bitCastedsrc = instructionTemplate(SPIRV::OpBitcast, intType, intType , srcReg);
if (FPClass & fcPosZero && FPClass & fcNegZero) {
APInt ZeroInt = APInt::getZero(bitWidth);
APInt MaskToClearSignBit = APInt::getSignedMaxValue(bitWidth);
Register MaskToClearSignBitConst = createConstant(MaskToClearSignBit.getZExtValue());
Register zeroConst = createConstant(ZeroInt.getZExtValue());
Register bitwiseAndRes = instructionTemplate(SPIRV::OpBitwiseAndS, intType, intType , MaskToClearSignBitConst, bitCastedsrc);
ResultVec.push_back(instructionTemplate(SPIRV::OpIEqual, boolType, boolType, bitwiseAndRes, zeroConst));
}else if (FPClass & fcPosZero) {
ResultVec.push_back(SetUpCMPToZero(bitCastedsrc, true));
} else {
ResultVec.push_back(SetUpCMPToZero(bitCastedsrc, false));
}
}
Register Result = ResultVec[0];
if(ResultVec.size() > 1){
for (size_t I = 1; I < ResultVec.size(); I++) {
Result = instructionTemplate(SPIRV::OpLogicalOr, boolType, boolType, Result , ResultVec[I]);
}
}
MIRBuilder.buildCopy(ResVReg, Result);
return true;
}


namespace llvm {
InstructionSelector *
createSPIRVInstructionSelector(const SPIRVTargetMachine &TM,
Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,8 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
getActionDefinitionsBuilder(G_FCOPYSIGN)
.legalForCartesianProduct(allFloatScalarsAndVectors,
allFloatScalarsAndVectors);
getActionDefinitionsBuilder(G_IS_FPCLASS)
.legalForCartesianProduct(allBoolScalarsAndVectors, allFloatScalarsAndVectors);

getActionDefinitionsBuilder(G_FPOWI).legalForCartesianProduct(
allFloatScalarsAndVectors, allIntScalarsAndVectors);
Expand Down
Loading