Skip to content

Commit 21ec838

Browse files
committed
making nvvm reflect more efficient
1 parent b712068 commit 21ec838

File tree

4 files changed

+112
-60
lines changed

4 files changed

+112
-60
lines changed

llvm/lib/Target/NVPTX/NVPTX.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ ModulePass *createNVPTXAssignValidGlobalNamesPass();
4343
ModulePass *createGenericToNVVMLegacyPass();
4444
ModulePass *createNVPTXCtorDtorLoweringLegacyPass();
4545
FunctionPass *createNVVMIntrRangePass();
46-
FunctionPass *createNVVMReflectPass(unsigned int SmVersion);
46+
ModulePass *createNVVMReflectPass(unsigned int SmVersion);
4747
MachineFunctionPass *createNVPTXPrologEpilogPass();
4848
MachineFunctionPass *createNVPTXReplaceImageHandlesPass();
4949
FunctionPass *createNVPTXImageOptimizerPass();
@@ -78,12 +78,12 @@ struct NVVMIntrRangePass : PassInfoMixin<NVVMIntrRangePass> {
7878
};
7979

8080
struct NVVMReflectPass : PassInfoMixin<NVVMReflectPass> {
81-
NVVMReflectPass();
82-
NVVMReflectPass(unsigned SmVersion) : SmVersion(SmVersion) {}
83-
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM);
81+
NVVMReflectPass() : NVVMReflectPass(0) {}
82+
NVVMReflectPass(unsigned SmVersion);
83+
PreservedAnalyses run(Module &F, ModuleAnalysisManager &AM);
8484

8585
private:
86-
unsigned SmVersion;
86+
StringMap<int> VarMap;
8787
};
8888

8989
struct GenericToNVVMPass : PassInfoMixin<GenericToNVVMPass> {

llvm/lib/Target/NVPTX/NVPTXPassRegistry.def

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#endif
1919
MODULE_PASS("generic-to-nvvm", GenericToNVVMPass())
2020
MODULE_PASS("nvptx-lower-ctor-dtor", NVPTXCtorDtorLoweringPass())
21+
MODULE_PASS("nvvm-reflect", NVVMReflectPass())
2122
#undef MODULE_PASS
2223

2324
#ifndef FUNCTION_ANALYSIS
@@ -36,7 +37,6 @@ FUNCTION_ALIAS_ANALYSIS("nvptx-aa", NVPTXAA())
3637
#define FUNCTION_PASS(NAME, CREATE_PASS)
3738
#endif
3839
FUNCTION_PASS("nvvm-intr-range", NVVMIntrRangePass())
39-
FUNCTION_PASS("nvvm-reflect", NVVMReflectPass())
4040
FUNCTION_PASS("nvptx-copy-byval-args", NVPTXCopyByValArgsPass())
4141
FUNCTION_PASS("nvptx-lower-args", NVPTXLowerArgsPass(*this));
4242
#undef FUNCTION_PASS

llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -240,11 +240,12 @@ void NVPTXTargetMachine::registerPassBuilderCallbacks(PassBuilder &PB) {
240240

241241
PB.registerPipelineStartEPCallback(
242242
[this](ModulePassManager &PM, OptimizationLevel Level) {
243-
FunctionPassManager FPM;
244243
// We do not want to fold out calls to nvvm.reflect early if the user
245244
// has not provided a target architecture just yet.
246245
if (Subtarget.hasTargetName())
247-
FPM.addPass(NVVMReflectPass(Subtarget.getSmVersion()));
246+
PM.addPass(NVVMReflectPass(Subtarget.getSmVersion()));
247+
248+
FunctionPassManager FPM;
248249
// Note: NVVMIntrRangePass was causing numerical discrepancies at one
249250
// point, if issues crop up, consider disabling.
250251
FPM.addPass(NVVMIntrRangePass());

llvm/lib/Target/NVPTX/NVVMReflect.cpp

Lines changed: 103 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
// See https://llvm.org/LICENSE.txt for license information.
55
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66
//
7-
//===----------------------------------------------------------------------===//
7+
88
//
99
// This pass replaces occurrences of __nvvm_reflect("foo") and llvm.nvvm.reflect
1010
// with an integer.
@@ -25,7 +25,6 @@
2525
#include "llvm/IR/Constants.h"
2626
#include "llvm/IR/DerivedTypes.h"
2727
#include "llvm/IR/Function.h"
28-
#include "llvm/IR/InstIterator.h"
2928
#include "llvm/IR/Instructions.h"
3029
#include "llvm/IR/Intrinsics.h"
3130
#include "llvm/IR/IntrinsicsNVPTX.h"
@@ -39,27 +38,43 @@
3938
#include "llvm/Transforms/Scalar.h"
4039
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
4140
#include "llvm/Transforms/Utils/Local.h"
41+
#include "llvm/Transforms/Utils/StripGCRelocates.h"
4242
#include <algorithm>
4343
#define NVVM_REFLECT_FUNCTION "__nvvm_reflect"
4444
#define NVVM_REFLECT_OCL_FUNCTION "__nvvm_reflect_ocl"
4545

4646
using namespace llvm;
4747

48-
#define DEBUG_TYPE "nvptx-reflect"
48+
#define DEBUG_TYPE "nvvm-reflect"
4949

5050
namespace {
51-
class NVVMReflect : public FunctionPass {
51+
class NVVMReflect : public ModulePass {
52+
private:
53+
StringMap<int> VarMap;
54+
/// Process a reflect function by finding all its uses and replacing them with
55+
/// appropriate constant values. For __CUDA_FTZ, uses the module flag value.
56+
/// For __CUDA_ARCH, uses SmVersion * 10. For all other strings, uses 0.
57+
bool handleReflectFunction(Function *F);
58+
void setVarMap(Module &M);
59+
5260
public:
5361
static char ID;
54-
unsigned int SmVersion;
5562
NVVMReflect() : NVVMReflect(0) {}
56-
explicit NVVMReflect(unsigned int Sm) : FunctionPass(ID), SmVersion(Sm) {}
57-
58-
bool runOnFunction(Function &) override;
63+
// __CUDA_FTZ is assigned in `runOnModule` by checking nvvm-reflect-ftz module
64+
// metadata.
65+
explicit NVVMReflect(unsigned int Sm) : ModulePass(ID) {
66+
VarMap["__CUDA_ARCH"] = Sm * 10;
67+
initializeNVVMReflectPass(*PassRegistry::getPassRegistry());
68+
}
69+
// This mapping will contain should include __CUDA_FTZ and __CUDA_ARCH values.
70+
explicit NVVMReflect(const StringMap<int> &Mapping) : ModulePass(ID), VarMap(Mapping) {
71+
initializeNVVMReflectPass(*PassRegistry::getPassRegistry());
72+
}
73+
bool runOnModule(Module &M) override;
5974
};
6075
} // namespace
6176

62-
FunctionPass *llvm::createNVVMReflectPass(unsigned int SmVersion) {
77+
ModulePass *llvm::createNVVMReflectPass(unsigned int SmVersion) {
6378
return new NVVMReflect(SmVersion);
6479
}
6580

@@ -72,27 +87,51 @@ INITIALIZE_PASS(NVVMReflect, "nvvm-reflect",
7287
"Replace occurrences of __nvvm_reflect() calls with 0/1", false,
7388
false)
7489

75-
static bool runNVVMReflect(Function &F, unsigned SmVersion) {
76-
if (!NVVMReflectEnabled)
77-
return false;
90+
static cl::list<std::string>
91+
ReflectList("nvvm-reflect-list", cl::value_desc("name=<int>"), cl::Hidden,
92+
cl::desc("A list of string=num assignments"),
93+
cl::ValueRequired);
7894

79-
if (F.getName() == NVVM_REFLECT_FUNCTION ||
80-
F.getName() == NVVM_REFLECT_OCL_FUNCTION) {
81-
assert(F.isDeclaration() && "_reflect function should not have a body");
82-
assert(F.getReturnType()->isIntegerTy() &&
83-
"_reflect's return type should be integer");
84-
return false;
95+
/// The command line can look as follows :
96+
/// -nvvm-reflect-list a=1,b=2 -nvvm-reflect-list c=3,d=0 -R e=2
97+
/// The strings "a=1,b=2", "c=3,d=0", "e=2" are available in the
98+
/// ReflectList vector. First, each of ReflectList[i] is 'split'
99+
/// using "," as the delimiter. Then each of this part is split
100+
/// using "=" as the delimiter.
101+
void NVVMReflect::setVarMap(Module &M) {
102+
if (auto *Flag = mdconst::extract_or_null<ConstantInt>(
103+
M.getModuleFlag("nvvm-reflect-ftz")))
104+
VarMap["__CUDA_FTZ"] = Flag->getSExtValue();
105+
106+
for (unsigned I = 0, E = ReflectList.size(); I != E; ++I) {
107+
LLVM_DEBUG(dbgs() << "Option : " << ReflectList[I] << "\n");
108+
SmallVector<StringRef, 4> NameValList;
109+
StringRef(ReflectList[I]).split(NameValList, ",");
110+
for (unsigned J = 0, EJ = NameValList.size(); J != EJ; ++J) {
111+
SmallVector<StringRef, 2> NameValPair;
112+
NameValList[J].split(NameValPair, "=");
113+
assert(NameValPair.size() == 2 && "name=val expected");
114+
StringRef ValStr = NameValPair[1].trim();
115+
int Val;
116+
if (ValStr.getAsInteger(10, Val))
117+
report_fatal_error("integer value expected");
118+
VarMap[NameValPair[0]] = Val;
119+
}
85120
}
121+
}
122+
123+
bool NVVMReflect::handleReflectFunction(Function *F) {
124+
// Validate _reflect function
125+
assert(F->isDeclaration() && "_reflect function should not have a body");
126+
assert(F->getReturnType()->isIntegerTy() &&
127+
"_reflect's return type should be integer");
86128

87129
SmallVector<Instruction *, 4> ToRemove;
88130
SmallVector<Instruction *, 4> ToSimplify;
89131

90-
// Go through the calls in this function. Each call to __nvvm_reflect or
91-
// llvm.nvvm.reflect should be a CallInst with a ConstantArray argument.
92-
// First validate that. If the c-string corresponding to the ConstantArray can
93-
// be found successfully, see if it can be found in VarMap. If so, replace the
94-
// uses of CallInst with the value found in VarMap. If not, replace the use
95-
// with value 0.
132+
// Go through the uses of the reflect function. Each use should be a CallInst
133+
// with a ConstantArray argument. Replace the uses with the appropriate
134+
// constant values.
96135

97136
// The IR for __nvvm_reflect calls differs between CUDA versions.
98137
//
@@ -113,15 +152,10 @@ static bool runNVVMReflect(Function &F, unsigned SmVersion) {
113152
//
114153
// In this case, we get a Constant with a GlobalVariable operand and we need
115154
// to dig deeper to find its initializer with the string we'll use for lookup.
116-
for (Instruction &I : instructions(F)) {
117-
CallInst *Call = dyn_cast<CallInst>(&I);
118-
if (!Call)
119-
continue;
120-
Function *Callee = Call->getCalledFunction();
121-
if (!Callee || (Callee->getName() != NVVM_REFLECT_FUNCTION &&
122-
Callee->getName() != NVVM_REFLECT_OCL_FUNCTION &&
123-
Callee->getIntrinsicID() != Intrinsic::nvvm_reflect))
124-
continue;
155+
156+
for (User *U : F->users()) {
157+
assert(isa<CallInst>(U) && "Only a call instruction can use _reflect");
158+
CallInst *Call = cast<CallInst>(U);
125159

126160
// FIXME: Improve error handling here and elsewhere in this pass.
127161
assert(Call->getNumOperands() == 2 &&
@@ -156,20 +190,15 @@ static bool runNVVMReflect(Function &F, unsigned SmVersion) {
156190
"Format of _reflect function not recognized");
157191

158192
StringRef ReflectArg = cast<ConstantDataSequential>(Operand)->getAsString();
193+
// Remove the null terminator from the string
159194
ReflectArg = ReflectArg.substr(0, ReflectArg.size() - 1);
160195
LLVM_DEBUG(dbgs() << "Arg of _reflect : " << ReflectArg << "\n");
161196

162197
int ReflectVal = 0; // The default value is 0
163-
if (ReflectArg == "__CUDA_FTZ") {
164-
// Try to pull __CUDA_FTZ from the nvvm-reflect-ftz module flag. Our
165-
// choice here must be kept in sync with AutoUpgrade, which uses the same
166-
// technique to detect whether ftz is enabled.
167-
if (auto *Flag = mdconst::extract_or_null<ConstantInt>(
168-
F.getParent()->getModuleFlag("nvvm-reflect-ftz")))
169-
ReflectVal = Flag->getSExtValue();
170-
} else if (ReflectArg == "__CUDA_ARCH") {
171-
ReflectVal = SmVersion * 10;
198+
if (VarMap.contains(ReflectArg)) {
199+
ReflectVal = VarMap[ReflectArg];
172200
}
201+
LLVM_DEBUG(dbgs() << "ReflectVal: " << ReflectVal << "\n");
173202

174203
// If the immediate user is a simple comparison we want to simplify it.
175204
for (User *U : Call->users())
@@ -185,7 +214,7 @@ static bool runNVVMReflect(Function &F, unsigned SmVersion) {
185214
// until we find a terminator that we can then remove.
186215
while (!ToSimplify.empty()) {
187216
Instruction *I = ToSimplify.pop_back_val();
188-
if (Constant *C = ConstantFoldInstruction(I, F.getDataLayout())) {
217+
if (Constant *C = ConstantFoldInstruction(I, F->getDataLayout())) {
189218
for (User *U : I->users())
190219
if (Instruction *I = dyn_cast<Instruction>(U))
191220
ToSimplify.push_back(I);
@@ -202,23 +231,45 @@ static bool runNVVMReflect(Function &F, unsigned SmVersion) {
202231
// Removing via isInstructionTriviallyDead may add duplicates to the ToRemove
203232
// array. Filter out the duplicates before starting to erase from parent.
204233
std::sort(ToRemove.begin(), ToRemove.end());
205-
auto NewLastIter = llvm::unique(ToRemove);
234+
auto *NewLastIter = llvm::unique(ToRemove);
206235
ToRemove.erase(NewLastIter, ToRemove.end());
207236

208237
for (Instruction *I : ToRemove)
209238
I->eraseFromParent();
210239

240+
// Remove the __nvvm_reflect function from the module
241+
F->eraseFromParent();
211242
return ToRemove.size() > 0;
212243
}
213244

214-
bool NVVMReflect::runOnFunction(Function &F) {
215-
return runNVVMReflect(F, SmVersion);
216-
}
245+
bool NVVMReflect::runOnModule(Module &M) {
246+
if (!NVVMReflectEnabled)
247+
return false;
248+
249+
setVarMap(M);
217250

218-
NVVMReflectPass::NVVMReflectPass() : NVVMReflectPass(0) {}
251+
bool Changed = false;
252+
// Names of reflect function to find and replace
253+
SmallVector<std::string, 3> ReflectNames = {
254+
NVVM_REFLECT_FUNCTION, NVVM_REFLECT_OCL_FUNCTION,
255+
Intrinsic::getName(Intrinsic::nvvm_reflect).str()};
256+
257+
// Process all reflect functions
258+
for (const std::string &Name : ReflectNames) {
259+
Function *ReflectFunction = M.getFunction(Name);
260+
if (ReflectFunction) {
261+
Changed |= handleReflectFunction(ReflectFunction);
262+
}
263+
}
264+
265+
return Changed;
266+
}
219267

220-
PreservedAnalyses NVVMReflectPass::run(Function &F,
221-
FunctionAnalysisManager &AM) {
222-
return runNVVMReflect(F, SmVersion) ? PreservedAnalyses::none()
223-
: PreservedAnalyses::all();
268+
// Implementations for the pass that works with the new pass manager.
269+
NVVMReflectPass::NVVMReflectPass(unsigned SmVersion) {
270+
VarMap["__CUDA_ARCH"] = SmVersion * 10;
224271
}
272+
PreservedAnalyses NVVMReflectPass::run(Module &M, ModuleAnalysisManager &AM) {
273+
return NVVMReflect(VarMap).runOnModule(M) ? PreservedAnalyses::none()
274+
: PreservedAnalyses::all();
275+
}

0 commit comments

Comments
 (0)