2727
2828using namespace llvm ;
2929
30+ static bool isInvariantLoad (const LoadInst *LI, const bool IsKernelFn) {
31+ // Don't bother with non-global loads
32+ if (LI->getPointerAddressSpace () != NVPTXAS::ADDRESS_SPACE_GLOBAL)
33+ return false ;
34+
35+ // If the load is already marked as invariant, we don't need to do anything
36+ if (LI->getMetadata (LLVMContext::MD_invariant_load))
37+ return false ;
38+
39+ // We use getUnderlyingObjects() here instead of getUnderlyingObject()
40+ // mainly because the former looks through phi nodes while the latter does
41+ // not. We need to look through phi nodes to handle pointer induction
42+ // variables.
43+ SmallVector<const Value *, 8 > Objs;
44+ getUnderlyingObjects (LI->getPointerOperand (), Objs);
45+
46+ return all_of (Objs, [&](const Value *V) {
47+ if (const auto *A = dyn_cast<const Argument>(V))
48+ return IsKernelFn && ((A->onlyReadsMemory () && A->hasNoAliasAttr ()) ||
49+ isParamGridConstant (*A));
50+ if (const auto *GV = dyn_cast<const GlobalVariable>(V))
51+ return GV->isConstant ();
52+ return false ;
53+ });
54+ }
55+
3056static void markLoadsAsInvariant (LoadInst *LI) {
3157 LI->setMetadata (LLVMContext::MD_invariant_load,
3258 MDNode::get (LI->getContext (), {}));
@@ -38,38 +64,12 @@ static bool tagInvariantLoads(Function &F) {
3864 bool Changed = false ;
3965 for (auto &I : instructions (F)) {
4066 if (auto *LI = dyn_cast<LoadInst>(&I)) {
41-
42- // Don't bother with non-global loads
43- if (LI->getPointerAddressSpace () != NVPTXAS::ADDRESS_SPACE_GLOBAL)
44- continue ;
45-
46- if (LI->getMetadata (LLVMContext::MD_invariant_load))
47- continue ;
48-
49- SmallVector<const Value *, 8 > Objs;
50-
51- // We use getUnderlyingObjects() here instead of getUnderlyingObject()
52- // mainly because the former looks through phi nodes while the latter does
53- // not. We need to look through phi nodes to handle pointer induction
54- // variables.
55-
56- getUnderlyingObjects (LI->getPointerOperand (), Objs);
57-
58- const bool IsInvariant = all_of (Objs, [&](const Value *V) {
59- if (const auto *A = dyn_cast<const Argument>(V))
60- return IsKernelFn && A->onlyReadsMemory () && A->hasNoAliasAttr ();
61- if (const auto *GV = dyn_cast<const GlobalVariable>(V))
62- return GV->isConstant ();
63- return false ;
64- });
65-
66- if (IsInvariant) {
67+ if (isInvariantLoad (LI, IsKernelFn)) {
6768 markLoadsAsInvariant (LI);
6869 Changed = true ;
6970 }
7071 }
7172 }
72-
7373 return Changed;
7474}
7575
0 commit comments