From 1bcd2ffabb7115d27ba87be4992e5bbcc8f86fb9 Mon Sep 17 00:00:00 2001 From: "Gu, Junjie" Date: Wed, 14 May 2025 23:06:36 -0700 Subject: [PATCH] [UniformAnalysis] Use Immediate postDom as last join Given a divergent block, computeJoinPoints uses FloorIdx to do early stopping. But it is incorrect for some cases (shown in the two new lit tests). This change uses the immediate post-dominator as the last join to check for early stopping. It adds post-dominator to genericUniformityImpl in order to get immediate postDom. --- llvm/include/llvm/ADT/GenericSSAContext.h | 4 + llvm/include/llvm/ADT/GenericUniformityImpl.h | 91 +++++++++---------- llvm/include/llvm/ADT/GenericUniformityInfo.h | 4 +- .../llvm/CodeGen/MachineUniformityAnalysis.h | 4 +- llvm/lib/Analysis/UniformityAnalysis.cpp | 10 +- .../lib/CodeGen/MachineUniformityAnalysis.cpp | 15 ++- .../AMDGPU/phi_div_branch.ll | 78 ++++++++++++++++ .../UniformityAnalysis/AMDGPU/phi_div_loop.ll | 82 +++++++++++++++++ 8 files changed, 233 insertions(+), 55 deletions(-) create mode 100644 llvm/test/Analysis/UniformityAnalysis/AMDGPU/phi_div_branch.ll create mode 100644 llvm/test/Analysis/UniformityAnalysis/AMDGPU/phi_div_loop.ll diff --git a/llvm/include/llvm/ADT/GenericSSAContext.h b/llvm/include/llvm/ADT/GenericSSAContext.h index 6aa3a8b9b6e0b..e99d4b1c6dd45 100644 --- a/llvm/include/llvm/ADT/GenericSSAContext.h +++ b/llvm/include/llvm/ADT/GenericSSAContext.h @@ -77,6 +77,10 @@ template class GenericSSAContext { // a given funciton. using DominatorTreeT = DominatorTreeBase; + // A post-dominator tree provides the post-dominance relation between + // basic blocks in a given funciton. + using PostDominatorTreeT = DominatorTreeBase; + GenericSSAContext() = default; GenericSSAContext(const FunctionT *F) : F(F) {} diff --git a/llvm/include/llvm/ADT/GenericUniformityImpl.h b/llvm/include/llvm/ADT/GenericUniformityImpl.h index d10355fff1bea..f404577bb7e56 100644 --- a/llvm/include/llvm/ADT/GenericUniformityImpl.h +++ b/llvm/include/llvm/ADT/GenericUniformityImpl.h @@ -263,6 +263,7 @@ template class GenericSyncDependenceAnalysis { public: using BlockT = typename ContextT::BlockT; using DominatorTreeT = typename ContextT::DominatorTreeT; + using PostDominatorTreeT = typename ContextT::PostDominatorTreeT; using FunctionT = typename ContextT::FunctionT; using ValueRefT = typename ContextT::ValueRefT; using InstructionT = typename ContextT::InstructionT; @@ -296,7 +297,9 @@ template class GenericSyncDependenceAnalysis { using DivergencePropagatorT = DivergencePropagator; GenericSyncDependenceAnalysis(const ContextT &Context, - const DominatorTreeT &DT, const CycleInfoT &CI); + const DominatorTreeT &DT, + const PostDominatorTreeT &PDT, + const CycleInfoT &CI); /// \brief Computes divergent join points and cycle exits caused by branch /// divergence in \p Term. @@ -315,6 +318,7 @@ template class GenericSyncDependenceAnalysis { ModifiedPO CyclePO; const DominatorTreeT &DT; + const PostDominatorTreeT &PDT; const CycleInfoT &CI; DenseMap> @@ -336,6 +340,7 @@ template class GenericUniformityAnalysisImpl { using UseT = typename ContextT::UseT; using InstructionT = typename ContextT::InstructionT; using DominatorTreeT = typename ContextT::DominatorTreeT; + using PostDominatorTreeT = typename ContextT::PostDominatorTreeT; using CycleInfoT = GenericCycleInfo; using CycleT = typename CycleInfoT::CycleT; @@ -348,10 +353,12 @@ template class GenericUniformityAnalysisImpl { using TemporalDivergenceTuple = std::tuple; - GenericUniformityAnalysisImpl(const DominatorTreeT &DT, const CycleInfoT &CI, + GenericUniformityAnalysisImpl(const DominatorTreeT &DT, + const PostDominatorTreeT &PDT, + const CycleInfoT &CI, const TargetTransformInfo *TTI) : Context(CI.getSSAContext()), F(*Context.getFunction()), CI(CI), - TTI(TTI), DT(DT), SDA(Context, DT, CI) {} + TTI(TTI), DT(DT), PDT(PDT), SDA(Context, DT, PDT, CI) {} void initialize(); @@ -435,6 +442,7 @@ template class GenericUniformityAnalysisImpl { private: const DominatorTreeT &DT; + const PostDominatorTreeT &PDT; // Recognized cycles with divergent exits. SmallPtrSet DivergentExitCycles; @@ -493,6 +501,7 @@ template class DivergencePropagator { public: using BlockT = typename ContextT::BlockT; using DominatorTreeT = typename ContextT::DominatorTreeT; + using PostDominatorTreeT = typename ContextT::PostDominatorTreeT; using FunctionT = typename ContextT::FunctionT; using ValueRefT = typename ContextT::ValueRefT; @@ -507,6 +516,7 @@ template class DivergencePropagator { const ModifiedPO &CyclePOT; const DominatorTreeT &DT; + const PostDominatorTreeT &PDT; const CycleInfoT &CI; const BlockT &DivTermBlock; const ContextT &Context; @@ -522,10 +532,11 @@ template class DivergencePropagator { BlockLabelMapT &BlockLabels; DivergencePropagator(const ModifiedPO &CyclePOT, const DominatorTreeT &DT, - const CycleInfoT &CI, const BlockT &DivTermBlock) - : CyclePOT(CyclePOT), DT(DT), CI(CI), DivTermBlock(DivTermBlock), - Context(CI.getSSAContext()), DivDesc(new DivergenceDescriptorT), - BlockLabels(DivDesc->BlockLabels) {} + const PostDominatorTreeT &PDT, const CycleInfoT &CI, + const BlockT &DivTermBlock) + : CyclePOT(CyclePOT), DT(DT), PDT(PDT), CI(CI), + DivTermBlock(DivTermBlock), Context(CI.getSSAContext()), + DivDesc(new DivergenceDescriptorT), BlockLabels(DivDesc->BlockLabels) {} void printDefs(raw_ostream &Out) { Out << "Propagator::BlockLabels {\n"; @@ -542,6 +553,12 @@ template class DivergencePropagator { Out << "}\n"; } + const BlockT *getIPDom(const BlockT *B) { + const auto *Node = PDT.getNode(B); + const auto *IPDomNode = Node->getIDom(); + return IPDomNode->getBlock(); + } + // Push a definition (\p PushedLabel) to \p SuccBlock and return whether this // causes a divergent join. bool computeJoin(const BlockT &SuccBlock, const BlockT &PushedLabel) { @@ -610,10 +627,11 @@ template class DivergencePropagator { LLVM_DEBUG(dbgs() << "SDA:computeJoinPoints: " << Context.print(&DivTermBlock) << "\n"); - // Early stopping criterion - int FloorIdx = CyclePOT.size() - 1; - const BlockT *FloorLabel = nullptr; - int DivTermIdx = CyclePOT.getIndex(&DivTermBlock); + // Immediate Post-dominator of DivTermBlock is the last join + // to visit. + const auto *ImmPDom = getIPDom(&DivTermBlock); + + LLVM_DEBUG(dbgs() << "Last join: " << Context.print(ImmPDom) << "\n"); // Bootstrap with branch targets auto const *DivTermCycle = CI.getCycle(&DivTermBlock); @@ -626,34 +644,29 @@ template class DivergencePropagator { LLVM_DEBUG(dbgs() << "\tImmediate divergent cycle exit: " << Context.print(SuccBlock) << "\n"); } - auto SuccIdx = CyclePOT.getIndex(SuccBlock); visitEdge(*SuccBlock, *SuccBlock); - FloorIdx = std::min(FloorIdx, SuccIdx); } while (true) { auto BlockIdx = FreshLabels.find_last(); - if (BlockIdx == -1 || BlockIdx < FloorIdx) + if (BlockIdx == -1) break; LLVM_DEBUG(dbgs() << "Current labels:\n"; printDefs(dbgs())); FreshLabels.reset(BlockIdx); - if (BlockIdx == DivTermIdx) { - LLVM_DEBUG(dbgs() << "Skipping DivTermBlock\n"); + const auto *Block = CyclePOT[BlockIdx]; + if (Block == ImmPDom) { + LLVM_DEBUG(dbgs() << "Skipping ImmPDom\n"); continue; } - const auto *Block = CyclePOT[BlockIdx]; LLVM_DEBUG(dbgs() << "visiting " << Context.print(Block) << " at index " << BlockIdx << "\n"); const auto *Label = BlockLabels[Block]; assert(Label); - bool CausedJoin = false; - int LoweredFloorIdx = FloorIdx; - // If the current block is the header of a reducible cycle that // contains the divergent branch, then the label should be // propagated to the cycle exits. Such a header is the "last @@ -681,28 +694,11 @@ template class DivergencePropagator { if (const auto *BlockCycle = getReducibleParent(Block)) { SmallVector BlockCycleExits; BlockCycle->getExitBlocks(BlockCycleExits); - for (auto *BlockCycleExit : BlockCycleExits) { - CausedJoin |= visitCycleExitEdge(*BlockCycleExit, *Label); - LoweredFloorIdx = - std::min(LoweredFloorIdx, CyclePOT.getIndex(BlockCycleExit)); - } + for (auto *BlockCycleExit : BlockCycleExits) + visitCycleExitEdge(*BlockCycleExit, *Label); } else { - for (const auto *SuccBlock : successors(Block)) { - CausedJoin |= visitEdge(*SuccBlock, *Label); - LoweredFloorIdx = - std::min(LoweredFloorIdx, CyclePOT.getIndex(SuccBlock)); - } - } - - // Floor update - if (CausedJoin) { - // 1. Different labels pushed to successors - FloorIdx = LoweredFloorIdx; - } else if (FloorLabel != Label) { - // 2. No join caused BUT we pushed a label that is different than the - // last pushed label - FloorIdx = LoweredFloorIdx; - FloorLabel = Label; + for (const auto *SuccBlock : successors(Block)) + visitEdge(*SuccBlock, *Label); } } @@ -742,8 +738,9 @@ typename llvm::GenericSyncDependenceAnalysis::DivergenceDescriptor template llvm::GenericSyncDependenceAnalysis::GenericSyncDependenceAnalysis( - const ContextT &Context, const DominatorTreeT &DT, const CycleInfoT &CI) - : CyclePO(Context), DT(DT), CI(CI) { + const ContextT &Context, const DominatorTreeT &DT, + const PostDominatorTreeT &PDT, const CycleInfoT &CI) + : CyclePO(Context), DT(DT), PDT(PDT), CI(CI) { CyclePO.compute(CI); } @@ -761,7 +758,7 @@ auto llvm::GenericSyncDependenceAnalysis::getJoinBlocks( return *ItCached->second; // compute all join points - DivergencePropagatorT Propagator(CyclePO, DT, CI, *DivTermBlock); + DivergencePropagatorT Propagator(CyclePO, DT, PDT, CI, *DivTermBlock); auto DivDesc = Propagator.computeJoinPoints(); auto printBlockSet = [&](ConstBlockSet &Blocks) { @@ -1155,9 +1152,9 @@ bool GenericUniformityAnalysisImpl::isAlwaysUniform( template GenericUniformityInfo::GenericUniformityInfo( - const DominatorTreeT &DT, const CycleInfoT &CI, - const TargetTransformInfo *TTI) { - DA.reset(new ImplT{DT, CI, TTI}); + const DominatorTreeT &DT, const PostDominatorTreeT &PDT, + const CycleInfoT &CI, const TargetTransformInfo *TTI) { + DA.reset(new ImplT{DT, PDT, CI, TTI}); } template diff --git a/llvm/include/llvm/ADT/GenericUniformityInfo.h b/llvm/include/llvm/ADT/GenericUniformityInfo.h index 9376fa6ee0bae..62d35582823dc 100644 --- a/llvm/include/llvm/ADT/GenericUniformityInfo.h +++ b/llvm/include/llvm/ADT/GenericUniformityInfo.h @@ -35,6 +35,7 @@ template class GenericUniformityInfo { using UseT = typename ContextT::UseT; using InstructionT = typename ContextT::InstructionT; using DominatorTreeT = typename ContextT::DominatorTreeT; + using PostDominatorTreeT = typename ContextT::PostDominatorTreeT; using ThisT = GenericUniformityInfo; using CycleInfoT = GenericCycleInfo; @@ -43,7 +44,8 @@ template class GenericUniformityInfo { using TemporalDivergenceTuple = std::tuple; - GenericUniformityInfo(const DominatorTreeT &DT, const CycleInfoT &CI, + GenericUniformityInfo(const DominatorTreeT &DT, const PostDominatorTreeT &PDT, + const CycleInfoT &CI, const TargetTransformInfo *TTI = nullptr); GenericUniformityInfo() = default; GenericUniformityInfo(GenericUniformityInfo &&) = default; diff --git a/llvm/include/llvm/CodeGen/MachineUniformityAnalysis.h b/llvm/include/llvm/CodeGen/MachineUniformityAnalysis.h index e8c0dc9b43823..03fc9ebfcf442 100644 --- a/llvm/include/llvm/CodeGen/MachineUniformityAnalysis.h +++ b/llvm/include/llvm/CodeGen/MachineUniformityAnalysis.h @@ -18,6 +18,7 @@ #include "llvm/CodeGen/MachineCycleAnalysis.h" #include "llvm/CodeGen/MachineDominators.h" #include "llvm/CodeGen/MachinePassManager.h" +#include "llvm/CodeGen/MachinePostDominators.h" #include "llvm/CodeGen/MachineSSAContext.h" namespace llvm { @@ -31,7 +32,8 @@ using MachineUniformityInfo = GenericUniformityInfo; /// everything is uniform. MachineUniformityInfo computeMachineUniformityInfo( MachineFunction &F, const MachineCycleInfo &cycleInfo, - const MachineDominatorTree &domTree, bool HasBranchDivergence); + const MachineDominatorTree &domTree, + const MachinePostDominatorTree &pdomTree, bool HasBranchDivergence); /// Legacy analysis pass which computes a \ref MachineUniformityInfo. class MachineUniformityAnalysisPass : public MachineFunctionPass { diff --git a/llvm/lib/Analysis/UniformityAnalysis.cpp b/llvm/lib/Analysis/UniformityAnalysis.cpp index 2101fdfacfc8f..a724a8c26d7db 100644 --- a/llvm/lib/Analysis/UniformityAnalysis.cpp +++ b/llvm/lib/Analysis/UniformityAnalysis.cpp @@ -9,6 +9,7 @@ #include "llvm/Analysis/UniformityAnalysis.h" #include "llvm/ADT/GenericUniformityImpl.h" #include "llvm/Analysis/CycleAnalysis.h" +#include "llvm/Analysis/PostDominators.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/InstIterator.h" @@ -114,9 +115,10 @@ template struct llvm::GenericUniformityAnalysisImplDeleter< llvm::UniformityInfo UniformityInfoAnalysis::run(Function &F, FunctionAnalysisManager &FAM) { auto &DT = FAM.getResult(F); + auto &PDT = FAM.getResult(F); auto &TTI = FAM.getResult(F); auto &CI = FAM.getResult(F); - UniformityInfo UI{DT, CI, &TTI}; + UniformityInfo UI{DT, PDT, CI, &TTI}; // Skip computation if we can assume everything is uniform. if (TTI.hasBranchDivergence(&F)) UI.compute(); @@ -148,6 +150,7 @@ UniformityInfoWrapperPass::UniformityInfoWrapperPass() : FunctionPass(ID) {} INITIALIZE_PASS_BEGIN(UniformityInfoWrapperPass, "uniformity", "Uniformity Analysis", true, true) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(PostDominatorTreeWrapperPass) INITIALIZE_PASS_DEPENDENCY(CycleInfoWrapperPass) INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) INITIALIZE_PASS_END(UniformityInfoWrapperPass, "uniformity", @@ -156,6 +159,7 @@ INITIALIZE_PASS_END(UniformityInfoWrapperPass, "uniformity", void UniformityInfoWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const { AU.setPreservesAll(); AU.addRequired(); + AU.addRequired(); AU.addRequiredTransitive(); AU.addRequired(); } @@ -163,11 +167,13 @@ void UniformityInfoWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const { bool UniformityInfoWrapperPass::runOnFunction(Function &F) { auto &cycleInfo = getAnalysis().getResult(); auto &domTree = getAnalysis().getDomTree(); + auto &pdomTree = getAnalysis().getPostDomTree(); auto &targetTransformInfo = getAnalysis().getTTI(F); m_function = &F; - m_uniformityInfo = UniformityInfo{domTree, cycleInfo, &targetTransformInfo}; + m_uniformityInfo = + UniformityInfo{domTree, pdomTree, cycleInfo, &targetTransformInfo}; // Skip computation if we can assume everything is uniform. if (targetTransformInfo.hasBranchDivergence(m_function)) diff --git a/llvm/lib/CodeGen/MachineUniformityAnalysis.cpp b/llvm/lib/CodeGen/MachineUniformityAnalysis.cpp index e4b82ce83fda6..b87f8357ecfa8 100644 --- a/llvm/lib/CodeGen/MachineUniformityAnalysis.cpp +++ b/llvm/lib/CodeGen/MachineUniformityAnalysis.cpp @@ -11,6 +11,7 @@ #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/CodeGen/MachineCycleAnalysis.h" #include "llvm/CodeGen/MachineDominators.h" +#include "llvm/CodeGen/MachinePostDominators.h" #include "llvm/CodeGen/MachineRegisterInfo.h" #include "llvm/CodeGen/MachineSSAContext.h" #include "llvm/CodeGen/TargetInstrInfo.h" @@ -156,9 +157,10 @@ template struct llvm::GenericUniformityAnalysisImplDeleter< MachineUniformityInfo llvm::computeMachineUniformityInfo( MachineFunction &F, const MachineCycleInfo &cycleInfo, - const MachineDominatorTree &domTree, bool HasBranchDivergence) { + const MachineDominatorTree &domTree, + const MachinePostDominatorTree &pdomTree, bool HasBranchDivergence) { assert(F.getRegInfo().isSSA() && "Expected to be run on SSA form!"); - MachineUniformityInfo UI(domTree, cycleInfo); + MachineUniformityInfo UI(domTree, pdomTree, cycleInfo); if (HasBranchDivergence) UI.compute(); return UI; @@ -184,12 +186,13 @@ MachineUniformityAnalysis::Result MachineUniformityAnalysis::run(MachineFunction &MF, MachineFunctionAnalysisManager &MFAM) { auto &DomTree = MFAM.getResult(MF); + auto &PDomTree = MFAM.getResult(MF); auto &CI = MFAM.getResult(MF); auto &FAM = MFAM.getResult(MF) .getManager(); auto &F = MF.getFunction(); auto &TTI = FAM.getResult(F); - return computeMachineUniformityInfo(MF, CI, DomTree, + return computeMachineUniformityInfo(MF, CI, DomTree, PDomTree, TTI.hasBranchDivergence(&F)); } @@ -215,6 +218,7 @@ INITIALIZE_PASS_BEGIN(MachineUniformityAnalysisPass, "machine-uniformity", "Machine Uniformity Info Analysis", false, true) INITIALIZE_PASS_DEPENDENCY(MachineCycleInfoWrapperPass) INITIALIZE_PASS_DEPENDENCY(MachineDominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(MachinePostDominatorTreeWrapperPass) INITIALIZE_PASS_END(MachineUniformityAnalysisPass, "machine-uniformity", "Machine Uniformity Info Analysis", false, true) @@ -222,15 +226,18 @@ void MachineUniformityAnalysisPass::getAnalysisUsage(AnalysisUsage &AU) const { AU.setPreservesAll(); AU.addRequiredTransitive(); AU.addRequired(); + AU.addRequired(); MachineFunctionPass::getAnalysisUsage(AU); } bool MachineUniformityAnalysisPass::runOnMachineFunction(MachineFunction &MF) { auto &DomTree = getAnalysis().getDomTree(); + auto &PDomTree = + getAnalysis().getPostDomTree(); auto &CI = getAnalysis().getCycleInfo(); // FIXME: Query TTI::hasBranchDivergence. -run-pass seems to end up with a // default NoTTI - UI = computeMachineUniformityInfo(MF, CI, DomTree, true); + UI = computeMachineUniformityInfo(MF, CI, DomTree, PDomTree, true); return false; } diff --git a/llvm/test/Analysis/UniformityAnalysis/AMDGPU/phi_div_branch.ll b/llvm/test/Analysis/UniformityAnalysis/AMDGPU/phi_div_branch.ll new file mode 100644 index 0000000000000..df949a86635c4 --- /dev/null +++ b/llvm/test/Analysis/UniformityAnalysis/AMDGPU/phi_div_branch.ll @@ -0,0 +1,78 @@ +; +; RUN: opt -mtriple amdgcn-- -passes='print' -disable-output %s 2>&1 | FileCheck %s +; +; This is to test an if-then-else case with some unmerged basic blocks +; (https://github.com/llvm/llvm-project/issues/137277) +; +; Entry (div.cond) +; / \ +; B0 B3 +; | | +; B1 B4 +; | | +; B2 B5 +; \ / +; B6 (phi: divergent) +; + + +; CHECK-LABEL: 'test_ctrl_divergence': +; CHECK-LABEL: BLOCK Entry +; CHECK: DIVERGENT: %div.cond = icmp eq i32 %tid, 0 +; CHECK: DIVERGENT: br i1 %div.cond, label %B3, label %B0 +; +; CHECK-LABEL: BLOCK B6 +; CHECK: DIVERGENT: %div_a = phi i32 [ %a0, %B2 ], [ %a1, %B5 ] +; CHECK: DIVERGENT: %div_b = phi i32 [ %b0, %B2 ], [ %b1, %B5 ] +; CHECK: DIVERGENT: %div_c = phi i32 [ %c0, %B2 ], [ %c1, %B5 ] + + +define amdgpu_kernel void @test_ctrl_divergence(i32 %a, i32 %b, i32 %c, i32 %d) { +Entry: + %tid = call i32 @llvm.amdgcn.workitem.id.x() + %div.cond = icmp eq i32 %tid, 0 + br i1 %div.cond, label %B3, label %B0 ; divergent branch + +B0: + %a0 = add i32 %a, 1 + br label %B1 + +B1: + %b0 = add i32 %b, 2 + br label %B2 + +B2: + %c0 = add i32 %c, 3 + br label %B6 + +B3: + %a1 = add i32 %a, 10 + br label %B4 + +B4: + %b1 = add i32 %b, 20 + br label %B5 + +B5: + %c1 = add i32 %c, 30 + br label %B6 + +B6: + %div_a = phi i32 [%a0, %B2], [%a1, %B5] + %div_b = phi i32 [%b0, %B2], [%b1, %B5] + %div_c = phi i32 [%c0, %B2], [%c1, %B5] + br i1 %div.cond, label %B8, label %B7 ; divergent branch + +B7: + %d1 = add i32 %d, 1 + br label %B8 + +B8: + %div_d = phi i32 [%d1, %B7], [%d, %B6] + ret void +} + + +declare i32 @llvm.amdgcn.workitem.id.x() #0 + +attributes #0 = {nounwind readnone } diff --git a/llvm/test/Analysis/UniformityAnalysis/AMDGPU/phi_div_loop.ll b/llvm/test/Analysis/UniformityAnalysis/AMDGPU/phi_div_loop.ll new file mode 100644 index 0000000000000..54c641862fe79 --- /dev/null +++ b/llvm/test/Analysis/UniformityAnalysis/AMDGPU/phi_div_loop.ll @@ -0,0 +1,82 @@ +; +; RUN: opt -mtriple amdgcn-- -passes='print' -disable-output %s 2>&1 | FileCheck %s +; +; This is to test a divergent phi involving loops +; (https://github.com/llvm/llvm-project/issues/137277). +; +; B0 (div.cond) +; / \ +; (L)B1 B4 +; | | +; B2 B5 (L) +; | | +; B3 / +; \ / +; B6 (phi: divergent) +; + +; +; CHECK-LABEL: UniformityInfo for function 'test_loop_ctrl_divergence': +; CHECK-LABEL: BLOCK Entry +; CHECK: DIVERGENT: %tid = call i32 @llvm.amdgcn.workitem.id.x() +; CHECK-LABEL: BLOCK B0 +; CHECK: DIVERGENT: %div.cond = icmp eq i32 %tid, 0 +; CHECK-LABEL: BLOCK B3 +; CHECK: %uni_a = phi i32 [ %a1, %B2 ], [ %a, %Entry ] +; CHECK-LABEL: BLOCK B5 +; CHECK: %uni.a3 = phi i32 [ %a2, %B4 ], [ %uni_a3, %B5 ] +; CHECK-LABEL BLOCK B6 +; CHECK: DIVERGENT: %div_a = phi i32 [ %uni_a, %B3 ], [ %uni_a3, %B5 ] +; + +define amdgpu_kernel void @test_loop_ctrl_divergence(i32 %a, i32 %b, i32 %c, i32 %d) { +Entry: + %tid = call i32 @llvm.amdgcn.workitem.id.x() + %uni.cond0 = icmp eq i32 %d, 0 + br i1 %uni.cond0, label %B3, label %B0 ; uniform branch + +B0: + %div.cond = icmp eq i32 %tid, 0 + br i1 %div.cond, label %B4, label %B1 ; divergent branch + +B1: + %uni.a0 = phi i32 [%a, %B0], [%a0, %B1] + %a0 = add i32 %uni.a0, 1 + %uni.cond1 = icmp slt i32 %a0, %b + br i1 %uni.cond1, label %B1, label %B2 + +B2: + %a1 = add i32 %a0, 10 + br label %B3 + +B3: + %uni_a = phi i32 [%a1, %B2], [%a, %Entry] + br label %B6 + +B4: + %a2 = add i32 %a, 20 + br label %B5 + +B5: + %uni.a3= phi i32 [%a2, %B4], [%uni_a3, %B5] + %uni_a3 = add i32 %uni.a3, 1 + %uni.cond2 = icmp slt i32 %uni_a3, %c + br i1 %uni.cond2, label %B5, label %B6 + +B6: + %div_a = phi i32 [%uni_a, %B3], [%uni_a3, %B5] ; divergent + %div.cond2 = icmp eq i32 %tid, 2 + br i1 %div.cond2, label %B7, label %B8 ; divergent branch + +B7: + %c0 = add i32 %div_a, 2 ; divergent + br label %B8 + +B8: + %ret = phi i32 [%c0, %B7], [0, %B6] ; divergent + ret void +} + +declare i32 @llvm.amdgcn.workitem.id.x() #0 + +attributes #0 = {nounwind readnone }