Skip to content

Commit d50151b

Browse files
committed
[LAA] Strip findForkedPointer (NFC)
Remove a level of indirection due to findForkedPointer, in an effort to improve code.
1 parent 96d5567 commit d50151b

File tree

1 file changed

+32
-39
lines changed

1 file changed

+32
-39
lines changed

llvm/lib/Analysis/LoopAccessAnalysis.cpp

Lines changed: 32 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1231,48 +1231,44 @@ static void findForkedSCEVs(
12311231
}
12321232
}
12331233

1234-
static SmallVector<PointerIntPair<const SCEV *, 1, bool>>
1235-
findForkedPointer(PredicatedScalarEvolution &PSE,
1236-
const DenseMap<Value *, const SCEV *> &StridesMap, Value *Ptr,
1237-
const Loop *L) {
1238-
ScalarEvolution *SE = PSE.getSE();
1239-
assert(SE->isSCEVable(Ptr->getType()) && "Value is not SCEVable!");
1240-
SmallVector<PointerIntPair<const SCEV *, 1, bool>> Scevs;
1241-
findForkedSCEVs(SE, L, Ptr, Scevs, MaxForkedSCEVDepth);
1242-
1243-
// For now, we will only accept a forked pointer with two possible SCEVs
1244-
// that are either SCEVAddRecExprs or loop invariant.
1245-
if (Scevs.size() == 2 &&
1246-
(isa<SCEVAddRecExpr>(get<0>(Scevs[0])) ||
1247-
SE->isLoopInvariant(get<0>(Scevs[0]), L)) &&
1248-
(isa<SCEVAddRecExpr>(get<0>(Scevs[1])) ||
1249-
SE->isLoopInvariant(get<0>(Scevs[1]), L))) {
1250-
LLVM_DEBUG(dbgs() << "LAA: Found forked pointer: " << *Ptr << "\n");
1251-
LLVM_DEBUG(dbgs() << "\t(1) " << *get<0>(Scevs[0]) << "\n");
1252-
LLVM_DEBUG(dbgs() << "\t(2) " << *get<0>(Scevs[1]) << "\n");
1253-
return Scevs;
1254-
}
1255-
1256-
return {{replaceSymbolicStrideSCEV(PSE, StridesMap, Ptr), false}};
1257-
}
1258-
12591234
bool AccessAnalysis::createCheckForAccess(
12601235
RuntimePointerChecking &RtCheck, MemAccessInfo Access, Type *AccessTy,
12611236
const DenseMap<Value *, const SCEV *> &StridesMap,
12621237
DenseMap<Value *, unsigned> &DepSetId, Loop *TheLoop,
12631238
unsigned &RunningDepId, unsigned ASId, bool Assume) {
12641239
Value *Ptr = Access.getPointer();
1240+
ScalarEvolution *SE = PSE.getSE();
1241+
assert(SE->isSCEVable(Ptr->getType()) && "Value is not SCEVable!");
1242+
1243+
SmallVector<PointerIntPair<const SCEV *, 1, bool>> RTCheckPtrs;
1244+
findForkedSCEVs(SE, TheLoop, Ptr, RTCheckPtrs, MaxForkedSCEVDepth);
1245+
assert(!RTCheckPtrs.empty() &&
1246+
"Must have some runtime-check pointer candidates");
12651247

1266-
SmallVector<PointerIntPair<const SCEV *, 1, bool>> TranslatedPtrs =
1267-
findForkedPointer(PSE, StridesMap, Ptr, TheLoop);
1268-
assert(!TranslatedPtrs.empty() && "must have some translated pointers");
1248+
auto IsLoopInvariantOrAR =
1249+
[&SE, &TheLoop](const PointerIntPair<const SCEV *, 1, bool> &P) {
1250+
return SE->isLoopInvariant(get<0>(P), TheLoop) ||
1251+
isa<SCEVAddRecExpr>(get<0>(P));
1252+
};
1253+
1254+
// RTCheckPtrs must have size 2 if there are forked pointers. Otherwise, there
1255+
// are no forked pointers; replaceSymbolicStridesSCEV in this case.
1256+
if (RTCheckPtrs.size() == 2 && all_of(RTCheckPtrs, IsLoopInvariantOrAR)) {
1257+
LLVM_DEBUG(dbgs() << "LAA: Found forked pointer: " << *Ptr << "\n");
1258+
for (auto [Idx, Q] : enumerate(RTCheckPtrs))
1259+
LLVM_DEBUG(dbgs() << "\t(" << Idx << ") " << *Q.getPointer() << "\n");
1260+
} else {
1261+
RTCheckPtrs.truncate(1);
1262+
RTCheckPtrs.front().setPointer(
1263+
replaceSymbolicStrideSCEV(PSE, StridesMap, Ptr));
1264+
RTCheckPtrs.front().setInt(false);
1265+
}
12691266

12701267
/// Check whether all pointers can participate in a runtime bounds check. They
1271-
/// must either be invariant or AddRecs. If ShouldCheckWrap is true, they also
1272-
/// must not wrap.
1273-
for (auto &P : TranslatedPtrs) {
1268+
/// must either be invariant or non-wrapping affine AddRecs.
1269+
for (auto &P : RTCheckPtrs) {
12741270
// The bounds for loop-invariant pointer is trivial.
1275-
if (PSE.getSE()->isLoopInvariant(P.getPointer(), TheLoop))
1271+
if (SE->isLoopInvariant(P.getPointer(), TheLoop))
12761272
continue;
12771273

12781274
const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(P.getPointer());
@@ -1283,21 +1279,18 @@ bool AccessAnalysis::createCheckForAccess(
12831279

12841280
// If there's only one option for Ptr, look it up after bounds and wrap
12851281
// checking, because assumptions might have been added to PSE.
1286-
if (TranslatedPtrs.size() == 1) {
1282+
if (RTCheckPtrs.size() == 1) {
12871283
AR =
12881284
cast<SCEVAddRecExpr>(replaceSymbolicStrideSCEV(PSE, StridesMap, Ptr));
12891285
P.setPointer(AR);
12901286
}
12911287

1292-
// When we run after a failing dependency check we have to make sure
1293-
// we don't have wrapping pointers.
1294-
if (!isNoWrap(PSE, AR, TranslatedPtrs.size() == 1 ? Ptr : nullptr, AccessTy,
1295-
TheLoop, Assume)) {
1288+
if (!isNoWrap(PSE, AR, RTCheckPtrs.size() == 1 ? Ptr : nullptr, AccessTy,
1289+
TheLoop, Assume))
12961290
return false;
1297-
}
12981291
}
12991292

1300-
for (auto [PtrExpr, NeedsFreeze] : TranslatedPtrs) {
1293+
for (const auto &[PtrExpr, NeedsFreeze] : RTCheckPtrs) {
13011294
// The id of the dependence set.
13021295
unsigned DepId;
13031296

0 commit comments

Comments
 (0)