44// See https://llvm.org/LICENSE.txt for license information.
55// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66//
7+ // NVIDIA_COPYRIGHT_BEGIN
8+ //
9+ // Copyright (c) 2023-2025, NVIDIA CORPORATION. All rights reserved.
10+ //
11+ // NVIDIA_COPYRIGHT_END
12+ //
713// ===----------------------------------------------------------------------===//
814//
915// This pass replaces occurrences of __nvvm_reflect("foo") and llvm.nvvm.reflect
3844#include " llvm/Transforms/Scalar.h"
3945#include " llvm/Transforms/Utils/BasicBlockUtils.h"
4046#include " llvm/Transforms/Utils/Local.h"
41- #include " llvm/Transforms/Utils/StripGCRelocates.h"
42- #include < algorithm>
47+ #include " llvm/ADT/StringExtras.h"
4348#define NVVM_REFLECT_FUNCTION " __nvvm_reflect"
4449#define NVVM_REFLECT_OCL_FUNCTION " __nvvm_reflect_ocl"
4550
@@ -54,16 +59,15 @@ class NVVMReflect : public ModulePass {
5459 // / Process a reflect function by finding all its uses and replacing them with
5560 // / appropriate constant values. For __CUDA_FTZ, uses the module flag value.
5661 // / For __CUDA_ARCH, uses SmVersion * 10. For all other strings, uses 0.
57- bool handleReflectFunction (Function *F);
62+ void handleReflectFunction (Function *F);
5863 void setVarMap (Module &M);
59-
64+ void foldReflectCall (CallInst *Call, Constant *NewValue);
6065public:
6166 static char ID;
6267 NVVMReflect () : NVVMReflect(0 ) {}
6368 // __CUDA_FTZ is assigned in `runOnModule` by checking nvvm-reflect-ftz module
6469 // metadata.
65- explicit NVVMReflect (unsigned int Sm) : ModulePass(ID) {
66- VarMap[" __CUDA_ARCH" ] = Sm * 10 ;
70+ explicit NVVMReflect (unsigned SmVersion) : ModulePass(ID), VarMap({{" __CUDA_ARCH" , SmVersion * 10 }}) {
6771 initializeNVVMReflectPass (*PassRegistry::getPassRegistry ());
6872 }
6973 // This mapping will contain should include __CUDA_FTZ and __CUDA_ARCH values.
@@ -87,51 +91,58 @@ INITIALIZE_PASS(NVVMReflect, "nvvm-reflect",
8791 " Replace occurrences of __nvvm_reflect() calls with 0/1" , false ,
8892 false )
8993
94+ // Allow users to specify additional key/value pairs to reflect. These key/value pairs
95+ // are the last to be added to the VarMap, and therefore will take precedence over initial
96+ // values (i.e. __CUDA_FTZ from module medadata and __CUDA_ARCH from SmVersion).
9097static cl::list<std::string>
91- ReflectList(" nvvm-reflect-list" , cl::value_desc(" name=<int>" ), cl::Hidden,
92- cl::desc(" A list of string=num assignments" ),
93- cl::ValueRequired);
94-
95- // / The command line can look as follows :
96- // / -nvvm-reflect-list a=1,b=2 -nvvm-reflect-list c=3,d=0 -R e=2
97- // / The strings "a=1,b=2", "c=3,d=0", "e=2" are available in the
98- // / ReflectList vector. First, each of ReflectList[i] is 'split'
99- // / using "," as the delimiter. Then each of this part is split
100- // / using "=" as the delimiter.
98+ ReflectList(" nvvm-reflect-add" , cl::value_desc(" name=<int>" ), cl::Hidden,
99+ cl::desc(" list of comma-separated key=value pairs" ),
100+ cl::ValueRequired);
101+
102+ // Set the VarMap with, first, the value of __CUDA_FTZ from module metadata, and then
103+ // the key/value pairs from the command line.
101104void NVVMReflect::setVarMap (Module &M) {
105+ LLVM_DEBUG (dbgs () << " Reflect list values:\n " );
106+ for (StringRef Option : ReflectList) {
107+ LLVM_DEBUG (dbgs () << " " << Option << " \n " );
108+ }
102109 if (auto *Flag = mdconst::extract_or_null<ConstantInt>(
103- M.getModuleFlag (" nvvm-reflect-ftz" )))
110+ M.getModuleFlag (" nvvm-reflect-ftz" )))
104111 VarMap[" __CUDA_FTZ" ] = Flag->getSExtValue ();
105112
106- for (unsigned I = 0 , E = ReflectList.size (); I != E; ++I) {
107- LLVM_DEBUG (dbgs () << " Option : " << ReflectList[I] << " \n " );
108- SmallVector<StringRef, 4 > NameValList;
109- StringRef (ReflectList[I]).split (NameValList, " ," );
110- for (unsigned J = 0 , EJ = NameValList.size (); J != EJ; ++J) {
111- SmallVector<StringRef, 2 > NameValPair;
112- NameValList[J].split (NameValPair, " =" );
113- assert (NameValPair.size () == 2 && " name=val expected" );
114- StringRef ValStr = NameValPair[1 ].trim ();
113+ // / The command line can look as follows :
114+ // / -nvvm-reflect-add a=1,b=2 -nvvm-reflect-add c=3,d=0 -nvvm-reflect-add e=2
115+ // / The strings "a=1,b=2", "c=3,d=0", "e=2" are available in the
116+ // / ReflectList vector. First, each of ReflectList[i] is 'split'
117+ // / using "," as the delimiter. Then each of this part is split
118+ // / using "=" as the delimiter.
119+ for (StringRef Option : ReflectList) {
120+ LLVM_DEBUG (dbgs () << " ReflectOption : " << Option << " \n " );
121+ while (!Option.empty ()) {
122+ std::pair<StringRef, StringRef> Split = Option.split (' ,' );
123+ StringRef NameVal = Split.first ;
124+ Option = Split.second ;
125+
126+ auto NameValPair = NameVal.split (' =' );
127+ assert (!NameValPair.first .empty () && !NameValPair.second .empty () &&
128+ " name=val expected" );
129+
115130 int Val;
116- if (ValStr. getAsInteger ( 10 , Val))
131+ if (! to_integer (NameValPair. second . trim () , Val, 10 ))
117132 report_fatal_error (" integer value expected" );
118- VarMap[NameValPair[ 0 ] ] = Val;
133+ VarMap[NameValPair. first ] = Val;
119134 }
120135 }
121136}
122137
123- bool NVVMReflect::handleReflectFunction (Function *F) {
138+ void NVVMReflect::handleReflectFunction (Function *F) {
124139 // Validate _reflect function
125140 assert (F->isDeclaration () && " _reflect function should not have a body" );
126- assert (F->getReturnType ()->isIntegerTy () &&
127- " _reflect's return type should be integer" );
141+ assert (F->getReturnType ()->isIntegerTy () && " _reflect's return type should be integer" );
128142
129- SmallVector<Instruction *, 4 > ToRemove;
130- SmallVector<Instruction *, 4 > ToSimplify;
131143
132144 // Go through the uses of the reflect function. Each use should be a CallInst
133- // with a ConstantArray argument. Replace the uses with the appropriate
134- // constant values.
145+ // with a ConstantArray argument. Replace the uses with the appropriate constant values.
135146
136147 // The IR for __nvvm_reflect calls differs between CUDA versions.
137148 //
@@ -153,7 +164,7 @@ bool NVVMReflect::handleReflectFunction(Function *F) {
153164 // In this case, we get a Constant with a GlobalVariable operand and we need
154165 // to dig deeper to find its initializer with the string we'll use for lookup.
155166
156- for (User *U : F->users ()) {
167+ for (User *U : make_early_inc_range ( F->users () )) {
157168 assert (isa<CallInst>(U) && " Only a call instruction can use _reflect" );
158169 CallInst *Call = cast<CallInst>(U);
159170
@@ -165,21 +176,23 @@ bool NVVMReflect::handleReflectFunction(Function *F) {
165176 // conversion of the string.
166177 const Value *Str = Call->getArgOperand (0 );
167178 if (const CallInst *ConvCall = dyn_cast<CallInst>(Str)) {
168- // FIXME: Add assertions about ConvCall.
179+ // Verify this is the constant-to-generic intrinsic
180+ Function *Callee = ConvCall->getCalledFunction ();
181+ assert (Callee && Callee->isIntrinsic () &&
182+ Callee->getName ().starts_with (" llvm.nvvm.ptr.constant.to.gen" ) &&
183+ " Expected llvm.nvvm.ptr.constant.to.gen intrinsic" );
184+ assert (ConvCall->getNumOperands () == 2 && " Expected one argument for ptr conversion" );
169185 Str = ConvCall->getArgOperand (0 );
170186 }
171187 // Pre opaque pointers we have a constant expression wrapping the constant
172- // string.
173188 Str = Str->stripPointerCasts ();
174- assert (isa<Constant>(Str) &&
175- " Format of __nvvm_reflect function not recognized" );
189+ assert (isa<Constant>(Str) && " Format of __nvvm_reflect function not recognized" );
176190
177191 const Value *Operand = cast<Constant>(Str)->getOperand (0 );
178192 if (const GlobalVariable *GV = dyn_cast<GlobalVariable>(Operand)) {
179193 // For CUDA-7.0 style __nvvm_reflect calls, we need to find the operand's
180194 // initializer.
181- assert (GV->hasInitializer () &&
182- " Format of _reflect function not recognized" );
195+ assert (GV->hasInitializer () && " Format of _reflect function not recognized" );
183196 const Constant *Initializer = GV->getInitializer ();
184197 Operand = Initializer;
185198 }
@@ -192,54 +205,48 @@ bool NVVMReflect::handleReflectFunction(Function *F) {
192205 StringRef ReflectArg = cast<ConstantDataSequential>(Operand)->getAsString ();
193206 // Remove the null terminator from the string
194207 ReflectArg = ReflectArg.substr (0 , ReflectArg.size () - 1 );
195- LLVM_DEBUG (dbgs () << " Arg of _reflect : " << ReflectArg << " \n " );
196208
197209 int ReflectVal = 0 ; // The default value is 0
198210 if (VarMap.contains (ReflectArg)) {
199211 ReflectVal = VarMap[ReflectArg];
200212 }
201- LLVM_DEBUG (dbgs () << " ReflectVal: " << ReflectVal << " \n " );
213+ LLVM_DEBUG (dbgs () << " Replacing call of reflect function " << F->getName () << " (" << ReflectArg << " ) with value " << ReflectVal << " \n " );
214+ Constant *NewValue = ConstantInt::get (Call->getType (), ReflectVal);
215+ foldReflectCall (Call, NewValue);
216+ Call->eraseFromParent ();
217+ }
202218
203- // If the immediate user is a simple comparison we want to simplify it.
204- for (User *U : Call->users ())
205- if (Instruction *I = dyn_cast<Instruction>(U))
206- ToSimplify.push_back (I);
219+ // Remove the __nvvm_reflect function from the module
220+ F->eraseFromParent ();
221+ }
207222
208- Call->replaceAllUsesWith (ConstantInt::get (Call->getType (), ReflectVal));
209- ToRemove.push_back (Call);
223+ void NVVMReflect::foldReflectCall (CallInst *Call, Constant *NewValue) {
224+ // Initialize worklist with all users of the call
225+ SmallVector<Instruction*, 8 > Worklist;
226+ for (User *U : Call->users ()) {
227+ if (Instruction *I = dyn_cast<Instruction>(U)) {
228+ Worklist.push_back (I);
229+ }
210230 }
211231
212- // The code guarded by __nvvm_reflect may be invalid for the target machine.
213- // Traverse the use-def chain, continually simplifying constant expressions
214- // until we find a terminator that we can then remove.
215- while (!ToSimplify.empty ()) {
216- Instruction *I = ToSimplify.pop_back_val ();
217- if (Constant *C = ConstantFoldInstruction (I, F->getDataLayout ())) {
218- for (User *U : I->users ())
219- if (Instruction *I = dyn_cast<Instruction>(U))
220- ToSimplify.push_back (I);
232+ Call->replaceAllUsesWith (NewValue);
221233
222- I->replaceAllUsesWith (C);
223- if (isInstructionTriviallyDead (I)) {
224- ToRemove.push_back (I);
234+ while (!Worklist.empty ()) {
235+ Instruction *I = Worklist.pop_back_val ();
236+ if (Constant *C = ConstantFoldInstruction (I, Call->getModule ()->getDataLayout ())) {
237+ // Add all users of this instruction to the worklist, replace it with the constant
238+ // then delete it if it's dead
239+ for (User *U : I->users ()) {
240+ if (Instruction *UI = dyn_cast<Instruction>(U))
241+ Worklist.push_back (UI);
225242 }
243+ I->replaceAllUsesWith (C);
244+ if (isInstructionTriviallyDead (I))
245+ I->eraseFromParent ();
226246 } else if (I->isTerminator ()) {
227247 ConstantFoldTerminator (I->getParent ());
228248 }
229249 }
230-
231- // Removing via isInstructionTriviallyDead may add duplicates to the ToRemove
232- // array. Filter out the duplicates before starting to erase from parent.
233- std::sort (ToRemove.begin (), ToRemove.end ());
234- auto *NewLastIter = llvm::unique (ToRemove);
235- ToRemove.erase (NewLastIter, ToRemove.end ());
236-
237- for (Instruction *I : ToRemove)
238- I->eraseFromParent ();
239-
240- // Remove the __nvvm_reflect function from the module
241- F->eraseFromParent ();
242- return ToRemove.size () > 0 ;
243250}
244251
245252bool NVVMReflect::runOnModule (Module &M) {
@@ -250,15 +257,19 @@ bool NVVMReflect::runOnModule(Module &M) {
250257
251258 bool Changed = false ;
252259 // Names of reflect function to find and replace
253- SmallVector<std::string, 3 > ReflectNames = {
254- NVVM_REFLECT_FUNCTION, NVVM_REFLECT_OCL_FUNCTION,
255- Intrinsic::getName (Intrinsic::nvvm_reflect).str ()};
260+ SmallVector<StringRef, 5 > ReflectNames = {
261+ NVVM_REFLECT_FUNCTION,
262+ NVVM_REFLECT_OCL_FUNCTION,
263+ Intrinsic::getName (Intrinsic::nvvm_reflect),
264+ };
256265
257266 // Process all reflect functions
258- for (const std::string &Name : ReflectNames) {
259- Function *ReflectFunction = M.getFunction (Name);
260- if (ReflectFunction) {
261- Changed |= handleReflectFunction (ReflectFunction);
267+ for (StringRef Name : ReflectNames) {
268+ if (Function *ReflectFunction = M.getFunction (Name)) {
269+ // If the reflect functition is called, we need to replace the call
270+ // with the appropriate constant, modifying the IR.
271+ Changed |= ReflectFunction->getNumUses () > 0 ;
272+ handleReflectFunction (ReflectFunction);
262273 }
263274 }
264275
@@ -269,7 +280,9 @@ bool NVVMReflect::runOnModule(Module &M) {
269280NVVMReflectPass::NVVMReflectPass (unsigned SmVersion) {
270281 VarMap[" __CUDA_ARCH" ] = SmVersion * 10 ;
271282}
272- PreservedAnalyses NVVMReflectPass::run (Module &M, ModuleAnalysisManager &AM) {
283+
284+ PreservedAnalyses NVVMReflectPass::run (Module &M,
285+ ModuleAnalysisManager &AM) {
273286 return NVVMReflect (VarMap).runOnModule (M) ? PreservedAnalyses::none ()
274- : PreservedAnalyses::all ();
275- }
287+ : PreservedAnalyses::all ();
288+ }
0 commit comments