Skip to content

Commit e0c003b

Browse files
fix indirect calls for function pointers
1 parent 4281f29 commit e0c003b

File tree

6 files changed

+149
-2
lines changed

6 files changed

+149
-2
lines changed

llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -600,6 +600,16 @@ void SPIRVAsmPrinter::outputModuleSections() {
600600
}
601601

602602
bool SPIRVAsmPrinter::doInitialization(Module &M) {
603+
// Discard the internal service function
604+
for (Function &F : M) {
605+
if (!F.getFnAttribute(SPIRV_BACKEND_SERVICE_FUN_NAME).isValid())
606+
continue;
607+
getAnalysis<MachineModuleInfoWrapperPass>()
608+
.getMMI()
609+
.deleteMachineFunctionFor(F);
610+
break;
611+
}
612+
603613
ModuleSectionsEmitted = false;
604614
// We need to call the parent's one explicitly.
605615
return AsmPrinter::doInitialization(M);

llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,10 @@ bool SPIRVCallLowering::lowerReturn(MachineIRBuilder &MIRBuilder,
3636
const Value *Val, ArrayRef<Register> VRegs,
3737
FunctionLoweringInfo &FLI,
3838
Register SwiftErrorVReg) const {
39+
// Discard the internal service function
40+
if (FLI.Fn && FLI.Fn->getFnAttribute(SPIRV_BACKEND_SERVICE_FUN_NAME).isValid())
41+
return true;
42+
3943
// Maybe run postponed production of types for function pointers
4044
if (IndirectCalls.size() > 0) {
4145
produceIndirectPtrTypes(MIRBuilder);
@@ -280,6 +284,10 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
280284
const Function &F,
281285
ArrayRef<ArrayRef<Register>> VRegs,
282286
FunctionLoweringInfo &FLI) const {
287+
// Discard the internal service function
288+
if (F.getFnAttribute(SPIRV_BACKEND_SERVICE_FUN_NAME).isValid())
289+
return true;
290+
283291
assert(GR && "Must initialize the SPIRV type registry before lowering args.");
284292
GR->setCurrentFunc(MIRBuilder.getMF());
285293

llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp

Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,10 @@ class SPIRVEmitIntrinsics
147147
void replaceWithPtrcasted(Instruction *CI, Type *NewElemTy, Type *KnownElemTy,
148148
CallInst *AssignCI);
149149

150+
bool runOnFunction(Function &F);
151+
bool postprocessTypes();
152+
bool processFunctionPointers(Module &M);
153+
150154
public:
151155
static char ID;
152156
SPIRVEmitIntrinsics() : ModulePass(ID) {
@@ -173,8 +177,6 @@ class SPIRVEmitIntrinsics
173177
StringRef getPassName() const override { return "SPIRV emit intrinsics"; }
174178

175179
bool runOnModule(Module &M) override;
176-
bool runOnFunction(Function &F);
177-
bool postprocessTypes();
178180

179181
void getAnalysisUsage(AnalysisUsage &AU) const override {
180182
ModulePass::getAnalysisUsage(AU);
@@ -1825,10 +1827,57 @@ bool SPIRVEmitIntrinsics::runOnModule(Module &M) {
18251827
}
18261828

18271829
Changed |= postprocessTypes();
1830+
Changed |= processFunctionPointers(M);
18281831

18291832
return Changed;
18301833
}
18311834

1835+
bool SPIRVEmitIntrinsics::processFunctionPointers(Module &M) {
1836+
bool IsExt = false;
1837+
SmallVector<Function*> Worklist;
1838+
for (auto &F : M) {
1839+
if (!IsExt) {
1840+
if (!TM->getSubtarget<SPIRVSubtarget>(F).canUseExtension(
1841+
SPIRV::Extension::SPV_INTEL_function_pointers))
1842+
return false;
1843+
IsExt = true;
1844+
}
1845+
if (!F.isDeclaration() || F.isIntrinsic())
1846+
continue;
1847+
for (User *U : F.users()) {
1848+
CallInst *CI = dyn_cast<CallInst>(U);
1849+
if (!CI || CI->getCalledFunction() != &F) {
1850+
Worklist.push_back(&F);
1851+
break;
1852+
}
1853+
}
1854+
}
1855+
if (Worklist.empty())
1856+
return false;
1857+
1858+
std::string ServiceFunName = SPIRV_BACKEND_SERVICE_FUN_NAME;
1859+
if (!getVacantFunctionName(M, ServiceFunName))
1860+
report_fatal_error(
1861+
"cannot allocate a name for the internal service function");
1862+
LLVMContext &Ctx = M.getContext();
1863+
Function *SF =
1864+
Function::Create(FunctionType::get(Type::getVoidTy(Ctx), {}, false),
1865+
GlobalValue::PrivateLinkage, ServiceFunName, M);
1866+
SF->addFnAttr(SPIRV_BACKEND_SERVICE_FUN_NAME, "");
1867+
BasicBlock *BB = BasicBlock::Create(Ctx, "entry", SF);
1868+
IRBuilder<> IRB(BB);
1869+
1870+
for (Function *F : Worklist) {
1871+
SmallVector<Value *> Args;
1872+
for (const auto &Arg : F->args())
1873+
Args.push_back(PoisonValue::get(Arg.getType()));
1874+
IRB.CreateCall(F, Args);
1875+
}
1876+
IRB.CreateRetVoid();
1877+
1878+
return true;
1879+
}
1880+
18321881
ModulePass *llvm::createSPIRVEmitIntrinsicsPass(SPIRVTargetMachine *TM) {
18331882
return new SPIRVEmitIntrinsics(TM);
18341883
}

llvm/lib/Target/SPIRV/SPIRVUtils.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -598,4 +598,18 @@ MachineInstr *getVRegDef(MachineRegisterInfo &MRI, Register Reg) {
598598
return MaybeDef;
599599
}
600600

601+
bool getVacantFunctionName(Module &M, std::string &Name) {
602+
// It's a bit of paranoia, but still we don't want to have even a chance that
603+
// the loop will work for too long.
604+
constexpr unsigned MaxIters = 1024;
605+
for (unsigned I = 0; I < MaxIters; ++I) {
606+
std::string OrdName = Name + Twine(I).str();
607+
if (!M.getFunction(OrdName)) {
608+
Name = OrdName;
609+
return true;
610+
}
611+
}
612+
return false;
613+
}
614+
601615
} // namespace llvm

llvm/lib/Target/SPIRV/SPIRVUtils.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,5 +341,8 @@ inline const Type *unifyPtrType(const Type *Ty) {
341341

342342
MachineInstr *getVRegDef(MachineRegisterInfo &MRI, Register Reg);
343343

344+
#define SPIRV_BACKEND_SERVICE_FUN_NAME "__spirv_backend_service_fun"
345+
bool getVacantFunctionName(Module &M, std::string &Name);
346+
344347
} // namespace llvm
345348
#endif // LLVM_LIB_TARGET_SPIRV_SPIRVUTILS_H
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_INTEL_function_pointers %s -o - | FileCheck %s
2+
; TODO: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
3+
4+
; CHECK: OpFunction
5+
6+
%classid = type { %arrayid }
7+
%arrayid = type { [1 x i64] }
8+
%struct.obj_storage_t = type { %storage }
9+
%storage = type { [8 x i8] }
10+
11+
@_ZTV12IncrementBy8 = linkonce_odr dso_local unnamed_addr addrspace(1) constant { [3 x ptr addrspace(4)] } { [3 x ptr addrspace(4)] [ptr addrspace(4) null, ptr addrspace(4) null, ptr addrspace(4) addrspacecast (ptr @_ZN12IncrementBy89incrementEPi to ptr addrspace(4))] }, align 8
12+
@_ZTV13BaseIncrement = linkonce_odr dso_local unnamed_addr addrspace(1) constant { [3 x ptr addrspace(4)] } { [3 x ptr addrspace(4)] [ptr addrspace(4) null, ptr addrspace(4) null, ptr addrspace(4) addrspacecast (ptr @_ZN13BaseIncrement9incrementEPi to ptr addrspace(4))] }, align 8
13+
@_ZTV12IncrementBy4 = linkonce_odr dso_local unnamed_addr addrspace(1) constant { [3 x ptr addrspace(4)] } { [3 x ptr addrspace(4)] [ptr addrspace(4) null, ptr addrspace(4) null, ptr addrspace(4) addrspacecast (ptr @_ZN12IncrementBy49incrementEPi to ptr addrspace(4))] }, align 8
14+
@_ZTV12IncrementBy2 = linkonce_odr dso_local unnamed_addr addrspace(1) constant { [3 x ptr addrspace(4)] } { [3 x ptr addrspace(4)] [ptr addrspace(4) null, ptr addrspace(4) null, ptr addrspace(4) addrspacecast (ptr @_ZN12IncrementBy29incrementEPi to ptr addrspace(4))] }, align 8
15+
16+
define weak_odr dso_local spir_kernel void @foo(ptr addrspace(1) noundef align 8 %_arg_StorageAcc, ptr noundef byval(%classid) align 8 %_arg_StorageAcc3, i32 noundef %_arg_TestCase, ptr addrspace(1) noundef align 4 %_arg_DataAcc) {
17+
entry:
18+
%0 = load i64, ptr %_arg_StorageAcc3, align 8
19+
%add.ptr.i = getelementptr inbounds %struct.obj_storage_t, ptr addrspace(1) %_arg_StorageAcc, i64 %0
20+
%arrayidx.ascast.i = addrspacecast ptr addrspace(1) %add.ptr.i to ptr addrspace(4)
21+
%cmp.i = icmp ugt i32 %_arg_TestCase, 3
22+
br i1 %cmp.i, label %entry.critedge, label %if.end.1
23+
24+
entry.critedge: ; preds = %entry
25+
%vtable.i.pre = load ptr addrspace(4), ptr addrspace(4) null, align 8
26+
br label %exit
27+
28+
if.end.1: ; preds = %entry
29+
switch i32 %_arg_TestCase, label %if.end.5 [
30+
i32 0, label %if.end.2
31+
i32 1, label %if.end.3
32+
i32 2, label %if.end.4
33+
]
34+
35+
if.end.5: ; preds = %if.end.1
36+
store ptr addrspace(1) getelementptr inbounds inrange(-16, 8) (i8, ptr addrspace(1) @_ZTV12IncrementBy8, i64 16), ptr addrspace(1) %add.ptr.i, align 8
37+
br label %exit
38+
39+
if.end.4: ; preds = %if.end.1
40+
store ptr addrspace(1) getelementptr inbounds inrange(-16, 8) (i8, ptr addrspace(1) @_ZTV12IncrementBy4, i64 16), ptr addrspace(1) %add.ptr.i, align 8
41+
br label %exit
42+
43+
if.end.3: ; preds = %if.end.1
44+
store ptr addrspace(1) getelementptr inbounds inrange(-16, 8) (i8, ptr addrspace(1) @_ZTV12IncrementBy2, i64 16), ptr addrspace(1) %add.ptr.i, align 8
45+
br label %exit
46+
47+
if.end.2: ; preds = %if.end.1
48+
store ptr addrspace(1) getelementptr inbounds inrange(-16, 8) (i8, ptr addrspace(1) @_ZTV13BaseIncrement, i64 16), ptr addrspace(1) %add.ptr.i, align 8
49+
br label %exit
50+
51+
exit: ; preds = %if.end.2, %if.end.3, %if.end.4, %if.end.5, %entry.critedge
52+
%vtable.i = phi ptr addrspace(4) [ %vtable.i.pre, %entry.critedge ], [ inttoptr (i64 ptrtoint (ptr addrspace(1) getelementptr inbounds inrange(-16, 8) (i8, ptr addrspace(1) @_ZTV12IncrementBy8, i64 16) to i64) to ptr addrspace(4)), %if.end.5 ], [ inttoptr (i64 ptrtoint (ptr addrspace(1) getelementptr inbounds inrange(-16, 8) (i8, ptr addrspace(1) @_ZTV12IncrementBy4, i64 16) to i64) to ptr addrspace(4)), %if.end.4 ], [ inttoptr (i64 ptrtoint (ptr addrspace(1) getelementptr inbounds inrange(-16, 8) (i8, ptr addrspace(1) @_ZTV12IncrementBy2, i64 16) to i64) to ptr addrspace(4)), %if.end.3 ], [ inttoptr (i64 ptrtoint (ptr addrspace(1) getelementptr inbounds inrange(-16, 8) (i8, ptr addrspace(1) @_ZTV13BaseIncrement, i64 16) to i64) to ptr addrspace(4)), %if.end.2 ]
53+
%retval.0.i = phi ptr addrspace(4) [ null, %entry.critedge ], [ %arrayidx.ascast.i, %if.end.5 ], [ %arrayidx.ascast.i, %if.end.4 ], [ %arrayidx.ascast.i, %if.end.3 ], [ %arrayidx.ascast.i, %if.end.2 ]
54+
%1 = addrspacecast ptr addrspace(1) %_arg_DataAcc to ptr addrspace(4)
55+
%2 = load ptr addrspace(4), ptr addrspace(4) %vtable.i, align 8
56+
tail call spir_func addrspace(4) void %2(ptr addrspace(4) noundef align 8 dereferenceable_or_null(8) %retval.0.i, ptr addrspace(4) noundef %1)
57+
ret void
58+
}
59+
60+
declare dso_local spir_func void @_ZN13BaseIncrement9incrementEPi(ptr addrspace(4) noundef align 8 dereferenceable_or_null(8), ptr addrspace(4) noundef)
61+
declare dso_local spir_func void @_ZN12IncrementBy29incrementEPi(ptr addrspace(4) noundef align 8 dereferenceable_or_null(8), ptr addrspace(4) noundef)
62+
declare dso_local spir_func void @_ZN12IncrementBy49incrementEPi(ptr addrspace(4) noundef align 8 dereferenceable_or_null(8), ptr addrspace(4) noundef)
63+
declare dso_local spir_func void @_ZN12IncrementBy89incrementEPi(ptr addrspace(4) noundef align 8 dereferenceable_or_null(8), ptr addrspace(4) noundef)

0 commit comments

Comments
 (0)