44// See https://llvm.org/LICENSE.txt for license information.
55// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66//
7- // ===----------------------------------------------------------------------===//
7+
88//
99// This pass replaces occurrences of __nvvm_reflect("foo") and llvm.nvvm.reflect
1010// with an integer.
2525#include " llvm/IR/Constants.h"
2626#include " llvm/IR/DerivedTypes.h"
2727#include " llvm/IR/Function.h"
28- #include " llvm/IR/InstIterator.h"
2928#include " llvm/IR/Instructions.h"
3029#include " llvm/IR/Intrinsics.h"
3130#include " llvm/IR/IntrinsicsNVPTX.h"
3938#include " llvm/Transforms/Scalar.h"
4039#include " llvm/Transforms/Utils/BasicBlockUtils.h"
4140#include " llvm/Transforms/Utils/Local.h"
41+ #include " llvm/Transforms/Utils/StripGCRelocates.h"
4242#include < algorithm>
4343#define NVVM_REFLECT_FUNCTION " __nvvm_reflect"
4444#define NVVM_REFLECT_OCL_FUNCTION " __nvvm_reflect_ocl"
4545
4646using namespace llvm ;
4747
48- #define DEBUG_TYPE " nvptx -reflect"
48+ #define DEBUG_TYPE " nvvm -reflect"
4949
5050namespace {
51- class NVVMReflect : public FunctionPass {
51+ class NVVMReflect : public ModulePass {
52+ private:
53+ StringMap<int > VarMap;
54+ // / Process a reflect function by finding all its uses and replacing them with
55+ // / appropriate constant values. For __CUDA_FTZ, uses the module flag value.
56+ // / For __CUDA_ARCH, uses SmVersion * 10. For all other strings, uses 0.
57+ bool handleReflectFunction (Function *F);
58+ void setVarMap (Module &M);
59+
5260public:
5361 static char ID;
54- unsigned int SmVersion;
5562 NVVMReflect () : NVVMReflect(0 ) {}
56- explicit NVVMReflect (unsigned int Sm) : FunctionPass(ID), SmVersion(Sm) {}
57-
58- bool runOnFunction (Function &) override ;
63+ // __CUDA_FTZ is assigned in `runOnModule` by checking nvvm-reflect-ftz module
64+ // metadata.
65+ explicit NVVMReflect (unsigned int Sm) : ModulePass(ID) {
66+ VarMap[" __CUDA_ARCH" ] = Sm * 10 ;
67+ initializeNVVMReflectPass (*PassRegistry::getPassRegistry ());
68+ }
69+ // This mapping will contain should include __CUDA_FTZ and __CUDA_ARCH values.
70+ explicit NVVMReflect (const StringMap<int > &Mapping) : ModulePass(ID), VarMap(Mapping) {
71+ initializeNVVMReflectPass (*PassRegistry::getPassRegistry ());
72+ }
73+ bool runOnModule (Module &M) override ;
5974};
6075} // namespace
6176
62- FunctionPass *llvm::createNVVMReflectPass (unsigned int SmVersion) {
77+ ModulePass *llvm::createNVVMReflectPass (unsigned int SmVersion) {
6378 return new NVVMReflect (SmVersion);
6479}
6580
@@ -72,27 +87,51 @@ INITIALIZE_PASS(NVVMReflect, "nvvm-reflect",
7287 " Replace occurrences of __nvvm_reflect() calls with 0/1" , false ,
7388 false )
7489
75- static bool runNVVMReflect(Function &F, unsigned SmVersion) {
76- if (!NVVMReflectEnabled)
77- return false ;
90+ static cl::list<std::string>
91+ ReflectList(" nvvm-reflect-list" , cl::value_desc(" name=<int>" ), cl::Hidden,
92+ cl::desc(" A list of string=num assignments" ),
93+ cl::ValueRequired);
7894
79- if (F.getName () == NVVM_REFLECT_FUNCTION ||
80- F.getName () == NVVM_REFLECT_OCL_FUNCTION) {
81- assert (F.isDeclaration () && " _reflect function should not have a body" );
82- assert (F.getReturnType ()->isIntegerTy () &&
83- " _reflect's return type should be integer" );
84- return false ;
95+ // / The command line can look as follows :
96+ // / -nvvm-reflect-list a=1,b=2 -nvvm-reflect-list c=3,d=0 -R e=2
97+ // / The strings "a=1,b=2", "c=3,d=0", "e=2" are available in the
98+ // / ReflectList vector. First, each of ReflectList[i] is 'split'
99+ // / using "," as the delimiter. Then each of this part is split
100+ // / using "=" as the delimiter.
101+ void NVVMReflect::setVarMap (Module &M) {
102+ if (auto *Flag = mdconst::extract_or_null<ConstantInt>(
103+ M.getModuleFlag (" nvvm-reflect-ftz" )))
104+ VarMap[" __CUDA_FTZ" ] = Flag->getSExtValue ();
105+
106+ for (unsigned I = 0 , E = ReflectList.size (); I != E; ++I) {
107+ LLVM_DEBUG (dbgs () << " Option : " << ReflectList[I] << " \n " );
108+ SmallVector<StringRef, 4 > NameValList;
109+ StringRef (ReflectList[I]).split (NameValList, " ," );
110+ for (unsigned J = 0 , EJ = NameValList.size (); J != EJ; ++J) {
111+ SmallVector<StringRef, 2 > NameValPair;
112+ NameValList[J].split (NameValPair, " =" );
113+ assert (NameValPair.size () == 2 && " name=val expected" );
114+ StringRef ValStr = NameValPair[1 ].trim ();
115+ int Val;
116+ if (ValStr.getAsInteger (10 , Val))
117+ report_fatal_error (" integer value expected" );
118+ VarMap[NameValPair[0 ]] = Val;
119+ }
85120 }
121+ }
122+
123+ bool NVVMReflect::handleReflectFunction (Function *F) {
124+ // Validate _reflect function
125+ assert (F->isDeclaration () && " _reflect function should not have a body" );
126+ assert (F->getReturnType ()->isIntegerTy () &&
127+ " _reflect's return type should be integer" );
86128
87129 SmallVector<Instruction *, 4 > ToRemove;
88130 SmallVector<Instruction *, 4 > ToSimplify;
89131
90- // Go through the calls in this function. Each call to __nvvm_reflect or
91- // llvm.nvvm.reflect should be a CallInst with a ConstantArray argument.
92- // First validate that. If the c-string corresponding to the ConstantArray can
93- // be found successfully, see if it can be found in VarMap. If so, replace the
94- // uses of CallInst with the value found in VarMap. If not, replace the use
95- // with value 0.
132+ // Go through the uses of the reflect function. Each use should be a CallInst
133+ // with a ConstantArray argument. Replace the uses with the appropriate
134+ // constant values.
96135
97136 // The IR for __nvvm_reflect calls differs between CUDA versions.
98137 //
@@ -113,15 +152,10 @@ static bool runNVVMReflect(Function &F, unsigned SmVersion) {
113152 //
114153 // In this case, we get a Constant with a GlobalVariable operand and we need
115154 // to dig deeper to find its initializer with the string we'll use for lookup.
116- for (Instruction &I : instructions (F)) {
117- CallInst *Call = dyn_cast<CallInst>(&I);
118- if (!Call)
119- continue ;
120- Function *Callee = Call->getCalledFunction ();
121- if (!Callee || (Callee->getName () != NVVM_REFLECT_FUNCTION &&
122- Callee->getName () != NVVM_REFLECT_OCL_FUNCTION &&
123- Callee->getIntrinsicID () != Intrinsic::nvvm_reflect))
124- continue ;
155+
156+ for (User *U : F->users ()) {
157+ assert (isa<CallInst>(U) && " Only a call instruction can use _reflect" );
158+ CallInst *Call = cast<CallInst>(U);
125159
126160 // FIXME: Improve error handling here and elsewhere in this pass.
127161 assert (Call->getNumOperands () == 2 &&
@@ -156,20 +190,15 @@ static bool runNVVMReflect(Function &F, unsigned SmVersion) {
156190 " Format of _reflect function not recognized" );
157191
158192 StringRef ReflectArg = cast<ConstantDataSequential>(Operand)->getAsString ();
193+ // Remove the null terminator from the string
159194 ReflectArg = ReflectArg.substr (0 , ReflectArg.size () - 1 );
160195 LLVM_DEBUG (dbgs () << " Arg of _reflect : " << ReflectArg << " \n " );
161196
162197 int ReflectVal = 0 ; // The default value is 0
163- if (ReflectArg == " __CUDA_FTZ" ) {
164- // Try to pull __CUDA_FTZ from the nvvm-reflect-ftz module flag. Our
165- // choice here must be kept in sync with AutoUpgrade, which uses the same
166- // technique to detect whether ftz is enabled.
167- if (auto *Flag = mdconst::extract_or_null<ConstantInt>(
168- F.getParent ()->getModuleFlag (" nvvm-reflect-ftz" )))
169- ReflectVal = Flag->getSExtValue ();
170- } else if (ReflectArg == " __CUDA_ARCH" ) {
171- ReflectVal = SmVersion * 10 ;
198+ if (VarMap.contains (ReflectArg)) {
199+ ReflectVal = VarMap[ReflectArg];
172200 }
201+ LLVM_DEBUG (dbgs () << " ReflectVal: " << ReflectVal << " \n " );
173202
174203 // If the immediate user is a simple comparison we want to simplify it.
175204 for (User *U : Call->users ())
@@ -185,7 +214,7 @@ static bool runNVVMReflect(Function &F, unsigned SmVersion) {
185214 // until we find a terminator that we can then remove.
186215 while (!ToSimplify.empty ()) {
187216 Instruction *I = ToSimplify.pop_back_val ();
188- if (Constant *C = ConstantFoldInstruction (I, F. getDataLayout ())) {
217+ if (Constant *C = ConstantFoldInstruction (I, F-> getDataLayout ())) {
189218 for (User *U : I->users ())
190219 if (Instruction *I = dyn_cast<Instruction>(U))
191220 ToSimplify.push_back (I);
@@ -202,23 +231,45 @@ static bool runNVVMReflect(Function &F, unsigned SmVersion) {
202231 // Removing via isInstructionTriviallyDead may add duplicates to the ToRemove
203232 // array. Filter out the duplicates before starting to erase from parent.
204233 std::sort (ToRemove.begin (), ToRemove.end ());
205- auto NewLastIter = llvm::unique (ToRemove);
234+ auto * NewLastIter = llvm::unique (ToRemove);
206235 ToRemove.erase (NewLastIter, ToRemove.end ());
207236
208237 for (Instruction *I : ToRemove)
209238 I->eraseFromParent ();
210239
240+ // Remove the __nvvm_reflect function from the module
241+ F->eraseFromParent ();
211242 return ToRemove.size () > 0 ;
212243}
213244
214- bool NVVMReflect::runOnFunction (Function &F) {
215- return runNVVMReflect (F, SmVersion);
216- }
245+ bool NVVMReflect::runOnModule (Module &M) {
246+ if (!NVVMReflectEnabled)
247+ return false ;
248+
249+ setVarMap (M);
217250
218- NVVMReflectPass::NVVMReflectPass () : NVVMReflectPass(0 ) {}
251+ bool Changed = false ;
252+ // Names of reflect function to find and replace
253+ SmallVector<std::string, 3 > ReflectNames = {
254+ NVVM_REFLECT_FUNCTION, NVVM_REFLECT_OCL_FUNCTION,
255+ Intrinsic::getName (Intrinsic::nvvm_reflect).str ()};
256+
257+ // Process all reflect functions
258+ for (const std::string &Name : ReflectNames) {
259+ Function *ReflectFunction = M.getFunction (Name);
260+ if (ReflectFunction) {
261+ Changed |= handleReflectFunction (ReflectFunction);
262+ }
263+ }
264+
265+ return Changed;
266+ }
219267
220- PreservedAnalyses NVVMReflectPass::run (Function &F,
221- FunctionAnalysisManager &AM) {
222- return runNVVMReflect (F, SmVersion) ? PreservedAnalyses::none ()
223- : PreservedAnalyses::all ();
268+ // Implementations for the pass that works with the new pass manager.
269+ NVVMReflectPass::NVVMReflectPass (unsigned SmVersion) {
270+ VarMap[" __CUDA_ARCH" ] = SmVersion * 10 ;
224271}
272+ PreservedAnalyses NVVMReflectPass::run (Module &M, ModuleAnalysisManager &AM) {
273+ return NVVMReflect (VarMap).runOnModule (M) ? PreservedAnalyses::none ()
274+ : PreservedAnalyses::all ();
275+ }
0 commit comments