Skip to content

Commit 86aa271

Browse files
committed
Add an entry point wrapper around functions (llvm pass)
SPIR-V spec states: "It is invalid for any function to be targeted by both an OpEntryPoint instruction and an OpFunctionCall instruction." In order to satisfy SPIR-V that entrypoints and functions must be different, this introduces an entrypoint wrapper around functions at the LLVM IR level, then fixes up a few things like naming at the SPIRV translation. Original commit: KhronosGroup/SPIRV-LLVM-Translator@85815e7
1 parent 3e06221 commit 86aa271

24 files changed

+156
-39
lines changed

llvm-spirv/lib/SPIRV/SPIRVInternal.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,7 @@ const static char ConvertHandleToImageINTEL[] = "ConvertHandleToImageINTEL";
369369
const static char ConvertHandleToSamplerINTEL[] = "ConvertHandleToSamplerINTEL";
370370
const static char ConvertHandleToSampledImageINTEL[] =
371371
"ConvertHandleToSampledImageINTEL";
372+
const static char EntrypointPrefix[] = "__spirv_entry_";
372373
} // namespace kSPIRVName
373374

374375
namespace kSPIRVPostfix {

llvm-spirv/lib/SPIRV/SPIRVReader.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2976,6 +2976,23 @@ Function *SPIRVToLLVM::transFunction(SPIRVFunction *BF) {
29762976
return Loc->second;
29772977

29782978
auto IsKernel = isKernel(BF);
2979+
if (IsKernel) {
2980+
// search for a previous function with the same name
2981+
// upgrade it to a kernel and drop this if it's found
2982+
for (auto &I : FuncMap) {
2983+
auto BFName = I.getFirst()->getName();
2984+
if (BF->getName() == BFName) {
2985+
auto *F = I.getSecond();
2986+
F->setCallingConv(CallingConv::SPIR_KERNEL);
2987+
F->setLinkage(GlobalValue::ExternalLinkage);
2988+
F->setDSOLocal(false);
2989+
F = cast<Function>(mapValue(BF, F));
2990+
mapFunction(BF, F);
2991+
return F;
2992+
}
2993+
}
2994+
}
2995+
29792996
auto Linkage = IsKernel ? GlobalValue::ExternalLinkage : transLinkageType(BF);
29802997
FunctionType *FT = cast<FunctionType>(transType(BF->getFunctionType()));
29812998
std::string FuncName = BF->getName();

llvm-spirv/lib/SPIRV/SPIRVRegularizeLLVM.cpp

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
#include "SPIRVRegularizeLLVM.h"
4141
#include "OCLUtil.h"
4242
#include "SPIRVInternal.h"
43+
#include "SPIRVMDWalker.h"
4344
#include "libSPIRV/SPIRVDebug.h"
4445

4546
#include "llvm/ADT/StringExtras.h" // llvm::isDigit
@@ -620,6 +621,7 @@ void prepareCacheControlsTranslation(Metadata *MD, Instruction *Inst) {
620621
/// Remove entities not representable by SPIR-V
621622
bool SPIRVRegularizeLLVMBase::regularize() {
622623
eraseUselessFunctions(M);
624+
addKernelEntryPoint(M);
623625
expandSYCLTypeUsing(M);
624626
cleanupConversionToNonStdIntegers(M);
625627

@@ -831,6 +833,69 @@ bool SPIRVRegularizeLLVMBase::regularize() {
831833
return true;
832834
}
833835

836+
void SPIRVRegularizeLLVMBase::addKernelEntryPoint(Module *M) {
837+
std::vector<Function *> Work;
838+
839+
// Get a list of all functions that have SPIR kernel calling conv
840+
for (auto &F : *M) {
841+
if (F.getCallingConv() == CallingConv::SPIR_KERNEL)
842+
Work.push_back(&F);
843+
}
844+
for (auto &F : Work) {
845+
// for declarations just make them into SPIR functions.
846+
F->setCallingConv(CallingConv::SPIR_FUNC);
847+
if (F->isDeclaration())
848+
continue;
849+
850+
// Otherwise add a wrapper around the function to act as an entry point.
851+
FunctionType *FType = F->getFunctionType();
852+
std::string WrapName =
853+
kSPIRVName::EntrypointPrefix + static_cast<std::string>(F->getName());
854+
Function *WrapFn =
855+
getOrCreateFunction(M, F->getReturnType(), FType->params(), WrapName);
856+
857+
auto *CallBB = BasicBlock::Create(M->getContext(), "", WrapFn);
858+
IRBuilder<> Builder(CallBB);
859+
860+
Function::arg_iterator DestI = WrapFn->arg_begin();
861+
for (const Argument &I : F->args()) {
862+
DestI->setName(I.getName());
863+
DestI++;
864+
}
865+
SmallVector<Value *, 1> Args;
866+
for (Argument &I : WrapFn->args()) {
867+
Args.emplace_back(&I);
868+
}
869+
auto *CI = CallInst::Create(F, ArrayRef<Value *>(Args), "", CallBB);
870+
CI->setCallingConv(F->getCallingConv());
871+
CI->setAttributes(F->getAttributes());
872+
873+
// copy over all the metadata (should it be removed from F?)
874+
SmallVector<std::pair<unsigned, MDNode *>> MDs;
875+
F->getAllMetadata(MDs);
876+
WrapFn->setAttributes(F->getAttributes());
877+
for (auto MD = MDs.begin(), End = MDs.end(); MD != End; ++MD) {
878+
WrapFn->addMetadata(MD->first, *MD->second);
879+
}
880+
WrapFn->setCallingConv(CallingConv::SPIR_KERNEL);
881+
WrapFn->setLinkage(llvm::GlobalValue::InternalLinkage);
882+
883+
Builder.CreateRet(F->getReturnType()->isVoidTy() ? nullptr : CI);
884+
885+
// Have to find the spir-v metadata for execution mode and transfer it to
886+
// the wrapper.
887+
if (auto NMD = SPIRVMDWalker(*M).getNamedMD(kSPIRVMD::ExecutionMode)) {
888+
while (!NMD.atEnd()) {
889+
Function *MDF = nullptr;
890+
auto N = NMD.nextOp(); /* execution mode MDNode */
891+
N.get(MDF);
892+
if (MDF == F)
893+
N.M->replaceOperandWith(0, ValueAsMetadata::get(WrapFn));
894+
}
895+
}
896+
}
897+
}
898+
834899
} // namespace SPIRV
835900

836901
INITIALIZE_PASS(SPIRVRegularizeLLVMLegacy, "spvregular",

llvm-spirv/lib/SPIRV/SPIRVRegularizeLLVM.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,11 @@ class SPIRVRegularizeLLVMBase {
5151
// Lower functions
5252
bool regularize();
5353

54+
// SPIR-V disallows functions being entrypoints and called
55+
// LLVM doesn't. This adds a wrapper around the entry point
56+
// that later SPIR-V writer renames.
57+
void addKernelEntryPoint(Module *M);
58+
5459
/// Some LLVM intrinsics that have no SPIR-V counterpart may be wrapped in
5560
/// @spirv.llvm_intrinsic_* function. During reverse translation from SPIR-V
5661
/// to LLVM IR we can detect this @spirv.llvm_intrinsic_* function and

llvm-spirv/lib/SPIRV/SPIRVWriter.cpp

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -875,13 +875,19 @@ SPIRVFunction *LLVMToSPIRVBase::transFunctionDecl(Function *F) {
875875
static_cast<SPIRVFunction *>(mapValue(F, BM->addFunction(BFT)));
876876
BF->setFunctionControlMask(transFunctionControlMask(F));
877877
if (F->hasName()) {
878-
if (isUniformGroupOperation(F))
879-
BM->getErrorLog().checkError(
880-
BM->isAllowedToUseExtension(
881-
ExtensionID::SPV_KHR_uniform_group_instructions),
882-
SPIRVEC_RequiresExtension, "SPV_KHR_uniform_group_instructions\n");
883-
884-
BM->setName(BF, F->getName().str());
878+
if (isKernel(F)) {
879+
/* strip the prefix as the runtime will be looking for this name */
880+
std::string Prefix = kSPIRVName::EntrypointPrefix;
881+
std::string Name = F->getName().str();
882+
BM->setName(BF, Name.substr(Prefix.size()));
883+
} else {
884+
if (isUniformGroupOperation(F))
885+
BM->getErrorLog().checkError(
886+
BM->isAllowedToUseExtension(
887+
ExtensionID::SPV_KHR_uniform_group_instructions),
888+
SPIRVEC_RequiresExtension, "SPV_KHR_uniform_group_instructions\n");
889+
BM->setName(BF, F->getName().str());
890+
}
885891
}
886892
if (!isKernel(F) && F->getLinkage() != GlobalValue::InternalLinkage)
887893
BF->setLinkageType(transLinkageType(F));
@@ -5721,7 +5727,7 @@ void LLVMToSPIRVBase::transFunction(Function *I) {
57215727

57225728
if (isKernel(I)) {
57235729
auto Interface = collectEntryPointInterfaces(BF, I);
5724-
BM->addEntryPoint(ExecutionModelKernel, BF->getId(), I->getName().str(),
5730+
BM->addEntryPoint(ExecutionModelKernel, BF->getId(), BF->getName(),
57255731
Interface);
57265732
}
57275733
}
@@ -6088,8 +6094,9 @@ bool LLVMToSPIRVBase::transMetadata() {
60886094
// Work around to translate kernel_arg_type and kernel_arg_type_qual metadata
60896095
static void transKernelArgTypeMD(SPIRVModule *BM, Function *F, MDNode *MD,
60906096
std::string MDName) {
6091-
std::string KernelArgTypesMDStr =
6092-
std::string(MDName) + "." + F->getName().str() + ".";
6097+
std::string Prefix = kSPIRVName::EntrypointPrefix;
6098+
std::string Name = F->getName().str().substr(Prefix.size());
6099+
std::string KernelArgTypesMDStr = std::string(MDName) + "." + Name + ".";
60936100
for (const auto &TyOp : MD->operands())
60946101
KernelArgTypesMDStr += cast<MDString>(TyOp)->getString().str() + ",";
60956102
BM->getString(KernelArgTypesMDStr);
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
;; Test to check that an LLVM spir_kernel gets translated into an
2+
;; Entrypoint wrapper and Function with LinkageAttributes
3+
; RUN: llvm-as %s -o %t.bc
4+
; RUN: llvm-spirv %t.bc -o - -spirv-text | FileCheck %s --check-prefix=CHECK-SPIRV
5+
; RUN: llvm-spirv %t.bc -o %t.spv
6+
; RUN: spirv-val %t.spv
7+
8+
target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
9+
target triple = "spir64-unknown-unknown"
10+
11+
define spir_kernel void @testfunction() {
12+
ret void
13+
}
14+
15+
; Check there is an entrypoint and a function produced.
16+
; CHECK-SPIRV: EntryPoint 6 [[EP:[0-9]+]] "testfunction"
17+
; CHECK-SPIRV: Name [[FUNC:[0-9]+]] "testfunction"
18+
; CHECK-SPIRV: Decorate [[FUNC]] LinkageAttributes "testfunction" Export
19+
; CHECK-SPIRV: Function 2 [[FUNC]] 0 3
20+
; CHECK-SPIRV: Function 2 [[EP]] 0 3
21+
; CHECK-SPIRV: FunctionCall 2 8 [[FUNC]]

llvm-spirv/test/extensions/INTEL/SPV_INTEL_function_pointers/alias.ll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ target triple = "spir64-unknown-unknown"
1212
; when used since they can't be translated directly.
1313

1414
; CHECK-SPIRV-DAG: Name [[#FOO:]] "foo"
15-
; CHECK-SPIRV-DAG: EntryPoint [[#]] [[#BAR:]] "bar"
15+
; CHECK-SPIRV-DAG: Name [[#BAR:]] "bar"
1616
; CHECK-SPIRV-DAG: Name [[#Y:]] "y"
1717
; CHECK-SPIRV-DAG: Name [[#FOOPTR:]] "foo.alias"
1818
; CHECK-SPIRV-DAG: Decorate [[#FOO]] LinkageAttributes "foo" Export

llvm-spirv/test/extensions/INTEL/SPV_INTEL_function_pointers/fp-from-host.ll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
; CHECK-SPIRV: Capability FunctionPointersINTEL
1818
; CHECK-SPIRV: Extension "SPV_INTEL_function_pointers"
1919
;
20-
; CHECK-SPIRV: EntryPoint [[#]] [[KERNEL_ID:[0-9]+]] "test"
20+
; CHECK-SPIRV: Name [[KERNEL_ID:[0-9]+]] "test"
2121
; CHECK-SPIRV: TypeInt [[INT32_TYPE_ID:[0-9]+]] 32
2222
; CHECK-SPIRV: TypePointer [[INT_PTR:[0-9]+]] 5 [[INT32_TYPE_ID]]
2323
; CHECK-SPIRV: TypeFunction [[FOO_TYPE_ID:[0-9]+]] [[INT32_TYPE_ID]] [[INT32_TYPE_ID]]

llvm-spirv/test/extensions/INTEL/SPV_INTEL_function_pointers/function-pointer-as-function-arg.ll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
; CHECK-SPIRV: Capability FunctionPointersINTEL
3434
; CHECK-SPIRV: Extension "SPV_INTEL_function_pointers"
3535
;
36-
; CHECK-SPIRV: EntryPoint [[#]] [[KERNEL_ID:[0-9]+]] "test"
36+
; CHECK-SPIRV: Name [[KERNEL_ID:[0-9]+]] "test"
3737
; CHECK-SPIRV: TypeInt [[TYPE_INT32_ID:[0-9]+]] 32
3838
; CHECK-SPIRV: TypeFunction [[FOO_TYPE_ID:[0-9]+]] [[TYPE_INT32_ID]] [[TYPE_INT32_ID]]
3939
; CHECK-SPIRV: TypePointer [[FOO_PTR_TYPE_ID:[0-9]+]] {{[0-9]+}} [[FOO_TYPE_ID]]

llvm-spirv/test/extensions/INTEL/SPV_INTEL_function_pointers/function-pointer.ll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
;
2020
; CHECK-SPIRV: Capability FunctionPointersINTEL
2121
; CHECK-SPIRV: Extension "SPV_INTEL_function_pointers"
22-
; CHECK-SPIRV: EntryPoint [[#]] [[KERNEL_ID:[0-9]+]] "test"
22+
; CHECK-SPIRV: Name [[KERNEL_ID:[0-9]+]] "test"
2323
; CHECK-SPIRV: TypeInt [[TYPE_INT_ID:[0-9]+]]
2424
; CHECK-SPIRV: TypeFunction [[FOO_TYPE_ID:[0-9]+]] [[TYPE_INT_ID]] [[TYPE_INT_ID]]
2525
; CHECK-SPIRV: TypePointer [[FOO_PTR_ID:[0-9]+]] {{[0-9]+}} [[FOO_TYPE_ID]]

0 commit comments

Comments
 (0)