|
40 | 40 | #include "SPIRVRegularizeLLVM.h" |
41 | 41 | #include "OCLUtil.h" |
42 | 42 | #include "SPIRVInternal.h" |
| 43 | +#include "SPIRVMDWalker.h" |
43 | 44 | #include "libSPIRV/SPIRVDebug.h" |
44 | 45 |
|
45 | 46 | #include "llvm/ADT/StringExtras.h" // llvm::isDigit |
@@ -620,6 +621,7 @@ void prepareCacheControlsTranslation(Metadata *MD, Instruction *Inst) { |
620 | 621 | /// Remove entities not representable by SPIR-V |
621 | 622 | bool SPIRVRegularizeLLVMBase::regularize() { |
622 | 623 | eraseUselessFunctions(M); |
| 624 | + addKernelEntryPoint(M); |
623 | 625 | expandSYCLTypeUsing(M); |
624 | 626 | cleanupConversionToNonStdIntegers(M); |
625 | 627 |
|
@@ -831,6 +833,69 @@ bool SPIRVRegularizeLLVMBase::regularize() { |
831 | 833 | return true; |
832 | 834 | } |
833 | 835 |
|
| 836 | +void SPIRVRegularizeLLVMBase::addKernelEntryPoint(Module *M) { |
| 837 | + std::vector<Function *> Work; |
| 838 | + |
| 839 | + // Get a list of all functions that have SPIR kernel calling conv |
| 840 | + for (auto &F : *M) { |
| 841 | + if (F.getCallingConv() == CallingConv::SPIR_KERNEL) |
| 842 | + Work.push_back(&F); |
| 843 | + } |
| 844 | + for (auto &F : Work) { |
| 845 | + // for declarations just make them into SPIR functions. |
| 846 | + F->setCallingConv(CallingConv::SPIR_FUNC); |
| 847 | + if (F->isDeclaration()) |
| 848 | + continue; |
| 849 | + |
| 850 | + // Otherwise add a wrapper around the function to act as an entry point. |
| 851 | + FunctionType *FType = F->getFunctionType(); |
| 852 | + std::string WrapName = |
| 853 | + kSPIRVName::EntrypointPrefix + static_cast<std::string>(F->getName()); |
| 854 | + Function *WrapFn = |
| 855 | + getOrCreateFunction(M, F->getReturnType(), FType->params(), WrapName); |
| 856 | + |
| 857 | + auto *CallBB = BasicBlock::Create(M->getContext(), "", WrapFn); |
| 858 | + IRBuilder<> Builder(CallBB); |
| 859 | + |
| 860 | + Function::arg_iterator DestI = WrapFn->arg_begin(); |
| 861 | + for (const Argument &I : F->args()) { |
| 862 | + DestI->setName(I.getName()); |
| 863 | + DestI++; |
| 864 | + } |
| 865 | + SmallVector<Value *, 1> Args; |
| 866 | + for (Argument &I : WrapFn->args()) { |
| 867 | + Args.emplace_back(&I); |
| 868 | + } |
| 869 | + auto *CI = CallInst::Create(F, ArrayRef<Value *>(Args), "", CallBB); |
| 870 | + CI->setCallingConv(F->getCallingConv()); |
| 871 | + CI->setAttributes(F->getAttributes()); |
| 872 | + |
| 873 | + // copy over all the metadata (should it be removed from F?) |
| 874 | + SmallVector<std::pair<unsigned, MDNode *>> MDs; |
| 875 | + F->getAllMetadata(MDs); |
| 876 | + WrapFn->setAttributes(F->getAttributes()); |
| 877 | + for (auto MD = MDs.begin(), End = MDs.end(); MD != End; ++MD) { |
| 878 | + WrapFn->addMetadata(MD->first, *MD->second); |
| 879 | + } |
| 880 | + WrapFn->setCallingConv(CallingConv::SPIR_KERNEL); |
| 881 | + WrapFn->setLinkage(llvm::GlobalValue::InternalLinkage); |
| 882 | + |
| 883 | + Builder.CreateRet(F->getReturnType()->isVoidTy() ? nullptr : CI); |
| 884 | + |
| 885 | + // Have to find the spir-v metadata for execution mode and transfer it to |
| 886 | + // the wrapper. |
| 887 | + if (auto NMD = SPIRVMDWalker(*M).getNamedMD(kSPIRVMD::ExecutionMode)) { |
| 888 | + while (!NMD.atEnd()) { |
| 889 | + Function *MDF = nullptr; |
| 890 | + auto N = NMD.nextOp(); /* execution mode MDNode */ |
| 891 | + N.get(MDF); |
| 892 | + if (MDF == F) |
| 893 | + N.M->replaceOperandWith(0, ValueAsMetadata::get(WrapFn)); |
| 894 | + } |
| 895 | + } |
| 896 | + } |
| 897 | +} |
| 898 | + |
834 | 899 | } // namespace SPIRV |
835 | 900 |
|
836 | 901 | INITIALIZE_PASS(SPIRVRegularizeLLVMLegacy, "spvregular", |
|
0 commit comments