4848#include " llvm/Analysis/OptimizationRemarkEmitter.h"
4949#include " llvm/IR/Constants.h"
5050#include " llvm/IR/Function.h"
51+ #include " llvm/IR/IRBuilder.h"
5152#include " llvm/IR/Module.h"
5253#include " llvm/Transforms/Utils/ModuleUtils.h"
5354
@@ -63,7 +64,7 @@ static inline void eraseFromModule(T &ToErase) {
6364 ToErase.eraseFromParent ();
6465}
6566
66- static inline bool checkIfSupported (GlobalVariable &G) {
67+ static bool checkIfSupported (GlobalVariable &G) {
6768 if (!G.isThreadLocal ())
6869 return true ;
6970
@@ -114,24 +115,221 @@ static inline void clearModule(Module &M) { // TODO: simplify.
114115 eraseFromModule (*M.ifuncs ().begin ());
115116}
116117
118+ static SmallVector<std::reference_wrapper<Use>>
119+ collectIndirectableUses (GlobalVariable *G) {
120+ // We are interested only in use chains that end in an Instruction.
121+ SmallVector<std::reference_wrapper<Use>> Uses;
122+
123+ SmallVector<std::reference_wrapper<Use>> Stack (G->use_begin (), G->use_end ());
124+ while (!Stack.empty ()) {
125+ Use &U = Stack.pop_back_val ();
126+ if (isa<Instruction>(U.getUser ()))
127+ Uses.emplace_back (U);
128+ else
129+ transform (U.getUser ()->uses (), std::back_inserter (Stack),
130+ [](auto &&U) { return std::ref (U); });
131+ }
132+
133+ return Uses;
134+ }
135+
136+ static inline GlobalVariable *getGlobalForName (GlobalVariable *G) {
137+ // Create an anonymous global which stores the variable's name, which will be
138+ // used by the HIPSTDPAR runtime to look up the program-wide symbol.
139+ LLVMContext &Ctx = G->getContext ();
140+ auto *CDS = ConstantDataArray::getString (Ctx, G->getName ());
141+
142+ GlobalVariable *N = G->getParent ()->getOrInsertGlobal (" " , CDS->getType ());
143+ N->setInitializer (CDS);
144+ N->setLinkage (GlobalValue::LinkageTypes::PrivateLinkage);
145+ N->setConstant (true );
146+
147+ return N;
148+ }
149+
150+ static inline GlobalVariable *getIndirectionGlobal (Module *M) {
151+ // Create an anonymous global which stores a pointer to a pointer, which will
152+ // be externally initialised by the HIPSTDPAR runtime with the address of the
153+ // program-wide symbol.
154+ Type *PtrTy = PointerType::get (
155+ M->getContext (), M->getDataLayout ().getDefaultGlobalsAddressSpace ());
156+ GlobalVariable *NewG = M->getOrInsertGlobal (" " , PtrTy);
157+
158+ NewG->setInitializer (PoisonValue::get (NewG->getValueType ()));
159+ NewG->setLinkage (GlobalValue::LinkageTypes::PrivateLinkage);
160+ NewG->setConstant (true );
161+ NewG->setExternallyInitialized (true );
162+
163+ return NewG;
164+ }
165+
166+ static Constant *
167+ appendIndirectedGlobal (const GlobalVariable *IndirectionTable,
168+ SmallVector<Constant *> &SymbolIndirections,
169+ GlobalVariable *ToIndirect) {
170+ Module *M = ToIndirect->getParent ();
171+
172+ auto *InitTy = cast<StructType>(IndirectionTable->getValueType ());
173+ auto *SymbolListTy = cast<StructType>(InitTy->getStructElementType (2 ));
174+ Type *NameTy = SymbolListTy->getElementType (0 );
175+ Type *IndirectTy = SymbolListTy->getElementType (1 );
176+
177+ Constant *NameG = getGlobalForName (ToIndirect);
178+ Constant *IndirectG = getIndirectionGlobal (M);
179+ Constant *Entry = ConstantStruct::get (
180+ SymbolListTy, {ConstantExpr::getAddrSpaceCast (NameG, NameTy),
181+ ConstantExpr::getAddrSpaceCast (IndirectG, IndirectTy)});
182+ SymbolIndirections.push_back (Entry);
183+
184+ return IndirectG;
185+ }
186+
187+ static void fillIndirectionTable (GlobalVariable *IndirectionTable,
188+ SmallVector<Constant *> Indirections) {
189+ Module *M = IndirectionTable->getParent ();
190+ size_t SymCnt = Indirections.size ();
191+
192+ auto *InitTy = cast<StructType>(IndirectionTable->getValueType ());
193+ Type *SymbolListTy = InitTy->getStructElementType (1 );
194+ auto *SymbolTy = cast<StructType>(InitTy->getStructElementType (2 ));
195+
196+ Constant *Count = ConstantInt::get (InitTy->getStructElementType (0 ), SymCnt);
197+ M->removeGlobalVariable (IndirectionTable);
198+ GlobalVariable *Symbols =
199+ M->getOrInsertGlobal (" " , ArrayType::get (SymbolTy, SymCnt));
200+ Symbols->setLinkage (GlobalValue::LinkageTypes::PrivateLinkage);
201+ Symbols->setInitializer (
202+ ConstantArray::get (ArrayType::get (SymbolTy, SymCnt), {Indirections}));
203+ Symbols->setConstant (true );
204+
205+ Constant *ASCSymbols = ConstantExpr::getAddrSpaceCast (Symbols, SymbolListTy);
206+ Constant *Init = ConstantStruct::get (
207+ InitTy, {Count, ASCSymbols, PoisonValue::get (SymbolTy)});
208+ M->insertGlobalVariable (IndirectionTable);
209+ IndirectionTable->setInitializer (Init);
210+ }
211+
212+ static void replaceWithIndirectUse (const Use &U, const GlobalVariable *G,
213+ Constant *IndirectedG) {
214+ auto *I = cast<Instruction>(U.getUser ());
215+
216+ IRBuilder<> Builder (I);
217+ unsigned OpIdx = U.getOperandNo ();
218+ Value *Op = I->getOperand (OpIdx);
219+
220+ // We walk back up the use chain, which could be an arbitrarily long sequence
221+ // of constexpr AS casts, ptr-to-int and GEP instructions, until we reach the
222+ // indirected global.
223+ while (auto *CE = dyn_cast<ConstantExpr>(Op)) {
224+ assert ((CE->getOpcode () == Instruction::GetElementPtr ||
225+ CE->getOpcode () == Instruction::AddrSpaceCast ||
226+ CE->getOpcode () == Instruction::PtrToInt) &&
227+ " Only GEP, ASCAST or PTRTOINT constant uses supported!" );
228+
229+ Instruction *NewI = Builder.Insert (CE->getAsInstruction ());
230+ I->replaceUsesOfWith (Op, NewI);
231+ I = NewI;
232+ Op = I->getOperand (0 );
233+ OpIdx = 0 ;
234+ Builder.SetInsertPoint (I);
235+ }
236+
237+ assert (Op == G && " Must reach indirected global!" );
238+
239+ I->setOperand (OpIdx, Builder.CreateLoad (G->getType (), IndirectedG));
240+ }
241+
242+ static inline bool isValidIndirectionTable (GlobalVariable *IndirectionTable) {
243+ std::string W;
244+ raw_string_ostream OS (W);
245+
246+ Type *Ty = IndirectionTable->getValueType ();
247+ bool Valid = false ;
248+
249+ if (!isa<StructType>(Ty)) {
250+ OS << " The Indirection Table must be a struct type; " ;
251+ Ty->print (OS);
252+ OS << " is incorrect.\n " ;
253+ } else if (cast<StructType>(Ty)->getNumElements () != 3u ) {
254+ OS << " The Indirection Table must have 3 elements; "
255+ << cast<StructType>(Ty)->getNumElements () << " is incorrect.\n " ;
256+ } else if (!isa<IntegerType>(cast<StructType>(Ty)->getStructElementType (0 ))) {
257+ OS << " The first element in the Indirection Table must be an integer; " ;
258+ cast<StructType>(Ty)->getStructElementType (0 )->print (OS);
259+ OS << " is incorrect.\n " ;
260+ } else if (!isa<PointerType>(cast<StructType>(Ty)->getStructElementType (1 ))) {
261+ OS << " The second element in the Indirection Table must be a pointer; " ;
262+ cast<StructType>(Ty)->getStructElementType (1 )->print (OS);
263+ OS << " is incorrect.\n " ;
264+ } else if (!isa<StructType>(cast<StructType>(Ty)->getStructElementType (2 ))) {
265+ OS << " The third element in the Indirection Table must be a struct type; " ;
266+ cast<StructType>(Ty)->getStructElementType (2 )->print (OS);
267+ OS << " is incorrect.\n " ;
268+ } else {
269+ Valid = true ;
270+ }
271+
272+ if (!Valid)
273+ IndirectionTable->getContext ().diagnose (DiagnosticInfoGeneric (W, DS_Error));
274+
275+ return Valid;
276+ }
277+
278+ static void indirectGlobals (GlobalVariable *IndirectionTable,
279+ SmallVector<GlobalVariable *> ToIndirect) {
280+ // We replace globals with an indirected access via a pointer that will get
281+ // set by the HIPSTDPAR runtime, using their accessible, program-wide unique
282+ // address as set by the host linker-loader.
283+ SmallVector<Constant *> SymbolIndirections;
284+ for (auto &&G : ToIndirect) {
285+ SmallVector<std::reference_wrapper<Use>> Uses = collectIndirectableUses (G);
286+
287+ if (Uses.empty ())
288+ continue ;
289+
290+ Constant *IndirectedGlobal =
291+ appendIndirectedGlobal (IndirectionTable, SymbolIndirections, G);
292+
293+ for_each (Uses,
294+ [=](auto &&U) { replaceWithIndirectUse (U, G, IndirectedGlobal); });
295+
296+ eraseFromModule (*G);
297+ }
298+
299+ if (SymbolIndirections.empty ())
300+ return ;
301+
302+ fillIndirectionTable (IndirectionTable, std::move (SymbolIndirections));
303+ }
304+
117305static inline void maybeHandleGlobals (Module &M) {
118306 unsigned GlobAS = M.getDataLayout ().getDefaultGlobalsAddressSpace ();
119- for (auto &&G : M.globals ()) { // TODO: should we handle these in the FE?
307+
308+ SmallVector<GlobalVariable *> ToIndirect;
309+ for (auto &&G : M.globals ()) {
120310 if (!checkIfSupported (G))
121311 return clearModule (M);
122-
123- if (G.isThreadLocal ())
124- continue ;
125- if (G.isConstant ())
126- continue ;
127312 if (G.getAddressSpace () != GlobAS)
128313 continue ;
129- if (G.getLinkage () != GlobalVariable::ExternalLinkage )
314+ if (G.isConstant () && G. hasInitializer () && G. hasAtLeastLocalUnnamedAddr () )
130315 continue ;
131316
132- G.setLinkage (GlobalVariable::ExternalWeakLinkage);
133- G.setInitializer (nullptr );
134- G.setExternallyInitialized (true );
317+ ToIndirect.push_back (&G);
318+ }
319+
320+ if (ToIndirect.empty ())
321+ return ;
322+
323+ if (auto *IT = M.getNamedGlobal (" __hipstdpar_symbol_indirection_table" )) {
324+ if (!isValidIndirectionTable (IT))
325+ return clearModule (M);
326+ return indirectGlobals (IT, std::move (ToIndirect));
327+ } else {
328+ for (auto &&G : ToIndirect) {
329+ // We will internalise these, so we provide a poison initialiser.
330+ if (!G->hasInitializer ())
331+ G->setInitializer (PoisonValue::get (G->getValueType ()));
332+ }
135333 }
136334}
137335
0 commit comments