Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions llvm/include/llvm/IR/IntrinsicsSPIRV.td
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,14 @@ def int_spv_rsqrt : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty]
[llvm_i32_ty, llvm_i32_ty, llvm_i32_ty,
llvm_i32_ty, llvm_ptr_ty],
[IntrNoMem]>;
def int_spv_resource_counterhandlefromimplicitbinding
: DefaultAttrsIntrinsic<[llvm_any_ty],
[llvm_any_ty, llvm_i32_ty, llvm_i32_ty],
[IntrNoMem]>;
def int_spv_resource_counterhandlefrombinding
: DefaultAttrsIntrinsic<[llvm_any_ty],
[llvm_any_ty, llvm_i32_ty, llvm_i32_ty],
[IntrNoMem]>;

def int_spv_firstbituhigh : DefaultAttrsIntrinsic<[LLVMScalarOrSameVectorWidth<0, llvm_i32_ty>], [llvm_anyint_ty], [IntrNoMem]>;
def int_spv_firstbitshigh : DefaultAttrsIntrinsic<[LLVMScalarOrSameVectorWidth<0, llvm_i32_ty>], [llvm_anyint_ty], [IntrNoMem]>;
Expand Down
134 changes: 134 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -307,13 +307,19 @@ class SPIRVInstructionSelector : public InstructionSelector {
bool selectHandleFromBinding(Register &ResVReg, const SPIRVType *ResType,
MachineInstr &I) const;

bool selectCounterHandleFromBinding(Register &ResVReg,
const SPIRVType *ResType,
MachineInstr &I) const;

bool selectReadImageIntrinsic(Register &ResVReg, const SPIRVType *ResType,
MachineInstr &I) const;
bool selectImageWriteIntrinsic(MachineInstr &I) const;
bool selectResourceGetPointer(Register &ResVReg, const SPIRVType *ResType,
MachineInstr &I) const;
bool selectModf(Register ResVReg, const SPIRVType *ResType,
MachineInstr &I) const;
bool selectUpdateCounter(Register &ResVReg, const SPIRVType *ResType,
MachineInstr &I) const;
bool selectFrexp(Register ResVReg, const SPIRVType *ResType,
MachineInstr &I) const;
// Utilities
Expand Down Expand Up @@ -3441,6 +3447,10 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
case Intrinsic::spv_resource_handlefrombinding: {
return selectHandleFromBinding(ResVReg, ResType, I);
}
case Intrinsic::spv_resource_counterhandlefrombinding:
return selectCounterHandleFromBinding(ResVReg, ResType, I);
case Intrinsic::spv_resource_updatecounter:
return selectUpdateCounter(ResVReg, ResType, I);
case Intrinsic::spv_resource_store_typedbuffer: {
return selectImageWriteIntrinsic(I);
}
Expand Down Expand Up @@ -3479,6 +3489,130 @@ bool SPIRVInstructionSelector::selectHandleFromBinding(Register &ResVReg,
*cast<GIntrinsic>(&I), I);
}

bool SPIRVInstructionSelector::selectCounterHandleFromBinding(
Register &ResVReg, const SPIRVType *ResType, MachineInstr &I) const {
auto &Intr = cast<GIntrinsic>(I);
assert(Intr.getIntrinsicID() ==
Intrinsic::spv_resource_counterhandlefrombinding);

// Extract information from the intrinsic call.
Register MainHandleReg = Intr.getOperand(2).getReg();
auto *MainHandleDef = cast<GIntrinsic>(getVRegDef(*MRI, MainHandleReg));
assert(MainHandleDef->getIntrinsicID() ==
Intrinsic::spv_resource_handlefrombinding);

uint32_t Set = getIConstVal(Intr.getOperand(4).getReg(), MRI);
uint32_t Binding = getIConstVal(Intr.getOperand(3).getReg(), MRI);
uint32_t ArraySize = getIConstVal(MainHandleDef->getOperand(4).getReg(), MRI);
Register IndexReg = MainHandleDef->getOperand(5).getReg();
const bool IsNonUniform = false;
std::string CounterName =
getStringValueFromReg(MainHandleDef->getOperand(6).getReg(), *MRI) +
".counter";

// Create the counter variable.
MachineIRBuilder MIRBuilder(I);
Register CounterVarReg = buildPointerToResource(
GR.getPointeeType(ResType), GR.getPointerStorageClass(ResType), Set,
Binding, ArraySize, IndexReg, IsNonUniform, CounterName, MIRBuilder);

return BuildCOPY(ResVReg, CounterVarReg, I);
}

bool SPIRVInstructionSelector::selectUpdateCounter(Register &ResVReg,
const SPIRVType *ResType,
MachineInstr &I) const {
auto &Intr = cast<GIntrinsic>(I);
assert(Intr.getIntrinsicID() == Intrinsic::spv_resource_updatecounter);

Register CounterHandleReg = Intr.getOperand(2).getReg();
Register IncrReg = Intr.getOperand(3).getReg();

// The counter handle is a pointer to the counter variable (which is a struct
// containing an i32). We need to get a pointer to that i32 member to do the
// atomic operation.
#ifndef NDEBUG
SPIRVType *CounterVarType = GR.getSPIRVTypeForVReg(CounterHandleReg);
SPIRVType *CounterVarPointeeType = GR.getPointeeType(CounterVarType);
assert(CounterVarPointeeType &&
CounterVarPointeeType->getOpcode() == SPIRV::OpTypeStruct &&
"Counter variable must be a struct");
assert(GR.getPointerStorageClass(CounterVarType) ==
SPIRV::StorageClass::StorageBuffer &&
"Counter variable must be in the storage buffer storage class");
assert(CounterVarPointeeType->getNumOperands() == 2 &&
"Counter variable must have exactly 1 member in the struct");
const SPIRVType *MemberType =
GR.getSPIRVTypeForVReg(CounterVarPointeeType->getOperand(1).getReg());
assert(MemberType->getOpcode() == SPIRV::OpTypeInt &&
"Counter variable struct must have a single i32 member");
#endif

// The struct has a single i32 member.
MachineIRBuilder MIRBuilder(I);
const Type *LLVMIntType =
Type::getInt32Ty(I.getMF()->getFunction().getContext());

SPIRVType *IntPtrType = GR.getOrCreateSPIRVPointerType(
LLVMIntType, MIRBuilder, SPIRV::StorageClass::StorageBuffer);

auto Zero = buildI32Constant(0, I);
if (!Zero.second)
return false;

Register PtrToCounter =
MRI->createVirtualRegister(GR.getRegClass(IntPtrType));
if (!BuildMI(*I.getParent(), I, I.getDebugLoc(),
TII.get(SPIRV::OpAccessChain))
.addDef(PtrToCounter)
.addUse(GR.getSPIRVTypeID(IntPtrType))
.addUse(CounterHandleReg)
.addUse(Zero.first)
.constrainAllUses(TII, TRI, RBI)) {
return false;
}

// For UAV/SSBO counters, the scope is Device. The counter variable is not
// used as a flag. So the memory semantics can be None.
auto Scope = buildI32Constant(SPIRV::Scope::Device, I);
if (!Scope.second)
return false;
auto Semantics = buildI32Constant(SPIRV::MemorySemantics::None, I);
if (!Semantics.second)
return false;

int64_t IncrVal = getIConstValSext(IncrReg, MRI);
auto Incr = buildI32Constant(static_cast<uint32_t>(IncrVal), I);
if (!Incr.second)
return false;

Register AtomicRes = MRI->createVirtualRegister(GR.getRegClass(ResType));
if (!BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(SPIRV::OpAtomicIAdd))
.addDef(AtomicRes)
.addUse(GR.getSPIRVTypeID(ResType))
.addUse(PtrToCounter)
.addUse(Scope.first)
.addUse(Semantics.first)
.addUse(Incr.first)
.constrainAllUses(TII, TRI, RBI)) {
return false;
}
if (IncrVal >= 0) {
return BuildCOPY(ResVReg, AtomicRes, I);
}

// In HLSL, IncrementCounter returns the value *before* the increment, while
// DecrementCounter returns the value *after* the decrement. Both are lowered
// to the same atomic intrinsic which returns the value *before* the
// operation. So for decrements (negative IncrVal), we must subtract the
// increment value from the result to get the post-decrement value.
return BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(SPIRV::OpIAddS))
.addDef(ResVReg)
.addUse(GR.getSPIRVTypeID(ResType))
.addUse(AtomicRes)
.addUse(Incr.first)
.constrainAllUses(TII, TRI, RBI);
}
bool SPIRVInstructionSelector::selectReadImageIntrinsic(
Register &ResVReg, const SPIRVType *ResType, MachineInstr &I) const {

Expand Down
135 changes: 95 additions & 40 deletions llvm/lib/Target/SPIRV/SPIRVLegalizeImplicitBinding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ class SPIRVLegalizeImplicitBinding : public ModulePass {
void collectBindingInfo(Module &M);
uint32_t getAndReserveFirstUnusedBinding(uint32_t DescSet);
void replaceImplicitBindingCalls(Module &M);
void replaceResourceHandleCall(Module &M, CallInst *OldCI);
void replaceCounterHandleCall(Module &M, CallInst *OldCI);

// A map from descriptor set to a bit vector of used binding numbers.
std::vector<BitVector> UsedBindings;
Expand All @@ -55,43 +57,61 @@ struct BindingInfoCollector : public InstVisitor<BindingInfoCollector> {
: UsedBindings(UsedBindings), ImplicitBindingCalls(ImplicitBindingCalls) {
}

void addBinding(uint32_t DescSet, uint32_t Binding) {
if (UsedBindings.size() <= DescSet) {
UsedBindings.resize(DescSet + 1);
UsedBindings[DescSet].resize(64);
}
if (UsedBindings[DescSet].size() <= Binding) {
UsedBindings[DescSet].resize(2 * Binding + 1);
}
UsedBindings[DescSet].set(Binding);
}

void visitCallInst(CallInst &CI) {
if (CI.getIntrinsicID() == Intrinsic::spv_resource_handlefrombinding) {
const uint32_t DescSet =
cast<ConstantInt>(CI.getArgOperand(0))->getZExtValue();
const uint32_t Binding =
cast<ConstantInt>(CI.getArgOperand(1))->getZExtValue();

if (UsedBindings.size() <= DescSet) {
UsedBindings.resize(DescSet + 1);
UsedBindings[DescSet].resize(64);
}
if (UsedBindings[DescSet].size() <= Binding) {
UsedBindings[DescSet].resize(2 * Binding + 1);
}
UsedBindings[DescSet].set(Binding);
addBinding(DescSet, Binding);
} else if (CI.getIntrinsicID() ==
Intrinsic::spv_resource_handlefromimplicitbinding) {
ImplicitBindingCalls.push_back(&CI);
} else if (CI.getIntrinsicID() ==
Intrinsic::spv_resource_counterhandlefrombinding) {
const uint32_t DescSet =
cast<ConstantInt>(CI.getArgOperand(2))->getZExtValue();
const uint32_t Binding =
cast<ConstantInt>(CI.getArgOperand(1))->getZExtValue();
addBinding(DescSet, Binding);
} else if (CI.getIntrinsicID() ==
Intrinsic::spv_resource_counterhandlefromimplicitbinding) {
ImplicitBindingCalls.push_back(&CI);
}
}
};

static uint32_t getOrderId(const CallInst *CI) {
switch (CI->getIntrinsicID()) {
case Intrinsic::spv_resource_handlefromimplicitbinding:
return cast<ConstantInt>(CI->getArgOperand(0))->getZExtValue();
case Intrinsic::spv_resource_counterhandlefromimplicitbinding:
return cast<ConstantInt>(CI->getArgOperand(1))->getZExtValue();
default:
llvm_unreachable("CallInst is not an implicit binding intrinsic");
}
}

void SPIRVLegalizeImplicitBinding::collectBindingInfo(Module &M) {
BindingInfoCollector InfoCollector(UsedBindings, ImplicitBindingCalls);
InfoCollector.visit(M);

// Sort the collected calls by their order ID.
std::sort(
ImplicitBindingCalls.begin(), ImplicitBindingCalls.end(),
[](const CallInst *A, const CallInst *B) {
const uint32_t OrderIdArgIdx = 0;
const uint32_t OrderA =
cast<ConstantInt>(A->getArgOperand(OrderIdArgIdx))->getZExtValue();
const uint32_t OrderB =
cast<ConstantInt>(B->getArgOperand(OrderIdArgIdx))->getZExtValue();
return OrderA < OrderB;
});
std::sort(ImplicitBindingCalls.begin(), ImplicitBindingCalls.end(),
[](const CallInst *A, const CallInst *B) {
return getOrderId(A) < getOrderId(B);
});
}

uint32_t SPIRVLegalizeImplicitBinding::getAndReserveFirstUnusedBinding(
Expand All @@ -113,27 +133,15 @@ uint32_t SPIRVLegalizeImplicitBinding::getAndReserveFirstUnusedBinding(

void SPIRVLegalizeImplicitBinding::replaceImplicitBindingCalls(Module &M) {
for (CallInst *OldCI : ImplicitBindingCalls) {
IRBuilder<> Builder(OldCI);
const uint32_t DescSet =
cast<ConstantInt>(OldCI->getArgOperand(1))->getZExtValue();
const uint32_t NewBinding = getAndReserveFirstUnusedBinding(DescSet);

SmallVector<Value *, 8> Args;
Args.push_back(Builder.getInt32(DescSet));
Args.push_back(Builder.getInt32(NewBinding));

// Copy the remaining arguments from the old call.
for (uint32_t i = 2; i < OldCI->arg_size(); ++i) {
Args.push_back(OldCI->getArgOperand(i));
if (OldCI->getIntrinsicID() ==
Intrinsic::spv_resource_handlefromimplicitbinding) {
replaceResourceHandleCall(M, OldCI);
} else {
assert(OldCI->getIntrinsicID() ==
Intrinsic::spv_resource_counterhandlefromimplicitbinding &&
"Unexpected implicit binding intrinsic");
replaceCounterHandleCall(M, OldCI);
}

Function *NewFunc = Intrinsic::getOrInsertDeclaration(
&M, Intrinsic::spv_resource_handlefrombinding, OldCI->getType());
CallInst *NewCI = Builder.CreateCall(NewFunc, Args);
NewCI->setCallingConv(OldCI->getCallingConv());

OldCI->replaceAllUsesWith(NewCI);
OldCI->eraseFromParent();
}
}

Expand All @@ -155,4 +163,51 @@ INITIALIZE_PASS(SPIRVLegalizeImplicitBinding, "legalize-spirv-implicit-binding",

ModulePass *llvm::createSPIRVLegalizeImplicitBindingPass() {
return new SPIRVLegalizeImplicitBinding();
}
}

void SPIRVLegalizeImplicitBinding::replaceResourceHandleCall(Module &M,
CallInst *OldCI) {
IRBuilder<> Builder(OldCI);
const uint32_t DescSet =
cast<ConstantInt>(OldCI->getArgOperand(1))->getZExtValue();
const uint32_t NewBinding = getAndReserveFirstUnusedBinding(DescSet);

SmallVector<Value *, 8> Args;
Args.push_back(Builder.getInt32(DescSet));
Args.push_back(Builder.getInt32(NewBinding));

// Copy the remaining arguments from the old call.
for (uint32_t i = 2; i < OldCI->arg_size(); ++i) {
Args.push_back(OldCI->getArgOperand(i));
}

Function *NewFunc = Intrinsic::getOrInsertDeclaration(
&M, Intrinsic::spv_resource_handlefrombinding, OldCI->getType());
CallInst *NewCI = Builder.CreateCall(NewFunc, Args);
NewCI->setCallingConv(OldCI->getCallingConv());

OldCI->replaceAllUsesWith(NewCI);
OldCI->eraseFromParent();
}

void SPIRVLegalizeImplicitBinding::replaceCounterHandleCall(Module &M,
CallInst *OldCI) {
IRBuilder<> Builder(OldCI);
const uint32_t DescSet =
cast<ConstantInt>(OldCI->getArgOperand(2))->getZExtValue();
const uint32_t NewBinding = getAndReserveFirstUnusedBinding(DescSet);

SmallVector<Value *, 8> Args;
Args.push_back(OldCI->getArgOperand(0));
Args.push_back(Builder.getInt32(NewBinding));
Args.push_back(Builder.getInt32(DescSet));

Type *Tys[] = {OldCI->getType(), OldCI->getArgOperand(0)->getType()};
Function *NewFunc = Intrinsic::getOrInsertDeclaration(
&M, Intrinsic::spv_resource_counterhandlefrombinding, Tys);
CallInst *NewCI = Builder.CreateCall(NewFunc, Args);
NewCI->setCallingConv(OldCI->getCallingConv());

OldCI->replaceAllUsesWith(NewCI);
OldCI->eraseFromParent();
}
6 changes: 6 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,12 @@ uint64_t getIConstVal(Register ConstReg, const MachineRegisterInfo *MRI) {
return MI->getOperand(1).getCImm()->getValue().getZExtValue();
}

int64_t getIConstValSext(Register ConstReg, const MachineRegisterInfo *MRI) {
const MachineInstr *MI = getDefInstrMaybeConstant(ConstReg, MRI);
assert(MI && MI->getOpcode() == TargetOpcode::G_CONSTANT);
return MI->getOperand(1).getCImm()->getSExtValue();
}

bool isSpvIntrinsic(const MachineInstr &MI, Intrinsic::ID IntrinsicID) {
if (const auto *GI = dyn_cast<GIntrinsic>(&MI))
return GI->is(IntrinsicID);
Expand Down
3 changes: 3 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,9 @@ MachineInstr *getDefInstrMaybeConstant(Register &ConstReg,
// Get constant integer value of the given ConstReg.
uint64_t getIConstVal(Register ConstReg, const MachineRegisterInfo *MRI);

// Get constant integer value of the given ConstReg, sign-extended.
int64_t getIConstValSext(Register ConstReg, const MachineRegisterInfo *MRI);

// Check if MI is a SPIR-V specific intrinsic call.
bool isSpvIntrinsic(const MachineInstr &MI, Intrinsic::ID IntrinsicID);
// Check if it's a SPIR-V specific intrinsic call.
Expand Down
Loading