Skip to content

[HLSL] Implement elementwise firstbitlow builtin #116858

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 24 commits into from
Jan 15, 2025
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
184 changes: 80 additions & 104 deletions llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3178,98 +3178,74 @@ bool SPIRVInstructionSelector::selectFirstBitSet64Overflow(
Register ResVReg, const SPIRVType *ResType, MachineInstr &I,
Register SrcReg, unsigned BitSetOpcode, bool SwapPrimarySide) const {

// SPIR-V only allow vecs of size 2,3,4. Calling with a larger vec requires
// creating a param reg and return reg with an invalid vec size. If that is
// resolved then this function is valid for vectors of any component size.
unsigned ComponentCount = GR.getScalarOrVectorComponentCount(ResType);
SPIRVType *BaseType = GR.retrieveScalarOrVectorIntType(ResType);
bool ZeroAsNull = STI.isOpenCLEnv();
Register ConstIntZero =
GR.getOrCreateConstInt(0, I, BaseType, TII, ZeroAsNull);
unsigned LeftComponentCount = ComponentCount / 2;
unsigned RightComponentCount = ComponentCount - LeftComponentCount;
bool LeftIsVector = LeftComponentCount > 1;
assert(ComponentCount < 5 && "Vec 5+ will generate invalid SPIR-V ops");

// Split the SrcReg in half into 2 smaller vec registers
// (ie i64x4 -> i64x2, i64x2)
bool ZeroAsNull = STI.isOpenCLEnv();
MachineIRBuilder MIRBuilder(I);
SPIRVType *OpType = GR.getOrCreateSPIRVIntegerType(64, MIRBuilder);
SPIRVType *LeftVecOpType;
SPIRVType *LeftVecResType;
if (LeftIsVector) {
LeftVecOpType =
GR.getOrCreateSPIRVVectorType(OpType, LeftComponentCount, MIRBuilder);
LeftVecResType =
GR.getOrCreateSPIRVVectorType(BaseType, LeftComponentCount, MIRBuilder);
} else {
LeftVecOpType = OpType;
LeftVecResType = BaseType;
}

SPIRVType *RightVecOpType =
GR.getOrCreateSPIRVVectorType(OpType, RightComponentCount, MIRBuilder);
SPIRVType *RightVecResType =
GR.getOrCreateSPIRVVectorType(BaseType, RightComponentCount, MIRBuilder);
SPIRVType *BaseType = GR.retrieveScalarOrVectorIntType(ResType);
SPIRVType *I64Type = GR.getOrCreateSPIRVIntegerType(64, MIRBuilder);
SPIRVType *I64x2Type = GR.getOrCreateSPIRVVectorType(I64Type, 2, MIRBuilder);
SPIRVType *Vec2ResType =
GR.getOrCreateSPIRVVectorType(BaseType, 2, MIRBuilder);

Register LeftSideIn =
MRI->createVirtualRegister(GR.getRegClass(LeftVecOpType));
Register RightSideIn =
MRI->createVirtualRegister(GR.getRegClass(RightVecOpType));
std::vector<Register> PartialRegs;

bool Result;
// Loops 0, 2, 4, ... but stops one loop early when ComponentCount is odd
unsigned CurrentComponent = 0;
for (; CurrentComponent + 1 < ComponentCount; CurrentComponent += 2) {
Register SubVecReg = MRI->createVirtualRegister(GR.getRegClass(I64x2Type));

// Extract the left half from the SrcReg into LeftSideIn
// accounting for the special case when it only has one element
if (LeftIsVector) {
auto MIB = BuildMI(*I.getParent(), I, I.getDebugLoc(),
TII.get(SPIRV::OpVectorShuffle))
.addDef(LeftSideIn)
.addUse(GR.getSPIRVTypeID(LeftVecOpType))
.addDef(SubVecReg)
.addUse(GR.getSPIRVTypeID(I64x2Type))
.addUse(SrcReg)
// Per the spec, repeat the vector if only one vec is needed
.addUse(SrcReg);

for (unsigned J = 0; J < LeftComponentCount; J++) {
MIB.addImm(J);
}
MIB.addImm(CurrentComponent);
MIB.addImm(CurrentComponent + 1);

Result = MIB.constrainAllUses(TII, TRI, RBI);
} else {
Result =
selectOpWithSrcs(LeftSideIn, LeftVecOpType, I, {SrcReg, ConstIntZero},
SPIRV::OpVectorExtractDynamic);
}
if (!MIB.constrainAllUses(TII, TRI, RBI))
return false;

// Extract the right half from the SrcReg into RightSideIn.
// Right will always be a vector since the only time one element is left is
// when Component == 3, and in that case Left is one element.
auto MIB = BuildMI(*I.getParent(), I, I.getDebugLoc(),
TII.get(SPIRV::OpVectorShuffle))
.addDef(RightSideIn)
.addUse(GR.getSPIRVTypeID(RightVecOpType))
.addUse(SrcReg)
// Per the spec, repeat the vector if only one vec is needed
.addUse(SrcReg);
Register SubVecBitSetReg =
MRI->createVirtualRegister(GR.getRegClass(Vec2ResType));

if (!selectFirstBitSet64(SubVecBitSetReg, Vec2ResType, I, SubVecReg,
BitSetOpcode, SwapPrimarySide))
return false;

for (unsigned J = LeftComponentCount; J < ComponentCount; J++) {
MIB.addImm(J);
PartialRegs.push_back(SubVecBitSetReg);
}

Result = Result && MIB.constrainAllUses(TII, TRI, RBI);
// On odd component counts we need to handle one more component
if (CurrentComponent != ComponentCount) {
Register FinalElemReg = MRI->createVirtualRegister(GR.getRegClass(I64Type));
Register ConstIntLastIdx = GR.getOrCreateConstInt(
ComponentCount - 1, I, BaseType, TII, ZeroAsNull);

// Recursively call selectFirstBitSet64 on the 2 halves
Register LeftSideOut =
MRI->createVirtualRegister(GR.getRegClass(LeftVecResType));
Register RightSideOut =
MRI->createVirtualRegister(GR.getRegClass(RightVecResType));
Result =
Result && selectFirstBitSet64(LeftSideOut, LeftVecResType, I, LeftSideIn,
BitSetOpcode, SwapPrimarySide);
Result =
Result && selectFirstBitSet64(RightSideOut, RightVecResType, I,
RightSideIn, BitSetOpcode, SwapPrimarySide);
if (!selectOpWithSrcs(FinalElemReg, I64Type, I, {SrcReg, ConstIntLastIdx},
SPIRV::OpVectorExtractDynamic))
return false;

// Join the two resulting registers back into the return type
// (ie i32x2, i32x2 -> i32x4)
return Result &&
selectOpWithSrcs(ResVReg, ResType, I, {LeftSideOut, RightSideOut},
Register FinalElemBitSetReg =
MRI->createVirtualRegister(GR.getRegClass(BaseType));

if (!selectFirstBitSet64(FinalElemBitSetReg, BaseType, I, FinalElemReg,
BitSetOpcode, SwapPrimarySide))
return false;

PartialRegs.push_back(FinalElemBitSetReg);
}

// Join all the resulting registers back into the return type in order
// (ie i32x2, i32x2, i32x1 -> i32x5)
return selectOpWithSrcs(ResVReg, ResType, I, PartialRegs,
SPIRV::OpCompositeConstruct);
}

Expand Down Expand Up @@ -3299,13 +3275,15 @@ bool SPIRVInstructionSelector::selectFirstBitSet64(
GR.getOrCreateSPIRVVectorType(BaseType, 2 * ComponentCount, MIRBuilder);
Register BitcastReg =
MRI->createVirtualRegister(GR.getRegClass(PostCastType));
bool Result =
selectOpWithSrcs(BitcastReg, PostCastType, I, {SrcReg}, SPIRV::OpBitcast);

if (!selectOpWithSrcs(BitcastReg, PostCastType, I, {SrcReg},
SPIRV::OpBitcast))
return false;

// 2. Find the first set bit from the primary side for all the pieces in #1
Register FBSReg = MRI->createVirtualRegister(GR.getRegClass(PostCastType));
Result = Result && selectFirstBitSet32(FBSReg, PostCastType, I, BitcastReg,
BitSetOpcode);
if (!selectFirstBitSet32(FBSReg, PostCastType, I, BitcastReg, BitSetOpcode))
return false;

// 3. Split result vector into high bits and low bits
Register HighReg = MRI->createVirtualRegister(GR.getRegClass(ResType));
Expand All @@ -3314,12 +3292,12 @@ bool SPIRVInstructionSelector::selectFirstBitSet64(
bool IsScalarRes = ResType->getOpcode() != SPIRV::OpTypeVector;
if (IsScalarRes) {
// if scalar do a vector extract
Result =
Result && selectOpWithSrcs(HighReg, ResType, I, {FBSReg, ConstIntZero},
SPIRV::OpVectorExtractDynamic);
Result =
Result && selectOpWithSrcs(LowReg, ResType, I, {FBSReg, ConstIntOne},
SPIRV::OpVectorExtractDynamic);
if (!selectOpWithSrcs(HighReg, ResType, I, {FBSReg, ConstIntZero},
SPIRV::OpVectorExtractDynamic))
return false;
if (!selectOpWithSrcs(LowReg, ResType, I, {FBSReg, ConstIntOne},
SPIRV::OpVectorExtractDynamic))
return false;
} else {
// if vector do a shufflevector
auto MIB = BuildMI(*I.getParent(), I, I.getDebugLoc(),
Expand All @@ -3334,7 +3312,9 @@ bool SPIRVInstructionSelector::selectFirstBitSet64(
for (unsigned J = 0; J < ComponentCount * 2; J += 2) {
MIB.addImm(J);
}
Result = Result && MIB.constrainAllUses(TII, TRI, RBI);

if (!MIB.constrainAllUses(TII, TRI, RBI))
return false;

MIB = BuildMI(*I.getParent(), I, I.getDebugLoc(),
TII.get(SPIRV::OpVectorShuffle))
Expand All @@ -3348,7 +3328,8 @@ bool SPIRVInstructionSelector::selectFirstBitSet64(
for (unsigned J = 1; J < ComponentCount * 2; J += 2) {
MIB.addImm(J);
}
Result = Result && MIB.constrainAllUses(TII, TRI, RBI);
if (!MIB.constrainAllUses(TII, TRI, RBI))
return false;
}

// 4. Check the result. When primary bits == -1 use secondary, otherwise use
Expand Down Expand Up @@ -3378,10 +3359,10 @@ bool SPIRVInstructionSelector::selectFirstBitSet64(
AddOp = SPIRV::OpIAddV;
}

Register PrimaryReg;
Register SecondaryReg;
Register PrimaryShiftReg;
Register SecondaryShiftReg;
Register PrimaryReg = HighReg;
Register SecondaryReg = LowReg;
Register PrimaryShiftReg = Reg32;
Register SecondaryShiftReg = Reg0;

// By default the emitted opcodes check for the set bit from the MSB side.
// Setting SwapPrimarySide checks the set bit from the LSB side
Expand All @@ -3390,32 +3371,27 @@ bool SPIRVInstructionSelector::selectFirstBitSet64(
SecondaryReg = HighReg;
PrimaryShiftReg = Reg0;
SecondaryShiftReg = Reg32;
} else {
PrimaryReg = HighReg;
SecondaryReg = LowReg;
PrimaryShiftReg = Reg32;
SecondaryShiftReg = Reg0;
}

// Check if the primary bits are == -1
Register BReg = MRI->createVirtualRegister(GR.getRegClass(BoolType));
Result = Result && selectOpWithSrcs(BReg, BoolType, I,
{PrimaryReg, NegOneReg}, SPIRV::OpIEqual);
if (!selectOpWithSrcs(BReg, BoolType, I, {PrimaryReg, NegOneReg},
SPIRV::OpIEqual))
return false;

// Select secondary bits if true in BReg, otherwise primary bits
Register TmpReg = MRI->createVirtualRegister(GR.getRegClass(ResType));
Result =
Result && selectOpWithSrcs(TmpReg, ResType, I,
{BReg, SecondaryReg, PrimaryReg}, SelectOp);
if (!selectOpWithSrcs(TmpReg, ResType, I, {BReg, SecondaryReg, PrimaryReg},
SelectOp))
return false;

// 5. Add 32 when high bits are used, otherwise 0 for low bits
Register ValReg = MRI->createVirtualRegister(GR.getRegClass(ResType));
Result = Result && selectOpWithSrcs(
ValReg, ResType, I,
{BReg, SecondaryShiftReg, PrimaryShiftReg}, SelectOp);
if (!selectOpWithSrcs(ValReg, ResType, I,
{BReg, SecondaryShiftReg, PrimaryShiftReg}, SelectOp))
return false;

return Result &&
selectOpWithSrcs(ResVReg, ResType, I, {ValReg, TmpReg}, AddOp);
return selectOpWithSrcs(ResVReg, ResType, I, {ValReg, TmpReg}, AddOp);
}

bool SPIRVInstructionSelector::selectFirstBitHigh(Register ResVReg,
Expand Down
Loading
Loading