Skip to content

Commit e86d562

Browse files
committed
[LAA] Detect hoistable uniform load/store IV pattern
Introduces isInvariantLoadHoistable that identifies when a loop-invariant load can be safely hoisted above the loop using MemorySSA and ScalarEvolution analysis. It assumes no aliasing or a alias check. * Confirms that the load address is loop-invariant * Searches the loop header for exactly one must-alias load and one must-alias store to the same memory location. * Ensures that both are non-volatile and that the MemorySSA clobber chain between them does not contain conflicting must-alias definitions. * Verifies that the store’s value has an invariant SCEV step relative to the load. * Ensures that instructions that represent the SCEV doesn't external users outside this slice. Will be used to enable selective invariant load hoisting in LoopVectorize where LAA can precisely prove safety.
1 parent 74f06ab commit e86d562

File tree

11 files changed

+314
-33
lines changed

11 files changed

+314
-33
lines changed

llvm/include/llvm/Analysis/LoopAccessAnalysis.h

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ class DataLayout;
2828
class Loop;
2929
class raw_ostream;
3030
class TargetTransformInfo;
31+
class MemorySSA;
3132

3233
/// Collection of parameters shared beetween the Loop Vectorizer and the
3334
/// Loop Access Analysis.
@@ -181,11 +182,12 @@ class MemoryDepChecker {
181182
};
182183

183184
MemoryDepChecker(PredicatedScalarEvolution &PSE, AssumptionCache *AC,
184-
DominatorTree *DT, const Loop *L,
185+
MemorySSA *MSSA, DominatorTree *DT, AAResults *AA,
186+
const Loop *L,
185187
const DenseMap<Value *, const SCEV *> &SymbolicStrides,
186188
unsigned MaxTargetVectorWidthInBits,
187189
std::optional<ScalarEvolution::LoopGuards> &LoopGuards)
188-
: PSE(PSE), AC(AC), DT(DT), InnermostLoop(L),
190+
: PSE(PSE), AC(AC), DT(DT), MSSA(MSSA), AA(AA), InnermostLoop(L),
189191
SymbolicStrides(SymbolicStrides),
190192
MaxTargetVectorWidthInBits(MaxTargetVectorWidthInBits),
191193
LoopGuards(LoopGuards) {}
@@ -292,6 +294,14 @@ class MemoryDepChecker {
292294
return PointerBounds;
293295
}
294296

297+
/// Return if a Load can be hoisted in this loop with a pattern of a
298+
/// memory induction variable. This assumes a alias runtime check
299+
/// will be used before hoisting.
300+
bool
301+
isInvariantLoadHoistable(LoadInst *L, ScalarEvolution &SE, StoreInst **S,
302+
const SCEV **Step,
303+
SmallVectorImpl<Instruction *> *Instructions) const;
304+
295305
DominatorTree *getDT() const {
296306
assert(DT && "requested DT, but it is not available");
297307
return DT;
@@ -312,6 +322,8 @@ class MemoryDepChecker {
312322

313323
AssumptionCache *AC;
314324
DominatorTree *DT;
325+
MemorySSA *MSSA;
326+
AAResults *AA;
315327

316328
const Loop *InnermostLoop;
317329

@@ -692,7 +704,7 @@ class LoopAccessInfo {
692704
const TargetTransformInfo *TTI,
693705
const TargetLibraryInfo *TLI, AAResults *AA,
694706
DominatorTree *DT, LoopInfo *LI, AssumptionCache *AC,
695-
bool AllowPartial = false);
707+
MemorySSA *MSSA, bool AllowPartial = false);
696708

697709
/// Return true we can analyze the memory accesses in the loop and there are
698710
/// no memory dependence cycles. Note that for dependences between loads &
@@ -786,7 +798,8 @@ class LoopAccessInfo {
786798
/// Analyze the loop. Returns true if all memory access in the loop can be
787799
/// vectorized.
788800
bool analyzeLoop(AAResults *AA, const LoopInfo *LI,
789-
const TargetLibraryInfo *TLI, DominatorTree *DT);
801+
const TargetLibraryInfo *TLI, DominatorTree *DT,
802+
MemorySSA *MSSA);
790803

791804
/// Check if the structure of the loop allows it to be analyzed by this
792805
/// pass.
@@ -963,12 +976,15 @@ class LoopAccessInfoManager {
963976
TargetTransformInfo *TTI;
964977
const TargetLibraryInfo *TLI = nullptr;
965978
AssumptionCache *AC;
979+
MemorySSA *MSSA;
966980

967981
public:
968982
LoopAccessInfoManager(ScalarEvolution &SE, AAResults &AA, DominatorTree &DT,
969983
LoopInfo &LI, TargetTransformInfo *TTI,
970-
const TargetLibraryInfo *TLI, AssumptionCache *AC)
971-
: SE(SE), AA(AA), DT(DT), LI(LI), TTI(TTI), TLI(TLI), AC(AC) {}
984+
const TargetLibraryInfo *TLI, AssumptionCache *AC,
985+
MemorySSA *MSSA)
986+
: SE(SE), AA(AA), DT(DT), LI(LI), TTI(TTI), TLI(TLI), AC(AC), MSSA(MSSA) {
987+
}
972988

973989
LLVM_ABI const LoopAccessInfo &getInfo(Loop &L, bool AllowPartial = false);
974990

llvm/lib/Analysis/LoopAccessAnalysis.cpp

Lines changed: 237 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
#include "llvm/Analysis/LoopInfo.h"
3030
#include "llvm/Analysis/LoopIterator.h"
3131
#include "llvm/Analysis/MemoryLocation.h"
32+
#include "llvm/Analysis/MemorySSA.h"
3233
#include "llvm/Analysis/OptimizationRemarkEmitter.h"
3334
#include "llvm/Analysis/ScalarEvolution.h"
3435
#include "llvm/Analysis/ScalarEvolutionExpressions.h"
@@ -1777,6 +1778,232 @@ bool llvm::isConsecutiveAccess(Value *A, Value *B, const DataLayout &DL,
17771778
return Diff == 1;
17781779
}
17791780

1781+
/// Collects all subexpressions that appear within a given SCEV tree.
1782+
struct SCEVSubexprCollector : public SCEVVisitor<SCEVSubexprCollector, void> {
1783+
SmallPtrSet<const SCEV *, 4> &Subs;
1784+
SCEVSubexprCollector(SmallPtrSet<const SCEV *, 4> &S) : Subs(S) {}
1785+
1786+
template <typename Operands> void visitOperands(Operands operands) {
1787+
for (auto *Op : operands)
1788+
visit(Op);
1789+
}
1790+
void visitConstant(const SCEVConstant *C) { Subs.insert(C); }
1791+
void visitUnknown(const SCEVUnknown *U) { Subs.insert(U); }
1792+
void visitAddExpr(const SCEVAddExpr *E) {
1793+
Subs.insert(E);
1794+
for (auto *Op : E->operands())
1795+
visit(Op);
1796+
}
1797+
void visitMulExpr(const SCEVMulExpr *E) {
1798+
Subs.insert(E);
1799+
for (auto *Op : E->operands())
1800+
visit(Op);
1801+
}
1802+
void visitAddRecExpr(const SCEVAddRecExpr *E) {
1803+
Subs.insert(E);
1804+
for (auto *Op : E->operands())
1805+
visit(Op);
1806+
}
1807+
void visitSMaxExpr(const SCEVSMaxExpr *E) {
1808+
Subs.insert(E);
1809+
for (auto *Op : E->operands())
1810+
visit(Op);
1811+
}
1812+
void visitSMinExpr(const SCEVSMinExpr *E) {
1813+
Subs.insert(E);
1814+
for (auto *Op : E->operands())
1815+
visit(Op);
1816+
}
1817+
void visitUMinExpr(const SCEVUMinExpr *E) {
1818+
Subs.insert(E);
1819+
for (auto *Op : E->operands())
1820+
visit(Op);
1821+
}
1822+
void visitUMaxExpr(const SCEVUMaxExpr *E) {
1823+
Subs.insert(E);
1824+
for (auto *Op : E->operands())
1825+
visit(Op);
1826+
}
1827+
void visitMinMaxExpr(const SCEVMinMaxExpr *E) {
1828+
Subs.insert(E);
1829+
for (auto *Op : E->operands())
1830+
visit(Op);
1831+
}
1832+
void visitUDivExpr(const SCEVUDivExpr *E) {
1833+
Subs.insert(E);
1834+
visit(E->getLHS());
1835+
visit(E->getRHS());
1836+
}
1837+
void visitZeroExtendExpr(const SCEVZeroExtendExpr *E) {
1838+
Subs.insert(E);
1839+
visit(E->getOperand());
1840+
}
1841+
void visitSignExtendExpr(const SCEVSignExtendExpr *E) {
1842+
Subs.insert(E);
1843+
visit(E->getOperand());
1844+
}
1845+
void visitTruncateExpr(const SCEVTruncateExpr *E) {
1846+
Subs.insert(E);
1847+
visit(E->getOperand());
1848+
}
1849+
void visitCouldNotCompute(const SCEVCouldNotCompute *E) { Subs.insert(E); }
1850+
void visitVScale(const SCEVVScale *E) {
1851+
Subs.insert(E);
1852+
visitOperands(E->operands());
1853+
}
1854+
void visitPtrToIntExpr(const SCEVPtrToIntExpr *E) {
1855+
Subs.insert(E);
1856+
visitOperands(E->operands());
1857+
}
1858+
void visitSequentialUMinExpr(const SCEVSequentialUMinExpr *E) {
1859+
Subs.insert(E);
1860+
visitOperands(E->operands());
1861+
}
1862+
};
1863+
1864+
bool MemoryDepChecker::isInvariantLoadHoistable(
1865+
LoadInst *L, ScalarEvolution &SE, StoreInst **S, const SCEV **StepSCEV,
1866+
SmallVectorImpl<Instruction *> *Instructions) const {
1867+
assert(L != nullptr);
1868+
assert(InnermostLoop->isLoopInvariant(L->getPointerOperand()));
1869+
1870+
if (!MSSA)
1871+
return false;
1872+
1873+
MemoryAccess *MA = MSSA->getMemoryAccess(L);
1874+
auto QLoc = MemoryLocation::get(L);
1875+
1876+
SmallVector<StoreInst *> Stores;
1877+
SmallVector<LoadInst *> Loads;
1878+
1879+
for (auto &&I : *InnermostLoop->getHeader()) {
1880+
if (auto *Store = dyn_cast<StoreInst>(&I)) {
1881+
AliasResult AR = AA->alias(MemoryLocation::get(Store), QLoc);
1882+
if (AR == AliasResult::MustAlias)
1883+
Stores.push_back(Store);
1884+
}
1885+
if (auto *Load = dyn_cast<LoadInst>(&I)) {
1886+
AliasResult AR = AA->alias(MemoryLocation::get(Load), QLoc);
1887+
if (AR == AliasResult::MustAlias)
1888+
Loads.push_back(Load);
1889+
}
1890+
}
1891+
1892+
if (Loads.size() != 1 || Loads[0]->isVolatile() || Stores.size() != 1 ||
1893+
Stores[0]->isVolatile())
1894+
return false;
1895+
1896+
// I have the memory PHI, so I know where is the backedge
1897+
// I have to find all memory accesses to the same cell (that I care)
1898+
// There should be a single memory use and a single memorydef
1899+
// memory use should have MemoryPhi as transitive clobber
1900+
// backedge should have the MemoryDef as a transitive clobber (must-alias) (?)
1901+
MemoryAccess *Clobber = MSSA->getWalker()->getClobberingMemoryAccess(MA);
1902+
while (auto *MD = dyn_cast<MemoryUseOrDef>(Clobber)) {
1903+
Instruction *DefI = MD->getMemoryInst();
1904+
1905+
if (!DefI)
1906+
return false;
1907+
1908+
AliasResult AR = AA->alias(MemoryLocation::get(DefI), QLoc);
1909+
1910+
Clobber = MD->getDefiningAccess();
1911+
1912+
// We assume runtime aliasing check will be used
1913+
if (AR == AliasResult::MustAlias)
1914+
return false;
1915+
}
1916+
1917+
MemoryAccess *MS = MSSA->getMemoryAccess(Stores[0]);
1918+
MemoryAccess *StoreClobber = MSSA->getWalker()->getClobberingMemoryAccess(MS);
1919+
while (true) {
1920+
if (isa<MemoryPhi>(StoreClobber))
1921+
break;
1922+
if (auto *MD = dyn_cast<MemoryUseOrDef>(StoreClobber)) {
1923+
Instruction *DefI = MD->getMemoryInst();
1924+
1925+
if (!DefI)
1926+
return false;
1927+
1928+
AliasResult AR = AA->alias(MemoryLocation::get(DefI), QLoc);
1929+
1930+
StoreClobber = MD->getDefiningAccess();
1931+
1932+
if (AR == AliasResult::MustAlias)
1933+
return false;
1934+
}
1935+
}
1936+
1937+
if (!SE.isSCEVable(Stores[0]->getValueOperand()->getType()))
1938+
return false;
1939+
1940+
const SCEV *LoadSCEV = SE.getUnknown(L);
1941+
const SCEV *StoreSCEV = SE.getSCEV(Stores[0]->getValueOperand());
1942+
1943+
auto Step = SE.getMinusSCEV(StoreSCEV, LoadSCEV);
1944+
1945+
if (isa<SCEVCouldNotCompute>(Step) ||
1946+
!SE.isLoopInvariant(Step, InnermostLoop))
1947+
return false;
1948+
1949+
SmallVector<Instruction *, 4> WL;
1950+
1951+
SmallPtrSet<Instruction *, 4> Slice;
1952+
SmallPtrSet<const SCEV *, 4> Subs;
1953+
SCEVSubexprCollector Collector(Subs);
1954+
Collector.visit(StoreSCEV);
1955+
1956+
// Register all instructions that matches the SCEV
1957+
// to allow its removal when hoisting it and
1958+
// re-expanding the SCEV
1959+
auto enqueueIfMatches = [&](Value *X) {
1960+
if (auto *XI = dyn_cast<Instruction>(X)) {
1961+
const SCEV *SX = SE.getSCEV(XI);
1962+
if (Subs.contains(SX) && Slice.insert(XI).second)
1963+
WL.push_back(XI);
1964+
}
1965+
};
1966+
1967+
enqueueIfMatches(Stores[0]->getValueOperand());
1968+
1969+
while (!WL.empty()) {
1970+
Instruction *I = WL.pop_back_val();
1971+
1972+
for (Value *Op : I->operands()) {
1973+
if (isa<Constant>(Op) || isa<Argument>(Op))
1974+
continue;
1975+
enqueueIfMatches(Op);
1976+
}
1977+
}
1978+
1979+
auto hasExternalUsers =
1980+
[&Stores](const SmallPtrSetImpl<Instruction *> &Slice) {
1981+
for (Instruction *I : Slice)
1982+
for (Use &U : I->uses())
1983+
if (auto *UserI = dyn_cast<Instruction>(U.getUser())) {
1984+
if (isa<DbgInfoIntrinsic>(UserI))
1985+
continue;
1986+
if (!Slice.count(UserI) &&
1987+
!std::count(Stores.begin(), Stores.end(), UserI))
1988+
return true;
1989+
}
1990+
return false;
1991+
};
1992+
1993+
if (hasExternalUsers(Slice))
1994+
return false;
1995+
1996+
if (S)
1997+
*S = Stores[0];
1998+
if (StepSCEV)
1999+
*StepSCEV = Step;
2000+
2001+
if (Instructions)
2002+
Instructions->insert(Instructions->end(), Slice.begin(), Slice.end());
2003+
2004+
return true;
2005+
}
2006+
17802007
void MemoryDepChecker::addAccess(StoreInst *SI) {
17812008
visitPointers(SI->getPointerOperand(), *InnermostLoop,
17822009
[this, SI](Value *Ptr) {
@@ -2505,7 +2732,7 @@ bool LoopAccessInfo::canAnalyzeLoop() {
25052732

25062733
bool LoopAccessInfo::analyzeLoop(AAResults *AA, const LoopInfo *LI,
25072734
const TargetLibraryInfo *TLI,
2508-
DominatorTree *DT) {
2735+
DominatorTree *DT, MemorySSA *MSSA) {
25092736
// Holds the Load and Store instructions.
25102737
SmallVector<LoadInst *, 16> Loads;
25112738
SmallVector<StoreInst *, 16> Stores;
@@ -3064,7 +3291,8 @@ LoopAccessInfo::LoopAccessInfo(Loop *L, ScalarEvolution *SE,
30643291
const TargetTransformInfo *TTI,
30653292
const TargetLibraryInfo *TLI, AAResults *AA,
30663293
DominatorTree *DT, LoopInfo *LI,
3067-
AssumptionCache *AC, bool AllowPartial)
3294+
AssumptionCache *AC, MemorySSA *MSSA,
3295+
bool AllowPartial)
30683296
: PSE(std::make_unique<PredicatedScalarEvolution>(*SE, *L)),
30693297
PtrRtChecking(nullptr), TheLoop(L), AllowPartial(AllowPartial) {
30703298
unsigned MaxTargetVectorWidthInBits = std::numeric_limits<unsigned>::max();
@@ -3075,11 +3303,12 @@ LoopAccessInfo::LoopAccessInfo(Loop *L, ScalarEvolution *SE,
30753303
TTI->getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector) * 2;
30763304

30773305
DepChecker = std::make_unique<MemoryDepChecker>(
3078-
*PSE, AC, DT, L, SymbolicStrides, MaxTargetVectorWidthInBits, LoopGuards);
3306+
*PSE, AC, MSSA, DT, AA, L, SymbolicStrides, MaxTargetVectorWidthInBits,
3307+
LoopGuards);
30793308
PtrRtChecking =
30803309
std::make_unique<RuntimePointerChecking>(*DepChecker, SE, LoopGuards);
30813310
if (canAnalyzeLoop())
3082-
CanVecMem = analyzeLoop(AA, LI, TLI, DT);
3311+
CanVecMem = analyzeLoop(AA, LI, TLI, DT, MSSA);
30833312
}
30843313

30853314
void LoopAccessInfo::print(raw_ostream &OS, unsigned Depth) const {
@@ -3145,7 +3374,7 @@ const LoopAccessInfo &LoopAccessInfoManager::getInfo(Loop &L,
31453374
// or if it was created with a different value of AllowPartial.
31463375
if (Inserted || It->second->hasAllowPartial() != AllowPartial)
31473376
It->second = std::make_unique<LoopAccessInfo>(&L, &SE, TTI, TLI, &AA, &DT,
3148-
&LI, AC, AllowPartial);
3377+
&LI, AC, MSSA, AllowPartial);
31493378

31503379
return *It->second;
31513380
}
@@ -3189,7 +3418,9 @@ LoopAccessInfoManager LoopAccessAnalysis::run(Function &F,
31893418
auto &TTI = FAM.getResult<TargetIRAnalysis>(F);
31903419
auto &TLI = FAM.getResult<TargetLibraryAnalysis>(F);
31913420
auto &AC = FAM.getResult<AssumptionAnalysis>(F);
3192-
return LoopAccessInfoManager(SE, AA, DT, LI, &TTI, &TLI, &AC);
3421+
auto &MSSA = FAM.getResult<MemorySSAAnalysis>(F);
3422+
return LoopAccessInfoManager(SE, AA, DT, LI, &TTI, &TLI, &AC,
3423+
&MSSA.getMSSA());
31933424
}
31943425

31953426
AnalysisKey LoopAccessAnalysis::Key;

llvm/lib/Transforms/Scalar/LoopFlatten.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1010,7 +1010,7 @@ PreservedAnalyses LoopFlattenPass::run(LoopNest &LN, LoopAnalysisManager &LAM,
10101010
// this pass will simplify all loops that contain inner loops,
10111011
// regardless of whether anything ends up being flattened.
10121012
LoopAccessInfoManager LAIM(AR.SE, AR.AA, AR.DT, AR.LI, &AR.TTI, nullptr,
1013-
&AR.AC);
1013+
&AR.AC, AR.MSSA);
10141014
for (Loop *InnerLoop : LN.getLoops()) {
10151015
auto *OuterLoop = InnerLoop->getParentLoop();
10161016
if (!OuterLoop)

llvm/lib/Transforms/Scalar/LoopVersioningLICM.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -549,7 +549,8 @@ PreservedAnalyses LoopVersioningLICMPass::run(Loop &L, LoopAnalysisManager &AM,
549549
const Function *F = L.getHeader()->getParent();
550550
OptimizationRemarkEmitter ORE(F);
551551

552-
LoopAccessInfoManager LAIs(*SE, *AA, *DT, LAR.LI, nullptr, nullptr, &LAR.AC);
552+
LoopAccessInfoManager LAIs(*SE, *AA, *DT, LAR.LI, nullptr, nullptr, &LAR.AC,
553+
LAR.MSSA);
553554
if (!LoopVersioningLICM(AA, SE, &ORE, LAIs, LAR.LI, &L).run(DT))
554555
return PreservedAnalyses::all();
555556
return getLoopPassPreservedAnalyses();

0 commit comments

Comments
 (0)