Skip to content

Commit 7fd4f03

Browse files
committed
address comments
1 parent af03c15 commit 7fd4f03

File tree

2 files changed

+31
-31
lines changed

2 files changed

+31
-31
lines changed

llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -766,12 +766,12 @@ NVPTX::Scope NVPTXDAGToDAGISel::getOperationScope(MemSDNode *N,
766766
llvm_unreachable("unhandled ordering");
767767
}
768768

769-
static bool canLowerToLDG(const MemSDNode *N, const NVPTXSubtarget &Subtarget,
769+
static bool canLowerToLDG(const MemSDNode &N, const NVPTXSubtarget &Subtarget,
770770
unsigned CodeAddrSpace) {
771771
// We use ldg (i.e. ld.global.nc) for invariant loads from the global address
772772
// space.
773773
return Subtarget.hasLDG() && CodeAddrSpace == NVPTX::AddressSpace::Global &&
774-
N->isInvariant();
774+
N.isInvariant();
775775
}
776776

777777
static unsigned int getFenceOp(NVPTX::Ordering O, NVPTX::Scope S,
@@ -1073,7 +1073,7 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
10731073

10741074
// Address Space Setting
10751075
const unsigned CodeAddrSpace = getCodeAddrSpace(LD);
1076-
if (canLowerToLDG(LD, *Subtarget, CodeAddrSpace))
1076+
if (canLowerToLDG(*LD, *Subtarget, CodeAddrSpace))
10771077
return tryLDGLDU(N);
10781078

10791079
SDLoc DL(N);
@@ -1158,7 +1158,7 @@ bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) {
11581158

11591159
// Address Space Setting
11601160
const unsigned CodeAddrSpace = getCodeAddrSpace(MemSD);
1161-
if (canLowerToLDG(MemSD, *Subtarget, CodeAddrSpace))
1161+
if (canLowerToLDG(*MemSD, *Subtarget, CodeAddrSpace))
11621162
return tryLDGLDU(N);
11631163

11641164
EVT EltVT = N->getValueType(0);

llvm/lib/Target/NVPTX/NVPTXTagInvariantLoads.cpp

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,32 @@
2727

2828
using namespace llvm;
2929

30+
static bool isInvariantLoad(const LoadInst *LI, const bool IsKernelFn) {
31+
// Don't bother with non-global loads
32+
if (LI->getPointerAddressSpace() != NVPTXAS::ADDRESS_SPACE_GLOBAL)
33+
return false;
34+
35+
// If the load is already marked as invariant, we don't need to do anything
36+
if (LI->getMetadata(LLVMContext::MD_invariant_load))
37+
return false;
38+
39+
// We use getUnderlyingObjects() here instead of getUnderlyingObject()
40+
// mainly because the former looks through phi nodes while the latter does
41+
// not. We need to look through phi nodes to handle pointer induction
42+
// variables.
43+
SmallVector<const Value *, 8> Objs;
44+
getUnderlyingObjects(LI->getPointerOperand(), Objs);
45+
46+
return all_of(Objs, [&](const Value *V) {
47+
if (const auto *A = dyn_cast<const Argument>(V))
48+
return IsKernelFn && ((A->onlyReadsMemory() && A->hasNoAliasAttr()) ||
49+
isParamGridConstant(*A));
50+
if (const auto *GV = dyn_cast<const GlobalVariable>(V))
51+
return GV->isConstant();
52+
return false;
53+
});
54+
}
55+
3056
static void markLoadsAsInvariant(LoadInst *LI) {
3157
LI->setMetadata(LLVMContext::MD_invariant_load,
3258
MDNode::get(LI->getContext(), {}));
@@ -38,38 +64,12 @@ static bool tagInvariantLoads(Function &F) {
3864
bool Changed = false;
3965
for (auto &I : instructions(F)) {
4066
if (auto *LI = dyn_cast<LoadInst>(&I)) {
41-
42-
// Don't bother with non-global loads
43-
if (LI->getPointerAddressSpace() != NVPTXAS::ADDRESS_SPACE_GLOBAL)
44-
continue;
45-
46-
if (LI->getMetadata(LLVMContext::MD_invariant_load))
47-
continue;
48-
49-
SmallVector<const Value *, 8> Objs;
50-
51-
// We use getUnderlyingObjects() here instead of getUnderlyingObject()
52-
// mainly because the former looks through phi nodes while the latter does
53-
// not. We need to look through phi nodes to handle pointer induction
54-
// variables.
55-
56-
getUnderlyingObjects(LI->getPointerOperand(), Objs);
57-
58-
const bool IsInvariant = all_of(Objs, [&](const Value *V) {
59-
if (const auto *A = dyn_cast<const Argument>(V))
60-
return IsKernelFn && A->onlyReadsMemory() && A->hasNoAliasAttr();
61-
if (const auto *GV = dyn_cast<const GlobalVariable>(V))
62-
return GV->isConstant();
63-
return false;
64-
});
65-
66-
if (IsInvariant) {
67+
if (isInvariantLoad(LI, IsKernelFn)) {
6768
markLoadsAsInvariant(LI);
6869
Changed = true;
6970
}
7071
}
7172
}
72-
7373
return Changed;
7474
}
7575

0 commit comments

Comments
 (0)