Skip to content

Commit 1bcd2ff

Browse files
committed
[UniformAnalysis] Use Immediate postDom as last join
Given a divergent block, computeJoinPoints uses FloorIdx to do early stopping. But it is incorrect for some cases (shown in the two new lit tests). This change uses the immediate post-dominator as the last join to check for early stopping. It adds post-dominator to genericUniformityImpl in order to get immediate postDom.
1 parent 690a30f commit 1bcd2ff

File tree

8 files changed

+233
-55
lines changed

8 files changed

+233
-55
lines changed

llvm/include/llvm/ADT/GenericSSAContext.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,10 @@ template <typename _FunctionT> class GenericSSAContext {
7777
// a given funciton.
7878
using DominatorTreeT = DominatorTreeBase<BlockT, false>;
7979

80+
// A post-dominator tree provides the post-dominance relation between
81+
// basic blocks in a given funciton.
82+
using PostDominatorTreeT = DominatorTreeBase<BlockT, true>;
83+
8084
GenericSSAContext() = default;
8185
GenericSSAContext(const FunctionT *F) : F(F) {}
8286

llvm/include/llvm/ADT/GenericUniformityImpl.h

Lines changed: 44 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,7 @@ template <typename ContextT> class GenericSyncDependenceAnalysis {
263263
public:
264264
using BlockT = typename ContextT::BlockT;
265265
using DominatorTreeT = typename ContextT::DominatorTreeT;
266+
using PostDominatorTreeT = typename ContextT::PostDominatorTreeT;
266267
using FunctionT = typename ContextT::FunctionT;
267268
using ValueRefT = typename ContextT::ValueRefT;
268269
using InstructionT = typename ContextT::InstructionT;
@@ -296,7 +297,9 @@ template <typename ContextT> class GenericSyncDependenceAnalysis {
296297
using DivergencePropagatorT = DivergencePropagator<ContextT>;
297298

298299
GenericSyncDependenceAnalysis(const ContextT &Context,
299-
const DominatorTreeT &DT, const CycleInfoT &CI);
300+
const DominatorTreeT &DT,
301+
const PostDominatorTreeT &PDT,
302+
const CycleInfoT &CI);
300303

301304
/// \brief Computes divergent join points and cycle exits caused by branch
302305
/// divergence in \p Term.
@@ -315,6 +318,7 @@ template <typename ContextT> class GenericSyncDependenceAnalysis {
315318
ModifiedPO CyclePO;
316319

317320
const DominatorTreeT &DT;
321+
const PostDominatorTreeT &PDT;
318322
const CycleInfoT &CI;
319323

320324
DenseMap<const BlockT *, std::unique_ptr<DivergenceDescriptor>>
@@ -336,6 +340,7 @@ template <typename ContextT> class GenericUniformityAnalysisImpl {
336340
using UseT = typename ContextT::UseT;
337341
using InstructionT = typename ContextT::InstructionT;
338342
using DominatorTreeT = typename ContextT::DominatorTreeT;
343+
using PostDominatorTreeT = typename ContextT::PostDominatorTreeT;
339344

340345
using CycleInfoT = GenericCycleInfo<ContextT>;
341346
using CycleT = typename CycleInfoT::CycleT;
@@ -348,10 +353,12 @@ template <typename ContextT> class GenericUniformityAnalysisImpl {
348353
using TemporalDivergenceTuple =
349354
std::tuple<ConstValueRefT, InstructionT *, const CycleT *>;
350355

351-
GenericUniformityAnalysisImpl(const DominatorTreeT &DT, const CycleInfoT &CI,
356+
GenericUniformityAnalysisImpl(const DominatorTreeT &DT,
357+
const PostDominatorTreeT &PDT,
358+
const CycleInfoT &CI,
352359
const TargetTransformInfo *TTI)
353360
: Context(CI.getSSAContext()), F(*Context.getFunction()), CI(CI),
354-
TTI(TTI), DT(DT), SDA(Context, DT, CI) {}
361+
TTI(TTI), DT(DT), PDT(PDT), SDA(Context, DT, PDT, CI) {}
355362

356363
void initialize();
357364

@@ -435,6 +442,7 @@ template <typename ContextT> class GenericUniformityAnalysisImpl {
435442

436443
private:
437444
const DominatorTreeT &DT;
445+
const PostDominatorTreeT &PDT;
438446

439447
// Recognized cycles with divergent exits.
440448
SmallPtrSet<const CycleT *, 16> DivergentExitCycles;
@@ -493,6 +501,7 @@ template <typename ContextT> class DivergencePropagator {
493501
public:
494502
using BlockT = typename ContextT::BlockT;
495503
using DominatorTreeT = typename ContextT::DominatorTreeT;
504+
using PostDominatorTreeT = typename ContextT::PostDominatorTreeT;
496505
using FunctionT = typename ContextT::FunctionT;
497506
using ValueRefT = typename ContextT::ValueRefT;
498507

@@ -507,6 +516,7 @@ template <typename ContextT> class DivergencePropagator {
507516

508517
const ModifiedPO &CyclePOT;
509518
const DominatorTreeT &DT;
519+
const PostDominatorTreeT &PDT;
510520
const CycleInfoT &CI;
511521
const BlockT &DivTermBlock;
512522
const ContextT &Context;
@@ -522,10 +532,11 @@ template <typename ContextT> class DivergencePropagator {
522532
BlockLabelMapT &BlockLabels;
523533

524534
DivergencePropagator(const ModifiedPO &CyclePOT, const DominatorTreeT &DT,
525-
const CycleInfoT &CI, const BlockT &DivTermBlock)
526-
: CyclePOT(CyclePOT), DT(DT), CI(CI), DivTermBlock(DivTermBlock),
527-
Context(CI.getSSAContext()), DivDesc(new DivergenceDescriptorT),
528-
BlockLabels(DivDesc->BlockLabels) {}
535+
const PostDominatorTreeT &PDT, const CycleInfoT &CI,
536+
const BlockT &DivTermBlock)
537+
: CyclePOT(CyclePOT), DT(DT), PDT(PDT), CI(CI),
538+
DivTermBlock(DivTermBlock), Context(CI.getSSAContext()),
539+
DivDesc(new DivergenceDescriptorT), BlockLabels(DivDesc->BlockLabels) {}
529540

530541
void printDefs(raw_ostream &Out) {
531542
Out << "Propagator::BlockLabels {\n";
@@ -542,6 +553,12 @@ template <typename ContextT> class DivergencePropagator {
542553
Out << "}\n";
543554
}
544555

556+
const BlockT *getIPDom(const BlockT *B) {
557+
const auto *Node = PDT.getNode(B);
558+
const auto *IPDomNode = Node->getIDom();
559+
return IPDomNode->getBlock();
560+
}
561+
545562
// Push a definition (\p PushedLabel) to \p SuccBlock and return whether this
546563
// causes a divergent join.
547564
bool computeJoin(const BlockT &SuccBlock, const BlockT &PushedLabel) {
@@ -610,10 +627,11 @@ template <typename ContextT> class DivergencePropagator {
610627
LLVM_DEBUG(dbgs() << "SDA:computeJoinPoints: "
611628
<< Context.print(&DivTermBlock) << "\n");
612629

613-
// Early stopping criterion
614-
int FloorIdx = CyclePOT.size() - 1;
615-
const BlockT *FloorLabel = nullptr;
616-
int DivTermIdx = CyclePOT.getIndex(&DivTermBlock);
630+
// Immediate Post-dominator of DivTermBlock is the last join
631+
// to visit.
632+
const auto *ImmPDom = getIPDom(&DivTermBlock);
633+
634+
LLVM_DEBUG(dbgs() << "Last join: " << Context.print(ImmPDom) << "\n");
617635

618636
// Bootstrap with branch targets
619637
auto const *DivTermCycle = CI.getCycle(&DivTermBlock);
@@ -626,34 +644,29 @@ template <typename ContextT> class DivergencePropagator {
626644
LLVM_DEBUG(dbgs() << "\tImmediate divergent cycle exit: "
627645
<< Context.print(SuccBlock) << "\n");
628646
}
629-
auto SuccIdx = CyclePOT.getIndex(SuccBlock);
630647
visitEdge(*SuccBlock, *SuccBlock);
631-
FloorIdx = std::min<int>(FloorIdx, SuccIdx);
632648
}
633649

634650
while (true) {
635651
auto BlockIdx = FreshLabels.find_last();
636-
if (BlockIdx == -1 || BlockIdx < FloorIdx)
652+
if (BlockIdx == -1)
637653
break;
638654

639655
LLVM_DEBUG(dbgs() << "Current labels:\n"; printDefs(dbgs()));
640656

641657
FreshLabels.reset(BlockIdx);
642-
if (BlockIdx == DivTermIdx) {
643-
LLVM_DEBUG(dbgs() << "Skipping DivTermBlock\n");
658+
const auto *Block = CyclePOT[BlockIdx];
659+
if (Block == ImmPDom) {
660+
LLVM_DEBUG(dbgs() << "Skipping ImmPDom\n");
644661
continue;
645662
}
646663

647-
const auto *Block = CyclePOT[BlockIdx];
648664
LLVM_DEBUG(dbgs() << "visiting " << Context.print(Block) << " at index "
649665
<< BlockIdx << "\n");
650666

651667
const auto *Label = BlockLabels[Block];
652668
assert(Label);
653669

654-
bool CausedJoin = false;
655-
int LoweredFloorIdx = FloorIdx;
656-
657670
// If the current block is the header of a reducible cycle that
658671
// contains the divergent branch, then the label should be
659672
// propagated to the cycle exits. Such a header is the "last
@@ -681,28 +694,11 @@ template <typename ContextT> class DivergencePropagator {
681694
if (const auto *BlockCycle = getReducibleParent(Block)) {
682695
SmallVector<BlockT *, 4> BlockCycleExits;
683696
BlockCycle->getExitBlocks(BlockCycleExits);
684-
for (auto *BlockCycleExit : BlockCycleExits) {
685-
CausedJoin |= visitCycleExitEdge(*BlockCycleExit, *Label);
686-
LoweredFloorIdx =
687-
std::min<int>(LoweredFloorIdx, CyclePOT.getIndex(BlockCycleExit));
688-
}
697+
for (auto *BlockCycleExit : BlockCycleExits)
698+
visitCycleExitEdge(*BlockCycleExit, *Label);
689699
} else {
690-
for (const auto *SuccBlock : successors(Block)) {
691-
CausedJoin |= visitEdge(*SuccBlock, *Label);
692-
LoweredFloorIdx =
693-
std::min<int>(LoweredFloorIdx, CyclePOT.getIndex(SuccBlock));
694-
}
695-
}
696-
697-
// Floor update
698-
if (CausedJoin) {
699-
// 1. Different labels pushed to successors
700-
FloorIdx = LoweredFloorIdx;
701-
} else if (FloorLabel != Label) {
702-
// 2. No join caused BUT we pushed a label that is different than the
703-
// last pushed label
704-
FloorIdx = LoweredFloorIdx;
705-
FloorLabel = Label;
700+
for (const auto *SuccBlock : successors(Block))
701+
visitEdge(*SuccBlock, *Label);
706702
}
707703
}
708704

@@ -742,8 +738,9 @@ typename llvm::GenericSyncDependenceAnalysis<ContextT>::DivergenceDescriptor
742738

743739
template <typename ContextT>
744740
llvm::GenericSyncDependenceAnalysis<ContextT>::GenericSyncDependenceAnalysis(
745-
const ContextT &Context, const DominatorTreeT &DT, const CycleInfoT &CI)
746-
: CyclePO(Context), DT(DT), CI(CI) {
741+
const ContextT &Context, const DominatorTreeT &DT,
742+
const PostDominatorTreeT &PDT, const CycleInfoT &CI)
743+
: CyclePO(Context), DT(DT), PDT(PDT), CI(CI) {
747744
CyclePO.compute(CI);
748745
}
749746

@@ -761,7 +758,7 @@ auto llvm::GenericSyncDependenceAnalysis<ContextT>::getJoinBlocks(
761758
return *ItCached->second;
762759

763760
// compute all join points
764-
DivergencePropagatorT Propagator(CyclePO, DT, CI, *DivTermBlock);
761+
DivergencePropagatorT Propagator(CyclePO, DT, PDT, CI, *DivTermBlock);
765762
auto DivDesc = Propagator.computeJoinPoints();
766763

767764
auto printBlockSet = [&](ConstBlockSet &Blocks) {
@@ -1155,9 +1152,9 @@ bool GenericUniformityAnalysisImpl<ContextT>::isAlwaysUniform(
11551152

11561153
template <typename ContextT>
11571154
GenericUniformityInfo<ContextT>::GenericUniformityInfo(
1158-
const DominatorTreeT &DT, const CycleInfoT &CI,
1159-
const TargetTransformInfo *TTI) {
1160-
DA.reset(new ImplT{DT, CI, TTI});
1155+
const DominatorTreeT &DT, const PostDominatorTreeT &PDT,
1156+
const CycleInfoT &CI, const TargetTransformInfo *TTI) {
1157+
DA.reset(new ImplT{DT, PDT, CI, TTI});
11611158
}
11621159

11631160
template <typename ContextT>

llvm/include/llvm/ADT/GenericUniformityInfo.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ template <typename ContextT> class GenericUniformityInfo {
3535
using UseT = typename ContextT::UseT;
3636
using InstructionT = typename ContextT::InstructionT;
3737
using DominatorTreeT = typename ContextT::DominatorTreeT;
38+
using PostDominatorTreeT = typename ContextT::PostDominatorTreeT;
3839
using ThisT = GenericUniformityInfo<ContextT>;
3940

4041
using CycleInfoT = GenericCycleInfo<ContextT>;
@@ -43,7 +44,8 @@ template <typename ContextT> class GenericUniformityInfo {
4344
using TemporalDivergenceTuple =
4445
std::tuple<ConstValueRefT, InstructionT *, const CycleT *>;
4546

46-
GenericUniformityInfo(const DominatorTreeT &DT, const CycleInfoT &CI,
47+
GenericUniformityInfo(const DominatorTreeT &DT, const PostDominatorTreeT &PDT,
48+
const CycleInfoT &CI,
4749
const TargetTransformInfo *TTI = nullptr);
4850
GenericUniformityInfo() = default;
4951
GenericUniformityInfo(GenericUniformityInfo &&) = default;

llvm/include/llvm/CodeGen/MachineUniformityAnalysis.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "llvm/CodeGen/MachineCycleAnalysis.h"
1919
#include "llvm/CodeGen/MachineDominators.h"
2020
#include "llvm/CodeGen/MachinePassManager.h"
21+
#include "llvm/CodeGen/MachinePostDominators.h"
2122
#include "llvm/CodeGen/MachineSSAContext.h"
2223

2324
namespace llvm {
@@ -31,7 +32,8 @@ using MachineUniformityInfo = GenericUniformityInfo<MachineSSAContext>;
3132
/// everything is uniform.
3233
MachineUniformityInfo computeMachineUniformityInfo(
3334
MachineFunction &F, const MachineCycleInfo &cycleInfo,
34-
const MachineDominatorTree &domTree, bool HasBranchDivergence);
35+
const MachineDominatorTree &domTree,
36+
const MachinePostDominatorTree &pdomTree, bool HasBranchDivergence);
3537

3638
/// Legacy analysis pass which computes a \ref MachineUniformityInfo.
3739
class MachineUniformityAnalysisPass : public MachineFunctionPass {

llvm/lib/Analysis/UniformityAnalysis.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "llvm/Analysis/UniformityAnalysis.h"
1010
#include "llvm/ADT/GenericUniformityImpl.h"
1111
#include "llvm/Analysis/CycleAnalysis.h"
12+
#include "llvm/Analysis/PostDominators.h"
1213
#include "llvm/Analysis/TargetTransformInfo.h"
1314
#include "llvm/IR/Dominators.h"
1415
#include "llvm/IR/InstIterator.h"
@@ -114,9 +115,10 @@ template struct llvm::GenericUniformityAnalysisImplDeleter<
114115
llvm::UniformityInfo UniformityInfoAnalysis::run(Function &F,
115116
FunctionAnalysisManager &FAM) {
116117
auto &DT = FAM.getResult<DominatorTreeAnalysis>(F);
118+
auto &PDT = FAM.getResult<PostDominatorTreeAnalysis>(F);
117119
auto &TTI = FAM.getResult<TargetIRAnalysis>(F);
118120
auto &CI = FAM.getResult<CycleAnalysis>(F);
119-
UniformityInfo UI{DT, CI, &TTI};
121+
UniformityInfo UI{DT, PDT, CI, &TTI};
120122
// Skip computation if we can assume everything is uniform.
121123
if (TTI.hasBranchDivergence(&F))
122124
UI.compute();
@@ -148,6 +150,7 @@ UniformityInfoWrapperPass::UniformityInfoWrapperPass() : FunctionPass(ID) {}
148150
INITIALIZE_PASS_BEGIN(UniformityInfoWrapperPass, "uniformity",
149151
"Uniformity Analysis", true, true)
150152
INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
153+
INITIALIZE_PASS_DEPENDENCY(PostDominatorTreeWrapperPass)
151154
INITIALIZE_PASS_DEPENDENCY(CycleInfoWrapperPass)
152155
INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
153156
INITIALIZE_PASS_END(UniformityInfoWrapperPass, "uniformity",
@@ -156,18 +159,21 @@ INITIALIZE_PASS_END(UniformityInfoWrapperPass, "uniformity",
156159
void UniformityInfoWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const {
157160
AU.setPreservesAll();
158161
AU.addRequired<DominatorTreeWrapperPass>();
162+
AU.addRequired<PostDominatorTreeWrapperPass>();
159163
AU.addRequiredTransitive<CycleInfoWrapperPass>();
160164
AU.addRequired<TargetTransformInfoWrapperPass>();
161165
}
162166

163167
bool UniformityInfoWrapperPass::runOnFunction(Function &F) {
164168
auto &cycleInfo = getAnalysis<CycleInfoWrapperPass>().getResult();
165169
auto &domTree = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
170+
auto &pdomTree = getAnalysis<PostDominatorTreeWrapperPass>().getPostDomTree();
166171
auto &targetTransformInfo =
167172
getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
168173

169174
m_function = &F;
170-
m_uniformityInfo = UniformityInfo{domTree, cycleInfo, &targetTransformInfo};
175+
m_uniformityInfo =
176+
UniformityInfo{domTree, pdomTree, cycleInfo, &targetTransformInfo};
171177

172178
// Skip computation if we can assume everything is uniform.
173179
if (targetTransformInfo.hasBranchDivergence(m_function))

llvm/lib/CodeGen/MachineUniformityAnalysis.cpp

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include "llvm/Analysis/TargetTransformInfo.h"
1212
#include "llvm/CodeGen/MachineCycleAnalysis.h"
1313
#include "llvm/CodeGen/MachineDominators.h"
14+
#include "llvm/CodeGen/MachinePostDominators.h"
1415
#include "llvm/CodeGen/MachineRegisterInfo.h"
1516
#include "llvm/CodeGen/MachineSSAContext.h"
1617
#include "llvm/CodeGen/TargetInstrInfo.h"
@@ -156,9 +157,10 @@ template struct llvm::GenericUniformityAnalysisImplDeleter<
156157

157158
MachineUniformityInfo llvm::computeMachineUniformityInfo(
158159
MachineFunction &F, const MachineCycleInfo &cycleInfo,
159-
const MachineDominatorTree &domTree, bool HasBranchDivergence) {
160+
const MachineDominatorTree &domTree,
161+
const MachinePostDominatorTree &pdomTree, bool HasBranchDivergence) {
160162
assert(F.getRegInfo().isSSA() && "Expected to be run on SSA form!");
161-
MachineUniformityInfo UI(domTree, cycleInfo);
163+
MachineUniformityInfo UI(domTree, pdomTree, cycleInfo);
162164
if (HasBranchDivergence)
163165
UI.compute();
164166
return UI;
@@ -184,12 +186,13 @@ MachineUniformityAnalysis::Result
184186
MachineUniformityAnalysis::run(MachineFunction &MF,
185187
MachineFunctionAnalysisManager &MFAM) {
186188
auto &DomTree = MFAM.getResult<MachineDominatorTreeAnalysis>(MF);
189+
auto &PDomTree = MFAM.getResult<MachinePostDominatorTreeAnalysis>(MF);
187190
auto &CI = MFAM.getResult<MachineCycleAnalysis>(MF);
188191
auto &FAM = MFAM.getResult<FunctionAnalysisManagerMachineFunctionProxy>(MF)
189192
.getManager();
190193
auto &F = MF.getFunction();
191194
auto &TTI = FAM.getResult<TargetIRAnalysis>(F);
192-
return computeMachineUniformityInfo(MF, CI, DomTree,
195+
return computeMachineUniformityInfo(MF, CI, DomTree, PDomTree,
193196
TTI.hasBranchDivergence(&F));
194197
}
195198

@@ -215,22 +218,26 @@ INITIALIZE_PASS_BEGIN(MachineUniformityAnalysisPass, "machine-uniformity",
215218
"Machine Uniformity Info Analysis", false, true)
216219
INITIALIZE_PASS_DEPENDENCY(MachineCycleInfoWrapperPass)
217220
INITIALIZE_PASS_DEPENDENCY(MachineDominatorTreeWrapperPass)
221+
INITIALIZE_PASS_DEPENDENCY(MachinePostDominatorTreeWrapperPass)
218222
INITIALIZE_PASS_END(MachineUniformityAnalysisPass, "machine-uniformity",
219223
"Machine Uniformity Info Analysis", false, true)
220224

221225
void MachineUniformityAnalysisPass::getAnalysisUsage(AnalysisUsage &AU) const {
222226
AU.setPreservesAll();
223227
AU.addRequiredTransitive<MachineCycleInfoWrapperPass>();
224228
AU.addRequired<MachineDominatorTreeWrapperPass>();
229+
AU.addRequired<MachinePostDominatorTreeWrapperPass>();
225230
MachineFunctionPass::getAnalysisUsage(AU);
226231
}
227232

228233
bool MachineUniformityAnalysisPass::runOnMachineFunction(MachineFunction &MF) {
229234
auto &DomTree = getAnalysis<MachineDominatorTreeWrapperPass>().getDomTree();
235+
auto &PDomTree =
236+
getAnalysis<MachinePostDominatorTreeWrapperPass>().getPostDomTree();
230237
auto &CI = getAnalysis<MachineCycleInfoWrapperPass>().getCycleInfo();
231238
// FIXME: Query TTI::hasBranchDivergence. -run-pass seems to end up with a
232239
// default NoTTI
233-
UI = computeMachineUniformityInfo(MF, CI, DomTree, true);
240+
UI = computeMachineUniformityInfo(MF, CI, DomTree, PDomTree, true);
234241
return false;
235242
}
236243

0 commit comments

Comments
 (0)