Skip to content

Commit f167f2c

Browse files
committed
reflect improvements
1 parent db16866 commit f167f2c

File tree

2 files changed

+124
-85
lines changed

2 files changed

+124
-85
lines changed

llvm/lib/Target/NVPTX/NVVMReflect.cpp

Lines changed: 98 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,12 @@
44
// See https://llvm.org/LICENSE.txt for license information.
55
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66
//
7+
// NVIDIA_COPYRIGHT_BEGIN
8+
//
9+
// Copyright (c) 2023-2025, NVIDIA CORPORATION. All rights reserved.
10+
//
11+
// NVIDIA_COPYRIGHT_END
12+
//
713
//===----------------------------------------------------------------------===//
814
//
915
// This pass replaces occurrences of __nvvm_reflect("foo") and llvm.nvvm.reflect
@@ -38,8 +44,7 @@
3844
#include "llvm/Transforms/Scalar.h"
3945
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
4046
#include "llvm/Transforms/Utils/Local.h"
41-
#include "llvm/Transforms/Utils/StripGCRelocates.h"
42-
#include <algorithm>
47+
#include "llvm/ADT/StringExtras.h"
4348
#define NVVM_REFLECT_FUNCTION "__nvvm_reflect"
4449
#define NVVM_REFLECT_OCL_FUNCTION "__nvvm_reflect_ocl"
4550

@@ -54,16 +59,15 @@ class NVVMReflect : public ModulePass {
5459
/// Process a reflect function by finding all its uses and replacing them with
5560
/// appropriate constant values. For __CUDA_FTZ, uses the module flag value.
5661
/// For __CUDA_ARCH, uses SmVersion * 10. For all other strings, uses 0.
57-
bool handleReflectFunction(Function *F);
62+
void handleReflectFunction(Function *F);
5863
void setVarMap(Module &M);
59-
64+
void foldReflectCall(CallInst *Call, Constant *NewValue);
6065
public:
6166
static char ID;
6267
NVVMReflect() : NVVMReflect(0) {}
6368
// __CUDA_FTZ is assigned in `runOnModule` by checking nvvm-reflect-ftz module
6469
// metadata.
65-
explicit NVVMReflect(unsigned int Sm) : ModulePass(ID) {
66-
VarMap["__CUDA_ARCH"] = Sm * 10;
70+
explicit NVVMReflect(unsigned SmVersion) : ModulePass(ID), VarMap({{"__CUDA_ARCH", SmVersion * 10}}) {
6771
initializeNVVMReflectPass(*PassRegistry::getPassRegistry());
6872
}
6973
// This mapping will contain should include __CUDA_FTZ and __CUDA_ARCH values.
@@ -87,51 +91,58 @@ INITIALIZE_PASS(NVVMReflect, "nvvm-reflect",
8791
"Replace occurrences of __nvvm_reflect() calls with 0/1", false,
8892
false)
8993

94+
// Allow users to specify additional key/value pairs to reflect. These key/value pairs
95+
// are the last to be added to the VarMap, and therefore will take precedence over initial
96+
// values (i.e. __CUDA_FTZ from module medadata and __CUDA_ARCH from SmVersion).
9097
static cl::list<std::string>
91-
ReflectList("nvvm-reflect-list", cl::value_desc("name=<int>"), cl::Hidden,
92-
cl::desc("A list of string=num assignments"),
93-
cl::ValueRequired);
94-
95-
/// The command line can look as follows :
96-
/// -nvvm-reflect-list a=1,b=2 -nvvm-reflect-list c=3,d=0 -R e=2
97-
/// The strings "a=1,b=2", "c=3,d=0", "e=2" are available in the
98-
/// ReflectList vector. First, each of ReflectList[i] is 'split'
99-
/// using "," as the delimiter. Then each of this part is split
100-
/// using "=" as the delimiter.
98+
ReflectList("nvvm-reflect-add", cl::value_desc("name=<int>"), cl::Hidden,
99+
cl::desc("list of comma-separated key=value pairs"),
100+
cl::ValueRequired);
101+
102+
// Set the VarMap with, first, the value of __CUDA_FTZ from module metadata, and then
103+
// the key/value pairs from the command line.
101104
void NVVMReflect::setVarMap(Module &M) {
105+
LLVM_DEBUG(dbgs() << "Reflect list values:\n");
106+
for (StringRef Option : ReflectList) {
107+
LLVM_DEBUG(dbgs() << " " << Option << "\n");
108+
}
102109
if (auto *Flag = mdconst::extract_or_null<ConstantInt>(
103-
M.getModuleFlag("nvvm-reflect-ftz")))
110+
M.getModuleFlag("nvvm-reflect-ftz")))
104111
VarMap["__CUDA_FTZ"] = Flag->getSExtValue();
105112

106-
for (unsigned I = 0, E = ReflectList.size(); I != E; ++I) {
107-
LLVM_DEBUG(dbgs() << "Option : " << ReflectList[I] << "\n");
108-
SmallVector<StringRef, 4> NameValList;
109-
StringRef(ReflectList[I]).split(NameValList, ",");
110-
for (unsigned J = 0, EJ = NameValList.size(); J != EJ; ++J) {
111-
SmallVector<StringRef, 2> NameValPair;
112-
NameValList[J].split(NameValPair, "=");
113-
assert(NameValPair.size() == 2 && "name=val expected");
114-
StringRef ValStr = NameValPair[1].trim();
113+
/// The command line can look as follows :
114+
/// -nvvm-reflect-add a=1,b=2 -nvvm-reflect-add c=3,d=0 -nvvm-reflect-add e=2
115+
/// The strings "a=1,b=2", "c=3,d=0", "e=2" are available in the
116+
/// ReflectList vector. First, each of ReflectList[i] is 'split'
117+
/// using "," as the delimiter. Then each of this part is split
118+
/// using "=" as the delimiter.
119+
for (StringRef Option : ReflectList) {
120+
LLVM_DEBUG(dbgs() << "ReflectOption : " << Option << "\n");
121+
while (!Option.empty()) {
122+
std::pair<StringRef, StringRef> Split = Option.split(',');
123+
StringRef NameVal = Split.first;
124+
Option = Split.second;
125+
126+
auto NameValPair = NameVal.split('=');
127+
assert(!NameValPair.first.empty() && !NameValPair.second.empty() &&
128+
"name=val expected");
129+
115130
int Val;
116-
if (ValStr.getAsInteger(10, Val))
131+
if (!to_integer(NameValPair.second.trim(), Val, 10))
117132
report_fatal_error("integer value expected");
118-
VarMap[NameValPair[0]] = Val;
133+
VarMap[NameValPair.first] = Val;
119134
}
120135
}
121136
}
122137

123-
bool NVVMReflect::handleReflectFunction(Function *F) {
138+
void NVVMReflect::handleReflectFunction(Function *F) {
124139
// Validate _reflect function
125140
assert(F->isDeclaration() && "_reflect function should not have a body");
126-
assert(F->getReturnType()->isIntegerTy() &&
127-
"_reflect's return type should be integer");
141+
assert(F->getReturnType()->isIntegerTy() && "_reflect's return type should be integer");
128142

129-
SmallVector<Instruction *, 4> ToRemove;
130-
SmallVector<Instruction *, 4> ToSimplify;
131143

132144
// Go through the uses of the reflect function. Each use should be a CallInst
133-
// with a ConstantArray argument. Replace the uses with the appropriate
134-
// constant values.
145+
// with a ConstantArray argument. Replace the uses with the appropriate constant values.
135146

136147
// The IR for __nvvm_reflect calls differs between CUDA versions.
137148
//
@@ -153,7 +164,7 @@ bool NVVMReflect::handleReflectFunction(Function *F) {
153164
// In this case, we get a Constant with a GlobalVariable operand and we need
154165
// to dig deeper to find its initializer with the string we'll use for lookup.
155166

156-
for (User *U : F->users()) {
167+
for (User *U : make_early_inc_range(F->users())) {
157168
assert(isa<CallInst>(U) && "Only a call instruction can use _reflect");
158169
CallInst *Call = cast<CallInst>(U);
159170

@@ -165,21 +176,23 @@ bool NVVMReflect::handleReflectFunction(Function *F) {
165176
// conversion of the string.
166177
const Value *Str = Call->getArgOperand(0);
167178
if (const CallInst *ConvCall = dyn_cast<CallInst>(Str)) {
168-
// FIXME: Add assertions about ConvCall.
179+
// Verify this is the constant-to-generic intrinsic
180+
Function *Callee = ConvCall->getCalledFunction();
181+
assert(Callee && Callee->isIntrinsic() &&
182+
Callee->getName().starts_with("llvm.nvvm.ptr.constant.to.gen") &&
183+
"Expected llvm.nvvm.ptr.constant.to.gen intrinsic");
184+
assert(ConvCall->getNumOperands() == 2 && "Expected one argument for ptr conversion");
169185
Str = ConvCall->getArgOperand(0);
170186
}
171187
// Pre opaque pointers we have a constant expression wrapping the constant
172-
// string.
173188
Str = Str->stripPointerCasts();
174-
assert(isa<Constant>(Str) &&
175-
"Format of __nvvm_reflect function not recognized");
189+
assert(isa<Constant>(Str) && "Format of __nvvm_reflect function not recognized");
176190

177191
const Value *Operand = cast<Constant>(Str)->getOperand(0);
178192
if (const GlobalVariable *GV = dyn_cast<GlobalVariable>(Operand)) {
179193
// For CUDA-7.0 style __nvvm_reflect calls, we need to find the operand's
180194
// initializer.
181-
assert(GV->hasInitializer() &&
182-
"Format of _reflect function not recognized");
195+
assert(GV->hasInitializer() && "Format of _reflect function not recognized");
183196
const Constant *Initializer = GV->getInitializer();
184197
Operand = Initializer;
185198
}
@@ -192,54 +205,48 @@ bool NVVMReflect::handleReflectFunction(Function *F) {
192205
StringRef ReflectArg = cast<ConstantDataSequential>(Operand)->getAsString();
193206
// Remove the null terminator from the string
194207
ReflectArg = ReflectArg.substr(0, ReflectArg.size() - 1);
195-
LLVM_DEBUG(dbgs() << "Arg of _reflect : " << ReflectArg << "\n");
196208

197209
int ReflectVal = 0; // The default value is 0
198210
if (VarMap.contains(ReflectArg)) {
199211
ReflectVal = VarMap[ReflectArg];
200212
}
201-
LLVM_DEBUG(dbgs() << "ReflectVal: " << ReflectVal << "\n");
213+
LLVM_DEBUG(dbgs() << "Replacing call of reflect function " << F->getName() << "(" << ReflectArg << ") with value " << ReflectVal << "\n");
214+
Constant *NewValue = ConstantInt::get(Call->getType(), ReflectVal);
215+
foldReflectCall(Call, NewValue);
216+
Call->eraseFromParent();
217+
}
202218

203-
// If the immediate user is a simple comparison we want to simplify it.
204-
for (User *U : Call->users())
205-
if (Instruction *I = dyn_cast<Instruction>(U))
206-
ToSimplify.push_back(I);
219+
// Remove the __nvvm_reflect function from the module
220+
F->eraseFromParent();
221+
}
207222

208-
Call->replaceAllUsesWith(ConstantInt::get(Call->getType(), ReflectVal));
209-
ToRemove.push_back(Call);
223+
void NVVMReflect::foldReflectCall(CallInst *Call, Constant *NewValue) {
224+
// Initialize worklist with all users of the call
225+
SmallVector<Instruction*, 8> Worklist;
226+
for (User *U : Call->users()) {
227+
if (Instruction *I = dyn_cast<Instruction>(U)) {
228+
Worklist.push_back(I);
229+
}
210230
}
211231

212-
// The code guarded by __nvvm_reflect may be invalid for the target machine.
213-
// Traverse the use-def chain, continually simplifying constant expressions
214-
// until we find a terminator that we can then remove.
215-
while (!ToSimplify.empty()) {
216-
Instruction *I = ToSimplify.pop_back_val();
217-
if (Constant *C = ConstantFoldInstruction(I, F->getDataLayout())) {
218-
for (User *U : I->users())
219-
if (Instruction *I = dyn_cast<Instruction>(U))
220-
ToSimplify.push_back(I);
232+
Call->replaceAllUsesWith(NewValue);
221233

222-
I->replaceAllUsesWith(C);
223-
if (isInstructionTriviallyDead(I)) {
224-
ToRemove.push_back(I);
234+
while (!Worklist.empty()) {
235+
Instruction *I = Worklist.pop_back_val();
236+
if (Constant *C = ConstantFoldInstruction(I, Call->getModule()->getDataLayout())) {
237+
// Add all users of this instruction to the worklist, replace it with the constant
238+
// then delete it if it's dead
239+
for (User *U : I->users()) {
240+
if (Instruction *UI = dyn_cast<Instruction>(U))
241+
Worklist.push_back(UI);
225242
}
243+
I->replaceAllUsesWith(C);
244+
if (isInstructionTriviallyDead(I))
245+
I->eraseFromParent();
226246
} else if (I->isTerminator()) {
227247
ConstantFoldTerminator(I->getParent());
228248
}
229249
}
230-
231-
// Removing via isInstructionTriviallyDead may add duplicates to the ToRemove
232-
// array. Filter out the duplicates before starting to erase from parent.
233-
std::sort(ToRemove.begin(), ToRemove.end());
234-
auto *NewLastIter = llvm::unique(ToRemove);
235-
ToRemove.erase(NewLastIter, ToRemove.end());
236-
237-
for (Instruction *I : ToRemove)
238-
I->eraseFromParent();
239-
240-
// Remove the __nvvm_reflect function from the module
241-
F->eraseFromParent();
242-
return ToRemove.size() > 0;
243250
}
244251

245252
bool NVVMReflect::runOnModule(Module &M) {
@@ -250,15 +257,19 @@ bool NVVMReflect::runOnModule(Module &M) {
250257

251258
bool Changed = false;
252259
// Names of reflect function to find and replace
253-
SmallVector<std::string, 3> ReflectNames = {
254-
NVVM_REFLECT_FUNCTION, NVVM_REFLECT_OCL_FUNCTION,
255-
Intrinsic::getName(Intrinsic::nvvm_reflect).str()};
260+
SmallVector<StringRef, 5> ReflectNames = {
261+
NVVM_REFLECT_FUNCTION,
262+
NVVM_REFLECT_OCL_FUNCTION,
263+
Intrinsic::getName(Intrinsic::nvvm_reflect),
264+
};
256265

257266
// Process all reflect functions
258-
for (const std::string &Name : ReflectNames) {
259-
Function *ReflectFunction = M.getFunction(Name);
260-
if (ReflectFunction) {
261-
Changed |= handleReflectFunction(ReflectFunction);
267+
for (StringRef Name : ReflectNames) {
268+
if (Function *ReflectFunction = M.getFunction(Name)) {
269+
// If the reflect functition is called, we need to replace the call
270+
// with the appropriate constant, modifying the IR.
271+
Changed |= ReflectFunction->getNumUses() > 0;
272+
handleReflectFunction(ReflectFunction);
262273
}
263274
}
264275

@@ -269,7 +280,9 @@ bool NVVMReflect::runOnModule(Module &M) {
269280
NVVMReflectPass::NVVMReflectPass(unsigned SmVersion) {
270281
VarMap["__CUDA_ARCH"] = SmVersion * 10;
271282
}
272-
PreservedAnalyses NVVMReflectPass::run(Module &M, ModuleAnalysisManager &AM) {
283+
284+
PreservedAnalyses NVVMReflectPass::run(Module &M,
285+
ModuleAnalysisManager &AM) {
273286
return NVVMReflect(VarMap).runOnModule(M) ? PreservedAnalyses::none()
274-
: PreservedAnalyses::all();
275-
}
287+
: PreservedAnalyses::all();
288+
}
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
; Verify that when passing in command-line options to NVVMReflect, that reflect calls are replaced with
2+
; the appropriate command line values.
3+
4+
declare i32 @__nvvm_reflect(ptr)
5+
@ftz = private unnamed_addr addrspace(1) constant [11 x i8] c"__CUDA_FTZ\00"
6+
@arch = private unnamed_addr addrspace(1) constant [12 x i8] c"__CUDA_ARCH\00"
7+
8+
; RUN: opt -passes=nvvm-reflect -mtriple=nvptx-nvidia-cuda -nvvm-reflect-add=__CUDA_FTZ=1,__CUDA_ARCH=350 %s -S | FileCheck %s --check-prefix=CHECK-FTZ1-ARCH350
9+
; RUN: opt -passes=nvvm-reflect -mtriple=nvptx-nvidia-cuda -nvvm-reflect-add=__CUDA_FTZ=0 -nvvm-reflect-add=__CUDA_ARCH=520 %s -S | FileCheck %s --check-prefix=CHECK-FTZ0-ARCH520
10+
11+
; Verify that if we have module metadata that sets __CUDA_FTZ=1, that gets overridden by the command line arguments
12+
13+
; RUN: cat %s > %t.options
14+
; RUN: echo '!llvm.module.flags = !{!0}' >> %t.options
15+
; RUN: echo '!0 = !{i32 4, !"nvvm-reflect-ftz", i32 1}' >> %t.options
16+
; RUN: opt -passes=nvvm-reflect -mtriple=nvptx-nvidia-cuda -nvvm-reflect-add=__CUDA_FTZ=0,__CUDA_ARCH=520 %t.options -S | FileCheck %s --check-prefix=CHECK-FTZ0-ARCH520
17+
18+
define i32 @options() {
19+
%1 = call i32 @__nvvm_reflect(ptr addrspacecast (ptr addrspace(1) @ftz to ptr))
20+
%2 = call i32 @__nvvm_reflect(ptr addrspacecast (ptr addrspace(1) @arch to ptr))
21+
%3 = add i32 %1, %2
22+
ret i32 %3
23+
}
24+
25+
; CHECK-FTZ1-ARCH350: ret i32 351
26+
; CHECK-FTZ0-ARCH520: ret i32 520

0 commit comments

Comments
 (0)