4141#include " llvm/ADT/StringExtras.h"
4242#define NVVM_REFLECT_FUNCTION " __nvvm_reflect"
4343#define NVVM_REFLECT_OCL_FUNCTION " __nvvm_reflect_ocl"
44+ // Argument of reflect call to retrive arch number
45+ #define CUDA_ARCH_NAME " __CUDA_ARCH"
46+ // Argument of reflect call to retrive ftz mode
47+ #define CUDA_FTZ_NAME " __CUDA_FTZ"
48+ // Name of module metadata where ftz mode is stored
49+ #define CUDA_FTZ_MODULE_NAME " nvvm-reflect-ftz"
4450
4551using namespace llvm ;
4652
4753#define DEBUG_TYPE " nvvm-reflect"
4854
55+ namespace llvm {
56+ void initializeNVVMReflectLegacyPassPass (PassRegistry &);
57+ }
58+
4959namespace {
50- class NVVMReflect : public ModulePass {
60+ class NVVMReflect {
5161private:
52- StringMap<int > VarMap;
53- void handleReflectFunction (Function *F);
54- void setVarMap (Module &M);
62+ // Map from reflect function call arguments to the value to replace the call with.
63+ // Should include __CUDA_FTZ and __CUDA_ARCH values.
64+ StringMap<int > ReflectMap;
65+ bool handleReflectFunction (Module &M, StringRef ReflectName);
66+ void populateReflectMap (Module &M);
5567 void foldReflectCall (CallInst *Call, Constant *NewValue);
5668public:
57- static char ID;
58- NVVMReflect () : NVVMReflect(0 ) {}
5969 // __CUDA_FTZ is assigned in `runOnModule` by checking nvvm-reflect-ftz module
6070 // metadata.
61- explicit NVVMReflect (unsigned SmVersion) : ModulePass(ID), VarMap({{" __CUDA_ARCH" , SmVersion * 10 }}) {
62- initializeNVVMReflectPass (*PassRegistry::getPassRegistry ());
63- }
64- // This mapping will contain should include __CUDA_FTZ and __CUDA_ARCH values.
65- explicit NVVMReflect (const StringMap<int > &Mapping) : ModulePass(ID), VarMap(Mapping) {
66- initializeNVVMReflectPass (*PassRegistry::getPassRegistry ());
71+ explicit NVVMReflect (unsigned SmVersion) : ReflectMap({{CUDA_ARCH_NAME, SmVersion * 10 }}) {}
72+ bool runOnModule (Module &M);
73+ };
74+ } // namespace
75+
76+ class NVVMReflectLegacyPass : public ModulePass {
77+ private:
78+ NVVMReflect Impl;
79+ public:
80+ static char ID;
81+ NVVMReflectLegacyPass (unsigned SmVersion) : ModulePass(ID), Impl(SmVersion) {
82+ initializeNVVMReflectLegacyPassPass (*PassRegistry::getPassRegistry ());
6783 }
6884 bool runOnModule (Module &M) override ;
6985};
70- } // namespace
7186
7287ModulePass *llvm::createNVVMReflectPass (unsigned int SmVersion) {
73- return new NVVMReflect (SmVersion);
88+ LLVM_DEBUG (dbgs () << " Creating NVVMReflectPass with SM version " << SmVersion << " \n " );
89+ return new NVVMReflectLegacyPass (SmVersion);
7490}
7591
7692static cl::opt<bool >
7793 NVVMReflectEnabled (" nvvm-reflect-enable" , cl::init(true ), cl::Hidden,
7894 cl::desc(" NVVM reflection, enabled by default" ));
7995
80- char NVVMReflect ::ID = 0 ;
81- INITIALIZE_PASS (NVVMReflect , " nvvm-reflect" ,
96+ char NVVMReflectLegacyPass ::ID = 0 ;
97+ INITIALIZE_PASS (NVVMReflectLegacyPass , " nvvm-reflect" ,
8298 " Replace occurrences of __nvvm_reflect() calls with 0/1" , false ,
8399 false )
84100
@@ -92,10 +108,10 @@ ReflectList("nvvm-reflect-add", cl::value_desc("name=<int>"), cl::Hidden,
92108
93109// Set the VarMap with, first, the value of __CUDA_FTZ from module metadata, and then
94110// the key/value pairs from the command line.
95- void NVVMReflect::setVarMap (Module &M) {
111+ void NVVMReflect::populateReflectMap (Module &M) {
96112 if (auto *Flag = mdconst::extract_or_null<ConstantInt>(
97- M.getModuleFlag (" nvvm-reflect-ftz " )))
98- VarMap[ " __CUDA_FTZ " ] = Flag->getSExtValue ();
113+ M.getModuleFlag (CUDA_FTZ_MODULE_NAME )))
114+ ReflectMap[CUDA_FTZ_NAME ] = Flag->getSExtValue ();
99115
100116 for (StringRef Option : ReflectList) {
101117 LLVM_DEBUG (dbgs () << " ReflectOption : " << Option << " \n " );
@@ -107,94 +123,52 @@ void NVVMReflect::setVarMap(Module &M) {
107123 int ValInt;
108124 if (!to_integer (Val.trim (), ValInt, 10 ))
109125 report_fatal_error (" integer value expected in nvvm-reflect-list option '" + Option + " '" );
110- VarMap [Name] = ValInt;
126+ ReflectMap [Name] = ValInt;
111127 }
112128}
113129
114- // / Process a reflect function by finding all its uses and replacing them with
130+ // / Process a reflect function by finding all its calls and replacing them with
115131// / appropriate constant values. For __CUDA_FTZ, uses the module flag value.
116132// / For __CUDA_ARCH, uses SmVersion * 10. For all other strings, uses 0.
117- void NVVMReflect::handleReflectFunction (Function *F) {
118- // Validate _reflect function
133+ bool NVVMReflect::handleReflectFunction (Module &M, StringRef ReflectName) {
134+ Function *F = M.getFunction (ReflectName);
135+ if (!F)
136+ return false ;
119137 assert (F->isDeclaration () && " _reflect function should not have a body" );
120138 assert (F->getReturnType ()->isIntegerTy () && " _reflect's return type should be integer" );
121139
122-
123- // Go through the uses of the reflect function. Each use should be a CallInst
124- // with a ConstantArray argument. Replace the uses with the appropriate constant values.
125-
126- // The IR for __nvvm_reflect calls differs between CUDA versions.
127- //
128- // CUDA 6.5 and earlier uses this sequence:
129- // %ptr = tail call i8* @llvm.nvvm.ptr.constant.to.gen.p0i8.p4i8
130- // (i8 addrspace(4)* getelementptr inbounds
131- // ([8 x i8], [8 x i8] addrspace(4)* @str, i32 0, i32 0))
132- // %reflect = tail call i32 @__nvvm_reflect(i8* %ptr)
133- //
134- // The value returned by Sym->getOperand(0) is a Constant with a
135- // ConstantDataSequential operand which can be converted to string and used
136- // for lookup.
137- //
138- // CUDA 7.0 does it slightly differently:
139- // %reflect = call i32 @__nvvm_reflect(i8* addrspacecast
140- // (i8 addrspace(1)* getelementptr inbounds
141- // ([8 x i8], [8 x i8] addrspace(1)* @str, i32 0, i32 0) to i8*))
142- //
143- // In this case, we get a Constant with a GlobalVariable operand and we need
144- // to dig deeper to find its initializer with the string we'll use for lookup.
145-
140+ bool Changed = F->getNumUses () > 0 ;
146141 for (User *U : make_early_inc_range (F->users ())) {
142+ // Reflect function calls look like:
143+ // @arch = private unnamed_addr addrspace(1) constant [12 x i8] c"__CUDA_ARCH\00"
144+ // call i32 @__nvvm_reflect(ptr addrspacecast (ptr addrspace(1) @arch to ptr))
145+ // We need to extract the string argument from the call (i.e. "__CUDA_ARCH")
147146 if (!isa<CallInst>(U))
148147 report_fatal_error (" __nvvm_reflect can only be used in a call instruction" );
149148 CallInst *Call = cast<CallInst>(U);
150-
151149 if (Call->getNumOperands () != 2 )
152150 report_fatal_error (" __nvvm_reflect requires exactly one argument" );
153151
154- // In cuda 6.5 and earlier, we will have an extra constant-to-generic
155- // conversion of the string.
156- const Value *Str = Call->getArgOperand (0 );
157- if (const CallInst *ConvCall = dyn_cast<CallInst>(Str)) {
158- // Verify this is the constant-to-generic intrinsic
159- Function *Callee = ConvCall->getCalledFunction ();
160- if (!Callee || !Callee->isIntrinsic () ||
161- !Callee->getName ().starts_with (" llvm.nvvm.ptr.constant.to.gen" ))
162- report_fatal_error (" Expected llvm.nvvm.ptr.constant.to.gen intrinsic" );
163- if (ConvCall->getNumOperands () != 2 )
164- report_fatal_error (" Expected one argument for ptr conversion" );
165- Str = ConvCall->getArgOperand (0 );
166- }
167- // Pre opaque pointers we have a constant expression wrapping the constant
168- Str = Str->stripPointerCasts ();
169- if (!isa<Constant>(Str))
152+ const Value *GlobalStr = Call->getArgOperand (0 )->stripPointerCasts ();
153+ if (!isa<Constant>(GlobalStr))
170154 report_fatal_error (" __nvvm_reflect argument must be a constant string" );
171155
172- const Value *Operand = cast<Constant>(Str)->getOperand (0 );
173- if (const GlobalVariable *GV = dyn_cast<GlobalVariable>(Operand)) {
174- // For CUDA-7.0 style __nvvm_reflect calls, we need to find the operand's
175- // initializer.
176- if (!GV->hasInitializer ())
177- report_fatal_error (" __nvvm_reflect string must have an initializer" );
178- const Constant *Initializer = GV->getInitializer ();
179- Operand = Initializer;
180- }
181-
182- if (!isa<ConstantDataSequential>(Operand))
156+ const Value *ConstantStr = cast<Constant>(GlobalStr)->getOperand (0 );
157+ if (!isa<ConstantDataSequential>(ConstantStr))
183158 report_fatal_error (" __nvvm_reflect argument must be a string constant" );
184- if (!cast<ConstantDataSequential>(Operand )->isCString ())
159+ if (!cast<ConstantDataSequential>(ConstantStr )->isCString ())
185160 report_fatal_error (" __nvvm_reflect argument must be a null-terminated string" );
186161
187- StringRef ReflectArg = cast<ConstantDataSequential>(Operand )->getAsString ();
162+ StringRef ReflectArg = cast<ConstantDataSequential>(ConstantStr )->getAsString ();
188163 // Remove the null terminator from the string
189164 ReflectArg = ReflectArg.substr (0 , ReflectArg.size () - 1 );
190-
191165 if (ReflectArg.empty ())
192166 report_fatal_error (" __nvvm_reflect argument cannot be empty" );
193-
167+ // Now that we have extracted the string argument, we can look it up in the VarMap
194168 int ReflectVal = 0 ; // The default value is 0
195- if (VarMap .contains (ReflectArg)) {
196- ReflectVal = VarMap [ReflectArg];
197- }
169+ if (ReflectMap .contains (ReflectArg))
170+ ReflectVal = ReflectMap [ReflectArg];
171+
198172 LLVM_DEBUG (dbgs () << " Replacing call of reflect function " << F->getName () << " (" << ReflectArg << " ) with value " << ReflectVal << " \n " );
199173 Constant *NewValue = ConstantInt::get (Call->getType (), ReflectVal);
200174 foldReflectCall (Call, NewValue);
@@ -203,29 +177,26 @@ void NVVMReflect::handleReflectFunction(Function *F) {
203177
204178 // Remove the __nvvm_reflect function from the module
205179 F->eraseFromParent ();
180+ return Changed;
206181}
207182
208183void NVVMReflect::foldReflectCall (CallInst *Call, Constant *NewValue) {
209- // Initialize worklist with all users of the call
210184 SmallVector<Instruction*, 8 > Worklist;
211- for (User *U : Call->users ()) {
212- if (Instruction *I = dyn_cast<Instruction>(U)) {
213- Worklist.push_back (I);
185+ // Replace an instruction with a constant and add all users of the instruction to the worklist
186+ auto ReplaceInstructionWithConst = [&](Instruction *I, Constant *C) {
187+ for (User *U : I->users ()) {
188+ if (Instruction *UI = dyn_cast<Instruction>(U))
189+ Worklist.push_back (UI);
214190 }
215- }
191+ I->replaceAllUsesWith (C);
192+ };
216193
217- Call-> replaceAllUsesWith ( NewValue);
194+ ReplaceInstructionWithConst (Call, NewValue);
218195
219196 while (!Worklist.empty ()) {
220197 Instruction *I = Worklist.pop_back_val ();
221198 if (Constant *C = ConstantFoldInstruction (I, Call->getModule ()->getDataLayout ())) {
222- // Add all users of this instruction to the worklist, replace it with the constant
223- // then delete it if it's dead
224- for (User *U : I->users ()) {
225- if (Instruction *UI = dyn_cast<Instruction>(U))
226- Worklist.push_back (UI);
227- }
228- I->replaceAllUsesWith (C);
199+ ReplaceInstructionWithConst (I, C);
229200 if (isInstructionTriviallyDead (I))
230201 I->eraseFromParent ();
231202 } else if (I->isTerminator ()) {
@@ -237,37 +208,20 @@ void NVVMReflect::foldReflectCall(CallInst *Call, Constant *NewValue) {
237208bool NVVMReflect::runOnModule (Module &M) {
238209 if (!NVVMReflectEnabled)
239210 return false ;
240-
241- setVarMap (M);
242-
243- bool Changed = false ;
244- // Names of reflect function to find and replace
245- SmallVector<StringRef, 5 > ReflectNames = {
246- NVVM_REFLECT_FUNCTION,
247- NVVM_REFLECT_OCL_FUNCTION,
248- Intrinsic::getName (Intrinsic::nvvm_reflect),
249- };
250-
251- // Process all reflect functions
252- for (StringRef Name : ReflectNames) {
253- if (Function *ReflectFunction = M.getFunction (Name)) {
254- // If the reflect functition is called, we need to replace the call
255- // with the appropriate constant, modifying the IR.
256- Changed |= ReflectFunction->getNumUses () > 0 ;
257- handleReflectFunction (ReflectFunction);
258- }
259- }
260-
211+ populateReflectMap (M);
212+ bool Changed = true ;
213+ handleReflectFunction (M, NVVM_REFLECT_FUNCTION);
214+ handleReflectFunction (M, NVVM_REFLECT_OCL_FUNCTION);
215+ handleReflectFunction (M, Intrinsic::getName (Intrinsic::nvvm_reflect));
261216 return Changed;
262217}
263218
264- // Implementations for the pass that works with the new pass manager.
265- NVVMReflectPass::NVVMReflectPass (unsigned SmVersion) {
266- VarMap[" __CUDA_ARCH" ] = SmVersion * 10 ;
219+ bool NVVMReflectLegacyPass::runOnModule (Module &M) {
220+ return Impl.runOnModule (M);
267221}
268222
269223PreservedAnalyses NVVMReflectPass::run (Module &M,
270224 ModuleAnalysisManager &AM) {
271- return NVVMReflect (VarMap ).runOnModule (M) ? PreservedAnalyses::none ()
225+ return NVVMReflect (SmVersion ).runOnModule (M) ? PreservedAnalyses::none ()
272226 : PreservedAnalyses::all ();
273227}
0 commit comments