Skip to content

Commit 6fd4fa9

Browse files
committed
final reflect cleanup
1 parent f5444f3 commit 6fd4fa9

File tree

3 files changed

+99
-124
lines changed

3 files changed

+99
-124
lines changed

llvm/lib/Target/NVPTX/NVPTX.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,12 +78,12 @@ struct NVVMIntrRangePass : PassInfoMixin<NVVMIntrRangePass> {
7878
};
7979

8080
struct NVVMReflectPass : PassInfoMixin<NVVMReflectPass> {
81-
NVVMReflectPass() : NVVMReflectPass(0) {}
82-
NVVMReflectPass(unsigned SmVersion);
81+
NVVMReflectPass() : SmVersion(0) {}
82+
NVVMReflectPass(unsigned SmVersion) : SmVersion(SmVersion) {}
8383
PreservedAnalyses run(Module &F, ModuleAnalysisManager &AM);
8484

8585
private:
86-
StringMap<int> VarMap;
86+
unsigned SmVersion;
8787
};
8888

8989
struct GenericToNVVMPass : PassInfoMixin<GenericToNVVMPass> {

llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,27 @@ static cl::opt<bool> EarlyByValArgsCopy(
8787
cl::desc("Create a copy of byval function arguments early."),
8888
cl::init(false), cl::Hidden);
8989

90+
namespace llvm {
91+
92+
void initializeGenericToNVVMLegacyPassPass(PassRegistry &);
93+
void initializeNVPTXAllocaHoistingPass(PassRegistry &);
94+
void initializeNVPTXAssignValidGlobalNamesPass(PassRegistry &);
95+
void initializeNVPTXAtomicLowerPass(PassRegistry &);
96+
void initializeNVPTXCtorDtorLoweringLegacyPass(PassRegistry &);
97+
void initializeNVPTXLowerAggrCopiesPass(PassRegistry &);
98+
void initializeNVPTXLowerAllocaPass(PassRegistry &);
99+
void initializeNVPTXLowerUnreachablePass(PassRegistry &);
100+
void initializeNVPTXCtorDtorLoweringLegacyPass(PassRegistry &);
101+
void initializeNVPTXLowerArgsLegacyPassPass(PassRegistry &);
102+
void initializeNVPTXProxyRegErasurePass(PassRegistry &);
103+
void initializeNVPTXForwardParamsPassPass(PassRegistry &);
104+
void initializeNVVMIntrRangePass(PassRegistry &);
105+
void initializeNVVMReflectLegacyPassPass(PassRegistry &);
106+
void initializeNVPTXAAWrapperPassPass(PassRegistry &);
107+
void initializeNVPTXExternalAAWrapperPass(PassRegistry &);
108+
109+
} // end namespace llvm
110+
90111
extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeNVPTXTarget() {
91112
// Register the target.
92113
RegisterTargetMachine<NVPTXTargetMachine32> X(getTheNVPTXTarget32());
@@ -95,7 +116,7 @@ extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeNVPTXTarget() {
95116
PassRegistry &PR = *PassRegistry::getPassRegistry();
96117
// FIXME: This pass is really intended to be invoked during IR optimization,
97118
// but it's very NVPTX-specific.
98-
initializeNVVMReflectPass(PR);
119+
initializeNVVMReflectLegacyPassPass(PR);
99120
initializeNVVMIntrRangePass(PR);
100121
initializeGenericToNVVMLegacyPassPass(PR);
101122
initializeNVPTXAllocaHoistingPass(PR);

llvm/lib/Target/NVPTX/NVVMReflect.cpp

Lines changed: 74 additions & 120 deletions
Original file line numberDiff line numberDiff line change
@@ -41,44 +41,60 @@
4141
#include "llvm/ADT/StringExtras.h"
4242
#define NVVM_REFLECT_FUNCTION "__nvvm_reflect"
4343
#define NVVM_REFLECT_OCL_FUNCTION "__nvvm_reflect_ocl"
44+
// Argument of reflect call to retrive arch number
45+
#define CUDA_ARCH_NAME "__CUDA_ARCH"
46+
// Argument of reflect call to retrive ftz mode
47+
#define CUDA_FTZ_NAME "__CUDA_FTZ"
48+
// Name of module metadata where ftz mode is stored
49+
#define CUDA_FTZ_MODULE_NAME "nvvm-reflect-ftz"
4450

4551
using namespace llvm;
4652

4753
#define DEBUG_TYPE "nvvm-reflect"
4854

55+
namespace llvm {
56+
void initializeNVVMReflectLegacyPassPass(PassRegistry &);
57+
}
58+
4959
namespace {
50-
class NVVMReflect : public ModulePass {
60+
class NVVMReflect {
5161
private:
52-
StringMap<int> VarMap;
53-
void handleReflectFunction(Function *F);
54-
void setVarMap(Module &M);
62+
// Map from reflect function call arguments to the value to replace the call with.
63+
// Should include __CUDA_FTZ and __CUDA_ARCH values.
64+
StringMap<int> ReflectMap;
65+
bool handleReflectFunction(Module &M, StringRef ReflectName);
66+
void populateReflectMap(Module &M);
5567
void foldReflectCall(CallInst *Call, Constant *NewValue);
5668
public:
57-
static char ID;
58-
NVVMReflect() : NVVMReflect(0) {}
5969
// __CUDA_FTZ is assigned in `runOnModule` by checking nvvm-reflect-ftz module
6070
// metadata.
61-
explicit NVVMReflect(unsigned SmVersion) : ModulePass(ID), VarMap({{"__CUDA_ARCH", SmVersion * 10}}) {
62-
initializeNVVMReflectPass(*PassRegistry::getPassRegistry());
63-
}
64-
// This mapping will contain should include __CUDA_FTZ and __CUDA_ARCH values.
65-
explicit NVVMReflect(const StringMap<int> &Mapping) : ModulePass(ID), VarMap(Mapping) {
66-
initializeNVVMReflectPass(*PassRegistry::getPassRegistry());
71+
explicit NVVMReflect(unsigned SmVersion) : ReflectMap({{CUDA_ARCH_NAME, SmVersion * 10}}) {}
72+
bool runOnModule(Module &M);
73+
};
74+
} // namespace
75+
76+
class NVVMReflectLegacyPass : public ModulePass {
77+
private:
78+
NVVMReflect Impl;
79+
public:
80+
static char ID;
81+
NVVMReflectLegacyPass(unsigned SmVersion) : ModulePass(ID), Impl(SmVersion) {
82+
initializeNVVMReflectLegacyPassPass(*PassRegistry::getPassRegistry());
6783
}
6884
bool runOnModule(Module &M) override;
6985
};
70-
} // namespace
7186

7287
ModulePass *llvm::createNVVMReflectPass(unsigned int SmVersion) {
73-
return new NVVMReflect(SmVersion);
88+
LLVM_DEBUG(dbgs() << "Creating NVVMReflectPass with SM version " << SmVersion << "\n");
89+
return new NVVMReflectLegacyPass(SmVersion);
7490
}
7591

7692
static cl::opt<bool>
7793
NVVMReflectEnabled("nvvm-reflect-enable", cl::init(true), cl::Hidden,
7894
cl::desc("NVVM reflection, enabled by default"));
7995

80-
char NVVMReflect::ID = 0;
81-
INITIALIZE_PASS(NVVMReflect, "nvvm-reflect",
96+
char NVVMReflectLegacyPass::ID = 0;
97+
INITIALIZE_PASS(NVVMReflectLegacyPass, "nvvm-reflect",
8298
"Replace occurrences of __nvvm_reflect() calls with 0/1", false,
8399
false)
84100

@@ -92,10 +108,10 @@ ReflectList("nvvm-reflect-add", cl::value_desc("name=<int>"), cl::Hidden,
92108

93109
// Set the VarMap with, first, the value of __CUDA_FTZ from module metadata, and then
94110
// the key/value pairs from the command line.
95-
void NVVMReflect::setVarMap(Module &M) {
111+
void NVVMReflect::populateReflectMap(Module &M) {
96112
if (auto *Flag = mdconst::extract_or_null<ConstantInt>(
97-
M.getModuleFlag("nvvm-reflect-ftz")))
98-
VarMap["__CUDA_FTZ"] = Flag->getSExtValue();
113+
M.getModuleFlag(CUDA_FTZ_MODULE_NAME)))
114+
ReflectMap[CUDA_FTZ_NAME] = Flag->getSExtValue();
99115

100116
for (StringRef Option : ReflectList) {
101117
LLVM_DEBUG(dbgs() << "ReflectOption : " << Option << "\n");
@@ -107,94 +123,52 @@ void NVVMReflect::setVarMap(Module &M) {
107123
int ValInt;
108124
if (!to_integer(Val.trim(), ValInt, 10))
109125
report_fatal_error("integer value expected in nvvm-reflect-list option '" + Option + "'");
110-
VarMap[Name] = ValInt;
126+
ReflectMap[Name] = ValInt;
111127
}
112128
}
113129

114-
/// Process a reflect function by finding all its uses and replacing them with
130+
/// Process a reflect function by finding all its calls and replacing them with
115131
/// appropriate constant values. For __CUDA_FTZ, uses the module flag value.
116132
/// For __CUDA_ARCH, uses SmVersion * 10. For all other strings, uses 0.
117-
void NVVMReflect::handleReflectFunction(Function *F) {
118-
// Validate _reflect function
133+
bool NVVMReflect::handleReflectFunction(Module &M, StringRef ReflectName) {
134+
Function *F = M.getFunction(ReflectName);
135+
if (!F)
136+
return false;
119137
assert(F->isDeclaration() && "_reflect function should not have a body");
120138
assert(F->getReturnType()->isIntegerTy() && "_reflect's return type should be integer");
121139

122-
123-
// Go through the uses of the reflect function. Each use should be a CallInst
124-
// with a ConstantArray argument. Replace the uses with the appropriate constant values.
125-
126-
// The IR for __nvvm_reflect calls differs between CUDA versions.
127-
//
128-
// CUDA 6.5 and earlier uses this sequence:
129-
// %ptr = tail call i8* @llvm.nvvm.ptr.constant.to.gen.p0i8.p4i8
130-
// (i8 addrspace(4)* getelementptr inbounds
131-
// ([8 x i8], [8 x i8] addrspace(4)* @str, i32 0, i32 0))
132-
// %reflect = tail call i32 @__nvvm_reflect(i8* %ptr)
133-
//
134-
// The value returned by Sym->getOperand(0) is a Constant with a
135-
// ConstantDataSequential operand which can be converted to string and used
136-
// for lookup.
137-
//
138-
// CUDA 7.0 does it slightly differently:
139-
// %reflect = call i32 @__nvvm_reflect(i8* addrspacecast
140-
// (i8 addrspace(1)* getelementptr inbounds
141-
// ([8 x i8], [8 x i8] addrspace(1)* @str, i32 0, i32 0) to i8*))
142-
//
143-
// In this case, we get a Constant with a GlobalVariable operand and we need
144-
// to dig deeper to find its initializer with the string we'll use for lookup.
145-
140+
bool Changed = F->getNumUses() > 0;
146141
for (User *U : make_early_inc_range(F->users())) {
142+
// Reflect function calls look like:
143+
// @arch = private unnamed_addr addrspace(1) constant [12 x i8] c"__CUDA_ARCH\00"
144+
// call i32 @__nvvm_reflect(ptr addrspacecast (ptr addrspace(1) @arch to ptr))
145+
// We need to extract the string argument from the call (i.e. "__CUDA_ARCH")
147146
if (!isa<CallInst>(U))
148147
report_fatal_error("__nvvm_reflect can only be used in a call instruction");
149148
CallInst *Call = cast<CallInst>(U);
150-
151149
if (Call->getNumOperands() != 2)
152150
report_fatal_error("__nvvm_reflect requires exactly one argument");
153151

154-
// In cuda 6.5 and earlier, we will have an extra constant-to-generic
155-
// conversion of the string.
156-
const Value *Str = Call->getArgOperand(0);
157-
if (const CallInst *ConvCall = dyn_cast<CallInst>(Str)) {
158-
// Verify this is the constant-to-generic intrinsic
159-
Function *Callee = ConvCall->getCalledFunction();
160-
if (!Callee || !Callee->isIntrinsic() ||
161-
!Callee->getName().starts_with("llvm.nvvm.ptr.constant.to.gen"))
162-
report_fatal_error("Expected llvm.nvvm.ptr.constant.to.gen intrinsic");
163-
if (ConvCall->getNumOperands() != 2)
164-
report_fatal_error("Expected one argument for ptr conversion");
165-
Str = ConvCall->getArgOperand(0);
166-
}
167-
// Pre opaque pointers we have a constant expression wrapping the constant
168-
Str = Str->stripPointerCasts();
169-
if (!isa<Constant>(Str))
152+
const Value *GlobalStr = Call->getArgOperand(0)->stripPointerCasts();
153+
if (!isa<Constant>(GlobalStr))
170154
report_fatal_error("__nvvm_reflect argument must be a constant string");
171155

172-
const Value *Operand = cast<Constant>(Str)->getOperand(0);
173-
if (const GlobalVariable *GV = dyn_cast<GlobalVariable>(Operand)) {
174-
// For CUDA-7.0 style __nvvm_reflect calls, we need to find the operand's
175-
// initializer.
176-
if (!GV->hasInitializer())
177-
report_fatal_error("__nvvm_reflect string must have an initializer");
178-
const Constant *Initializer = GV->getInitializer();
179-
Operand = Initializer;
180-
}
181-
182-
if (!isa<ConstantDataSequential>(Operand))
156+
const Value *ConstantStr = cast<Constant>(GlobalStr)->getOperand(0);
157+
if (!isa<ConstantDataSequential>(ConstantStr))
183158
report_fatal_error("__nvvm_reflect argument must be a string constant");
184-
if (!cast<ConstantDataSequential>(Operand)->isCString())
159+
if (!cast<ConstantDataSequential>(ConstantStr)->isCString())
185160
report_fatal_error("__nvvm_reflect argument must be a null-terminated string");
186161

187-
StringRef ReflectArg = cast<ConstantDataSequential>(Operand)->getAsString();
162+
StringRef ReflectArg = cast<ConstantDataSequential>(ConstantStr)->getAsString();
188163
// Remove the null terminator from the string
189164
ReflectArg = ReflectArg.substr(0, ReflectArg.size() - 1);
190-
191165
if (ReflectArg.empty())
192166
report_fatal_error("__nvvm_reflect argument cannot be empty");
193-
167+
// Now that we have extracted the string argument, we can look it up in the VarMap
194168
int ReflectVal = 0; // The default value is 0
195-
if (VarMap.contains(ReflectArg)) {
196-
ReflectVal = VarMap[ReflectArg];
197-
}
169+
if (ReflectMap.contains(ReflectArg))
170+
ReflectVal = ReflectMap[ReflectArg];
171+
198172
LLVM_DEBUG(dbgs() << "Replacing call of reflect function " << F->getName() << "(" << ReflectArg << ") with value " << ReflectVal << "\n");
199173
Constant *NewValue = ConstantInt::get(Call->getType(), ReflectVal);
200174
foldReflectCall(Call, NewValue);
@@ -203,29 +177,26 @@ void NVVMReflect::handleReflectFunction(Function *F) {
203177

204178
// Remove the __nvvm_reflect function from the module
205179
F->eraseFromParent();
180+
return Changed;
206181
}
207182

208183
void NVVMReflect::foldReflectCall(CallInst *Call, Constant *NewValue) {
209-
// Initialize worklist with all users of the call
210184
SmallVector<Instruction*, 8> Worklist;
211-
for (User *U : Call->users()) {
212-
if (Instruction *I = dyn_cast<Instruction>(U)) {
213-
Worklist.push_back(I);
185+
// Replace an instruction with a constant and add all users of the instruction to the worklist
186+
auto ReplaceInstructionWithConst = [&](Instruction *I, Constant *C) {
187+
for (User *U : I->users()) {
188+
if (Instruction *UI = dyn_cast<Instruction>(U))
189+
Worklist.push_back(UI);
214190
}
215-
}
191+
I->replaceAllUsesWith(C);
192+
};
216193

217-
Call->replaceAllUsesWith(NewValue);
194+
ReplaceInstructionWithConst(Call, NewValue);
218195

219196
while (!Worklist.empty()) {
220197
Instruction *I = Worklist.pop_back_val();
221198
if (Constant *C = ConstantFoldInstruction(I, Call->getModule()->getDataLayout())) {
222-
// Add all users of this instruction to the worklist, replace it with the constant
223-
// then delete it if it's dead
224-
for (User *U : I->users()) {
225-
if (Instruction *UI = dyn_cast<Instruction>(U))
226-
Worklist.push_back(UI);
227-
}
228-
I->replaceAllUsesWith(C);
199+
ReplaceInstructionWithConst(I, C);
229200
if (isInstructionTriviallyDead(I))
230201
I->eraseFromParent();
231202
} else if (I->isTerminator()) {
@@ -237,37 +208,20 @@ void NVVMReflect::foldReflectCall(CallInst *Call, Constant *NewValue) {
237208
bool NVVMReflect::runOnModule(Module &M) {
238209
if (!NVVMReflectEnabled)
239210
return false;
240-
241-
setVarMap(M);
242-
243-
bool Changed = false;
244-
// Names of reflect function to find and replace
245-
SmallVector<StringRef, 5> ReflectNames = {
246-
NVVM_REFLECT_FUNCTION,
247-
NVVM_REFLECT_OCL_FUNCTION,
248-
Intrinsic::getName(Intrinsic::nvvm_reflect),
249-
};
250-
251-
// Process all reflect functions
252-
for (StringRef Name : ReflectNames) {
253-
if (Function *ReflectFunction = M.getFunction(Name)) {
254-
// If the reflect functition is called, we need to replace the call
255-
// with the appropriate constant, modifying the IR.
256-
Changed |= ReflectFunction->getNumUses() > 0;
257-
handleReflectFunction(ReflectFunction);
258-
}
259-
}
260-
211+
populateReflectMap(M);
212+
bool Changed = true;
213+
handleReflectFunction(M, NVVM_REFLECT_FUNCTION);
214+
handleReflectFunction(M, NVVM_REFLECT_OCL_FUNCTION);
215+
handleReflectFunction(M, Intrinsic::getName(Intrinsic::nvvm_reflect));
261216
return Changed;
262217
}
263218

264-
// Implementations for the pass that works with the new pass manager.
265-
NVVMReflectPass::NVVMReflectPass(unsigned SmVersion) {
266-
VarMap["__CUDA_ARCH"] = SmVersion * 10;
219+
bool NVVMReflectLegacyPass::runOnModule(Module &M) {
220+
return Impl.runOnModule(M);
267221
}
268222

269223
PreservedAnalyses NVVMReflectPass::run(Module &M,
270224
ModuleAnalysisManager &AM) {
271-
return NVVMReflect(VarMap).runOnModule(M) ? PreservedAnalyses::none()
225+
return NVVMReflect(SmVersion).runOnModule(M) ? PreservedAnalyses::none()
272226
: PreservedAnalyses::all();
273227
}

0 commit comments

Comments
 (0)