Skip to content

Commit 519e5da

Browse files
committed
Address review
1 parent 5cd00d0 commit 519e5da

File tree

5 files changed

+58
-68
lines changed

5 files changed

+58
-68
lines changed

llvm/include/llvm/Analysis/CaptureTracking.h

Lines changed: 18 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
#define LLVM_ANALYSIS_CAPTURETRACKING_H
1515

1616
#include "llvm/ADT/DenseMap.h"
17-
#include "llvm/Support/ModRef.h"
1817

1918
namespace llvm {
2019

@@ -83,6 +82,21 @@ namespace llvm {
8382
/// addition to the interface here, you'll need to provide your own getters
8483
/// to see whether anything was captured.
8584
struct CaptureTracker {
85+
/// Action returned from captures().
86+
enum Action {
87+
/// Stop the traversal.
88+
Stop,
89+
/// Continue traversal, and also follow the return value of the user if
90+
/// it has additional capture components (that is, if it has capture
91+
/// components in Ret that are not part of Other).
92+
Continue,
93+
/// Continue traversal, but do not follow the return value of the user,
94+
/// even if it has additional capture components. Should only be used if
95+
/// captures() has already taken the potential return captures into
96+
/// account.
97+
ContinueIgnoringReturn,
98+
};
99+
86100
virtual ~CaptureTracker();
87101

88102
/// tooManyUses - The depth of traversal has breached a limit. There may be
@@ -96,38 +110,12 @@ namespace llvm {
96110
/// U->getUser() is always an Instruction.
97111
virtual bool shouldExplore(const Use *U);
98112

99-
/// When returned from captures(), stop the traversal.
100-
static std::optional<CaptureComponents> stop() { return std::nullopt; }
101-
102-
/// When returned from captures(), continue traversal, but do not follow
103-
/// the return value of this user, even if it has additional capture
104-
/// components. Should only be used if captures() has already taken the
105-
/// potential return caputres into account.
106-
static std::optional<CaptureComponents> continueIgnoringReturn() {
107-
return CaptureComponents::None;
108-
}
109-
110-
/// When returned from captures(), continue traversal, and also follow
111-
/// the return value of this user if it has additional capture components
112-
/// (that is, capture components in Ret that are not part of Other).
113-
static std::optional<CaptureComponents> continueDefault(CaptureInfo CI) {
114-
CaptureComponents RetCC = CI.getRetComponents();
115-
if (!capturesNothing(RetCC & ~CI.getOtherComponents()))
116-
return RetCC;
117-
return CaptureComponents::None;
118-
}
119-
120113
/// Use U directly captures CI.getOtherComponents() and additionally
121114
/// CI.getRetComponents() through the return value of the user of U.
122115
///
123-
/// Return std::nullopt to stop the traversal, or the CaptureComponents to
124-
/// follow via the return value, which must be a subset of
125-
/// CI.getRetComponents().
126-
///
127-
/// For convenience, prefer returning one of stop(), continueDefault(CI) or
128-
/// continueIgnoringReturn().
129-
virtual std::optional<CaptureComponents> captured(const Use *U,
130-
CaptureInfo CI) = 0;
116+
/// Return one of Stop, Continue or ContinueIgnoringReturn to control
117+
/// further traversal.
118+
virtual Action captured(const Use *U, CaptureInfo CI) = 0;
131119

132120
/// isDereferenceableOrNull - Overload to allow clients with additional
133121
/// knowledge about pointer dereferenceability to provide it and thereby

llvm/lib/Analysis/CaptureTracking.cpp

Lines changed: 24 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -81,16 +81,15 @@ struct SimpleCaptureTracker : public CaptureTracker {
8181
Captured = true;
8282
}
8383

84-
std::optional<CaptureComponents> captured(const Use *U,
85-
CaptureInfo CI) override {
84+
Action captured(const Use *U, CaptureInfo CI) override {
8685
// TODO(captures): Use CaptureInfo.
8786
if (isa<ReturnInst>(U->getUser()) && !ReturnCaptures)
88-
return continueIgnoringReturn();
87+
return ContinueIgnoringReturn;
8988

9089
LLVM_DEBUG(dbgs() << "Captured by: " << *U->getUser() << "\n");
9190

9291
Captured = true;
93-
return stop();
92+
return Stop;
9493
}
9594

9695
bool ReturnCaptures;
@@ -124,22 +123,21 @@ struct CapturesBefore : public CaptureTracker {
124123
return !isPotentiallyReachable(I, BeforeHere, nullptr, DT, LI);
125124
}
126125

127-
std::optional<CaptureComponents> captured(const Use *U,
128-
CaptureInfo CI) override {
126+
Action captured(const Use *U, CaptureInfo CI) override {
129127
// TODO(captures): Use CaptureInfo.
130128
Instruction *I = cast<Instruction>(U->getUser());
131129
if (isa<ReturnInst>(I) && !ReturnCaptures)
132-
return continueIgnoringReturn();
130+
return ContinueIgnoringReturn;
133131

134132
// Check isSafeToPrune() here rather than in shouldExplore() to avoid
135133
// an expensive reachability query for every instruction we look at.
136134
// Instead we only do one for actual capturing candidates.
137135
if (isSafeToPrune(I))
138136
// If the use is not reachable, the instruction result isn't either.
139-
return continueIgnoringReturn();
137+
return ContinueIgnoringReturn;
140138

141139
Captured = true;
142-
return stop();
140+
return Stop;
143141
}
144142

145143
const Instruction *BeforeHere;
@@ -171,12 +169,11 @@ struct EarliestCaptures : public CaptureTracker {
171169
EarliestCapture = &*F.getEntryBlock().begin();
172170
}
173171

174-
std::optional<CaptureComponents> captured(const Use *U,
175-
CaptureInfo CI) override {
172+
Action captured(const Use *U, CaptureInfo CI) override {
176173
// TODO(captures): Use CaptureInfo.
177174
Instruction *I = cast<Instruction>(U->getUser());
178175
if (isa<ReturnInst>(I) && !ReturnCaptures)
179-
return continueIgnoringReturn();
176+
return ContinueIgnoringReturn;
180177

181178
if (!EarliestCapture)
182179
EarliestCapture = I;
@@ -187,7 +184,7 @@ struct EarliestCaptures : public CaptureTracker {
187184
// Continue analysis, as we need to see all potential captures. However,
188185
// we do not need to follow the instruction result, as this use will
189186
// dominate any captures made through the instruction result..
190-
return continueIgnoringReturn();
187+
return ContinueIgnoringReturn;
191188
}
192189

193190
Instruction *EarliestCapture = nullptr;
@@ -451,17 +448,24 @@ void llvm::PointerMayBeCaptured(const Value *V, CaptureTracker *Tracker,
451448
CaptureInfo CI = DetermineUseCaptureKind(*U, IsDereferenceableOrNull);
452449
if (capturesNothing(CI))
453450
continue;
451+
CaptureComponents OtherCC = CI.getOtherComponents();
454452
CaptureComponents RetCC = CI.getRetComponents();
455-
if (!capturesNothing(CI.getOtherComponents())) {
456-
std::optional<CaptureComponents> Res = Tracker->captured(U, CI);
457-
if (!Res)
453+
if (capturesAnything(OtherCC)) {
454+
switch (Tracker->captured(U, CI)) {
455+
case CaptureTracker::Stop:
458456
return;
459-
assert(capturesNothing(*Res & ~RetCC) &&
460-
"captures() result must be subset of getRetComponents()");
461-
RetCC = *Res;
457+
case CaptureTracker::ContinueIgnoringReturn:
458+
continue;
459+
case CaptureTracker::Continue:
460+
// Fall through to passthrough handling, but only if RetCC contains
461+
// additional components that OtherCC does not.
462+
if (capturesNothing(RetCC & ~OtherCC))
463+
continue;
464+
break;
465+
}
462466
}
463467
// TODO(captures): We could keep track of RetCC for the users.
464-
if (!capturesNothing(RetCC) && !AddUses(U->getUser()))
468+
if (capturesAnything(RetCC) && !AddUses(U->getUser()))
465469
return;
466470
}
467471

llvm/lib/Analysis/InstructionSimplify.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2788,8 +2788,7 @@ static Constant *computePointerICmp(CmpPredicate Pred, Value *LHS, Value *RHS,
27882788
struct CustomCaptureTracker : public CaptureTracker {
27892789
bool Captured = false;
27902790
void tooManyUses() override { Captured = true; }
2791-
std::optional<CaptureComponents> captured(const Use *U,
2792-
CaptureInfo CI) override {
2791+
Action captured(const Use *U, CaptureInfo CI) override {
27932792
// TODO(captures): Use CaptureInfo.
27942793
if (auto *ICmp = dyn_cast<ICmpInst>(U->getUser())) {
27952794
// Comparison against value stored in global variable. Given the
@@ -2798,11 +2797,11 @@ static Constant *computePointerICmp(CmpPredicate Pred, Value *LHS, Value *RHS,
27982797
unsigned OtherIdx = 1 - U->getOperandNo();
27992798
auto *LI = dyn_cast<LoadInst>(ICmp->getOperand(OtherIdx));
28002799
if (LI && isa<GlobalVariable>(LI->getPointerOperand()))
2801-
return continueDefault(CI);
2800+
return Continue;
28022801
}
28032802

28042803
Captured = true;
2805-
return stop();
2804+
return Stop;
28062805
}
28072806
};
28082807
CustomCaptureTracker Tracker;

llvm/lib/Transforms/IPO/FunctionAttrs.cpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,8 @@ using namespace llvm;
7272

7373
STATISTIC(NumMemoryAttr, "Number of functions with improved memory attribute");
7474
STATISTIC(NumCapturesNone, "Number of arguments marked captures(none)");
75-
STATISTIC(NumCapturesOther, "Number of arguments marked with captures "
76-
"attribute other than captures(none)");
75+
STATISTIC(NumCapturesPartial, "Number of arguments marked with captures "
76+
"attribute other than captures(none)");
7777
STATISTIC(NumReturned, "Number of arguments marked returned");
7878
STATISTIC(NumReadNoneArg, "Number of arguments marked readnone");
7979
STATISTIC(NumReadOnlyArg, "Number of arguments marked readonly");
@@ -114,7 +114,7 @@ static void addCapturesStat(CaptureInfo CI) {
114114
if (capturesNothing(CI))
115115
++NumCapturesNone;
116116
else
117-
++NumCapturesOther;
117+
++NumCapturesPartial;
118118
}
119119

120120
namespace {
@@ -549,18 +549,17 @@ struct ArgumentUsesTracker : public CaptureTracker {
549549

550550
void tooManyUses() override { CI = CaptureInfo::all(); }
551551

552-
std::optional<CaptureComponents> captured(const Use *U,
553-
CaptureInfo UseCI) override {
552+
Action captured(const Use *U, CaptureInfo UseCI) override {
554553
if (updateCaptureInfo(U, UseCI.getOtherComponents())) {
555554
// Don't bother continuing if we already capture everything.
556555
if (capturesAll(CI.getOtherComponents()))
557-
return stop();
558-
return continueDefault(UseCI);
556+
return Stop;
557+
return Continue;
559558
}
560559

561560
// For SCC argument tracking, we're not going to analyze other/ret
562561
// components separately, so don't follow the return value.
563-
return continueIgnoringReturn();
562+
return ContinueIgnoringReturn;
564563
}
565564

566565
bool updateCaptureInfo(const Use *U, CaptureComponents CC) {
@@ -1329,7 +1328,8 @@ static void addArgumentAttrs(const SCCNodeSet &SCCNodes,
13291328
}
13301329

13311330
// Infer the access attributes given the new captures one
1332-
DetermineAccessAttrsForSingleton(A);
1331+
if (DetermineAccessAttrsForSingleton(A))
1332+
Changed.insert(A->getParent());
13331333
}
13341334
continue;
13351335
}
@@ -1369,7 +1369,7 @@ static void addArgumentAttrs(const SCCNodeSet &SCCNodes,
13691369
}
13701370

13711371
// TODO(captures): Ignore address-only captures.
1372-
if (!capturesNothing(CC)) {
1372+
if (capturesAnything(CC)) {
13731373
// As the pointer may be captured, determine the pointer attributes
13741374
// looking at each argument invidivually.
13751375
for (ArgumentGraphNode *N : ArgumentSCC) {

llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -882,8 +882,7 @@ bool InstCombinerImpl::foldAllocaCmp(AllocaInst *Alloca) {
882882

883883
void tooManyUses() override { Captured = true; }
884884

885-
std::optional<CaptureComponents> captured(const Use *U,
886-
CaptureInfo CI) override {
885+
Action captured(const Use *U, CaptureInfo CI) override {
887886
// TODO(captures): Use CaptureInfo.
888887
auto *ICmp = dyn_cast<ICmpInst>(U->getUser());
889888
// We need to check that U is based *only* on the alloca, and doesn't
@@ -894,11 +893,11 @@ bool InstCombinerImpl::foldAllocaCmp(AllocaInst *Alloca) {
894893
// Collect equality icmps of the alloca, and don't treat them as
895894
// captures.
896895
ICmps[ICmp] |= 1u << U->getOperandNo();
897-
return continueDefault(CI);
896+
return Continue;
898897
}
899898

900899
Captured = true;
901-
return stop();
900+
return Stop;
902901
}
903902
};
904903

0 commit comments

Comments
 (0)