1111// ===----------------------------------------------------------------------===//
1212
1313#include " NVPTXCtorDtorLowering.h"
14+ #include " MCTargetDesc/NVPTXBaseInfo.h"
1415#include " NVPTX.h"
1516#include " llvm/ADT/StringExtras.h"
1617#include " llvm/IR/Constants.h"
@@ -32,6 +33,11 @@ static cl::opt<std::string>
3233 cl::desc (" Override unique ID of ctor/dtor globals." ),
3334 cl::init(" " ), cl::Hidden);
3435
36+ static cl::opt<bool >
37+ CreateKernels (" nvptx-emit-init-fini-kernel" ,
38+ cl::desc (" Emit kernels to call ctor/dtor globals." ),
39+ cl::init(true ), cl::Hidden);
40+
3541namespace {
3642
3743static std::string getHash (StringRef Str) {
@@ -42,11 +48,163 @@ static std::string getHash(StringRef Str) {
4248 return llvm::utohexstr (Hash.low (), /* LowerCase=*/ true );
4349}
4450
45- static bool createInitOrFiniGlobls (Module &M, StringRef GlobalName,
46- bool IsCtor) {
47- GlobalVariable *GV = M.getGlobalVariable (GlobalName);
48- if (!GV || !GV->hasInitializer ())
49- return false ;
51+ static void addKernelMetadata (Module &M, GlobalValue *GV) {
52+ llvm::LLVMContext &Ctx = M.getContext ();
53+
54+ // Get "nvvm.annotations" metadata node.
55+ llvm::NamedMDNode *MD = M.getOrInsertNamedMetadata (" nvvm.annotations" );
56+
57+ llvm::Metadata *KernelMDVals[] = {
58+ llvm::ConstantAsMetadata::get (GV), llvm::MDString::get (Ctx, " kernel" ),
59+ llvm::ConstantAsMetadata::get (
60+ llvm::ConstantInt::get (llvm::Type::getInt32Ty (Ctx), 1 ))};
61+
62+ // This kernel is only to be called single-threaded.
63+ llvm::Metadata *ThreadXMDVals[] = {
64+ llvm::ConstantAsMetadata::get (GV), llvm::MDString::get (Ctx, " maxntidx" ),
65+ llvm::ConstantAsMetadata::get (
66+ llvm::ConstantInt::get (llvm::Type::getInt32Ty (Ctx), 1 ))};
67+ llvm::Metadata *ThreadYMDVals[] = {
68+ llvm::ConstantAsMetadata::get (GV), llvm::MDString::get (Ctx, " maxntidy" ),
69+ llvm::ConstantAsMetadata::get (
70+ llvm::ConstantInt::get (llvm::Type::getInt32Ty (Ctx), 1 ))};
71+ llvm::Metadata *ThreadZMDVals[] = {
72+ llvm::ConstantAsMetadata::get (GV), llvm::MDString::get (Ctx, " maxntidz" ),
73+ llvm::ConstantAsMetadata::get (
74+ llvm::ConstantInt::get (llvm::Type::getInt32Ty (Ctx), 1 ))};
75+
76+ llvm::Metadata *BlockMDVals[] = {
77+ llvm::ConstantAsMetadata::get (GV),
78+ llvm::MDString::get (Ctx, " maxclusterrank" ),
79+ llvm::ConstantAsMetadata::get (
80+ llvm::ConstantInt::get (llvm::Type::getInt32Ty (Ctx), 1 ))};
81+
82+ // Append metadata to nvvm.annotations.
83+ MD->addOperand (llvm::MDNode::get (Ctx, KernelMDVals));
84+ MD->addOperand (llvm::MDNode::get (Ctx, ThreadXMDVals));
85+ MD->addOperand (llvm::MDNode::get (Ctx, ThreadYMDVals));
86+ MD->addOperand (llvm::MDNode::get (Ctx, ThreadZMDVals));
87+ MD->addOperand (llvm::MDNode::get (Ctx, BlockMDVals));
88+ }
89+
90+ static Function *createInitOrFiniKernelFunction (Module &M, bool IsCtor) {
91+ StringRef InitOrFiniKernelName =
92+ IsCtor ? " nvptx$device$init" : " nvptx$device$fini" ;
93+ if (M.getFunction (InitOrFiniKernelName))
94+ return nullptr ;
95+
96+ Function *InitOrFiniKernel = Function::createWithDefaultAttr (
97+ FunctionType::get (Type::getVoidTy (M.getContext ()), false ),
98+ GlobalValue::WeakODRLinkage, 0 , InitOrFiniKernelName, &M);
99+ addKernelMetadata (M, InitOrFiniKernel);
100+
101+ return InitOrFiniKernel;
102+ }
103+
104+ // We create the IR required to call each callback in this section. This is
105+ // equivalent to the following code. Normally, the linker would provide us with
106+ // the definitions of the init and fini array sections. The 'nvlink' linker does
107+ // not do this so initializing these values is done by the runtime.
108+ //
109+ // extern "C" void **__init_array_start = nullptr;
110+ // extern "C" void **__init_array_end = nullptr;
111+ // extern "C" void **__fini_array_start = nullptr;
112+ // extern "C" void **__fini_array_end = nullptr;
113+ //
114+ // using InitCallback = void();
115+ // using FiniCallback = void();
116+ //
117+ // void call_init_array_callbacks() {
118+ // for (auto start = __init_array_start; start != __init_array_end; ++start)
119+ // reinterpret_cast<InitCallback *>(*start)();
120+ // }
121+ //
122+ // void call_init_array_callbacks() {
123+ // size_t fini_array_size = __fini_array_end - __fini_array_start;
124+ // for (size_t i = fini_array_size; i > 0; --i)
125+ // reinterpret_cast<FiniCallback *>(__fini_array_start[i - 1])();
126+ // }
127+ static void createInitOrFiniCalls (Function &F, bool IsCtor) {
128+ Module &M = *F.getParent ();
129+ LLVMContext &C = M.getContext ();
130+
131+ IRBuilder<> IRB (BasicBlock::Create (C, " entry" , &F));
132+ auto *LoopBB = BasicBlock::Create (C, " while.entry" , &F);
133+ auto *ExitBB = BasicBlock::Create (C, " while.end" , &F);
134+ Type *PtrTy = IRB.getPtrTy (llvm::ADDRESS_SPACE_GLOBAL);
135+
136+ auto *Begin = M.getOrInsertGlobal (
137+ IsCtor ? " __init_array_start" : " __fini_array_start" ,
138+ PointerType::get (C, 0 ), [&]() {
139+ auto *GV = new GlobalVariable (
140+ M, PointerType::get (C, 0 ),
141+ /* isConstant=*/ false , GlobalValue::WeakAnyLinkage,
142+ Constant::getNullValue (PointerType::get (C, 0 )),
143+ IsCtor ? " __init_array_start" : " __fini_array_start" ,
144+ /* InsertBefore=*/ nullptr , GlobalVariable::NotThreadLocal,
145+ /* AddressSpace=*/ llvm::ADDRESS_SPACE_GLOBAL);
146+ GV->setVisibility (GlobalVariable::ProtectedVisibility);
147+ return GV;
148+ });
149+ auto *End = M.getOrInsertGlobal (
150+ IsCtor ? " __init_array_end" : " __fini_array_end" , PointerType::get (C, 0 ),
151+ [&]() {
152+ auto *GV = new GlobalVariable (
153+ M, PointerType::get (C, 0 ),
154+ /* isConstant=*/ false , GlobalValue::WeakAnyLinkage,
155+ Constant::getNullValue (PointerType::get (C, 0 )),
156+ IsCtor ? " __init_array_end" : " __fini_array_end" ,
157+ /* InsertBefore=*/ nullptr , GlobalVariable::NotThreadLocal,
158+ /* AddressSpace=*/ llvm::ADDRESS_SPACE_GLOBAL);
159+ GV->setVisibility (GlobalVariable::ProtectedVisibility);
160+ return GV;
161+ });
162+
163+ // The constructor type is suppoed to allow using the argument vectors, but
164+ // for now we just call them with no arguments.
165+ auto *CallBackTy = FunctionType::get (IRB.getVoidTy (), {});
166+
167+ // The destructor array must be called in reverse order. Get an expression to
168+ // the end of the array and iterate backwards in that case.
169+ Value *BeginVal = IRB.CreateLoad (Begin->getType (), Begin, " begin" );
170+ Value *EndVal = IRB.CreateLoad (Begin->getType (), End, " stop" );
171+ if (!IsCtor) {
172+ auto *BeginInt = IRB.CreatePtrToInt (BeginVal, IntegerType::getInt64Ty (C));
173+ auto *EndInt = IRB.CreatePtrToInt (EndVal, IntegerType::getInt64Ty (C));
174+ auto *SubInst = IRB.CreateSub (EndInt, BeginInt);
175+ auto *Offset = IRB.CreateAShr (
176+ SubInst, ConstantInt::get (IntegerType::getInt64Ty (C), 3 ), " offset" ,
177+ /* IsExact=*/ true );
178+ auto *ValuePtr = IRB.CreateGEP (PointerType::get (C, 0 ), BeginVal,
179+ ArrayRef<Value *>({Offset}));
180+ EndVal = BeginVal;
181+ BeginVal = IRB.CreateInBoundsGEP (
182+ PointerType::get (C, 0 ), ValuePtr,
183+ ArrayRef<Value *>(ConstantInt::get (IntegerType::getInt64Ty (C), -1 )),
184+ " start" );
185+ }
186+ IRB.CreateCondBr (
187+ IRB.CreateCmp (IsCtor ? ICmpInst::ICMP_NE : ICmpInst::ICMP_UGT, BeginVal,
188+ EndVal),
189+ LoopBB, ExitBB);
190+ IRB.SetInsertPoint (LoopBB);
191+ auto *CallBackPHI = IRB.CreatePHI (PtrTy, 2 , " ptr" );
192+ auto *CallBack = IRB.CreateLoad (CallBackTy->getPointerTo (F.getAddressSpace ()),
193+ CallBackPHI, " callback" );
194+ IRB.CreateCall (CallBackTy, CallBack);
195+ auto *NewCallBack =
196+ IRB.CreateConstGEP1_64 (PtrTy, CallBackPHI, IsCtor ? 1 : -1 , " next" );
197+ auto *EndCmp = IRB.CreateCmp (IsCtor ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_ULT,
198+ NewCallBack, EndVal, " end" );
199+ CallBackPHI->addIncoming (BeginVal, &F.getEntryBlock ());
200+ CallBackPHI->addIncoming (NewCallBack, LoopBB);
201+ IRB.CreateCondBr (EndCmp, ExitBB, LoopBB);
202+ IRB.SetInsertPoint (ExitBB);
203+ IRB.CreateRetVoid ();
204+ }
205+
206+ static bool createInitOrFiniGlobals (Module &M, GlobalVariable *GV,
207+ bool IsCtor) {
50208 ConstantArray *GA = dyn_cast<ConstantArray>(GV->getInitializer ());
51209 if (!GA || GA->getNumOperands () == 0 )
52210 return false ;
@@ -81,14 +239,35 @@ static bool createInitOrFiniGlobls(Module &M, StringRef GlobalName,
81239 appendToUsed (M, {GV});
82240 }
83241
242+ return true ;
243+ }
244+
245+ static bool createInitOrFiniKernel (Module &M, StringRef GlobalName,
246+ bool IsCtor) {
247+ GlobalVariable *GV = M.getGlobalVariable (GlobalName);
248+ if (!GV || !GV->hasInitializer ())
249+ return false ;
250+
251+ if (!createInitOrFiniGlobals (M, GV, IsCtor))
252+ return false ;
253+
254+ if (!CreateKernels)
255+ return true ;
256+
257+ Function *InitOrFiniKernel = createInitOrFiniKernelFunction (M, IsCtor);
258+ if (!InitOrFiniKernel)
259+ return false ;
260+
261+ createInitOrFiniCalls (*InitOrFiniKernel, IsCtor);
262+
84263 GV->eraseFromParent ();
85264 return true ;
86265}
87266
88267static bool lowerCtorsAndDtors (Module &M) {
89268 bool Modified = false ;
90- Modified |= createInitOrFiniGlobls (M, " llvm.global_ctors" , /* IsCtor =*/ true );
91- Modified |= createInitOrFiniGlobls (M, " llvm.global_dtors" , /* IsCtor =*/ false );
269+ Modified |= createInitOrFiniKernel (M, " llvm.global_ctors" , /* IsCtor =*/ true );
270+ Modified |= createInitOrFiniKernel (M, " llvm.global_dtors" , /* IsCtor =*/ false );
92271 return Modified;
93272}
94273
0 commit comments