Skip to content

Commit 6c2781e

Browse files
authored
[GVN] Share equality propagation for assume and condition (#161639)
GVN currently has two different implementation of equality propagation. One of them is used for branch conditions (dominating an edge), which performs replacements across multiple blocks. This is also used for assumes to handle uses outside the current block. However, uses inside the block are handled using a completely separate implementation, which involves populating a replacement map and then checking it for individual instructions during normal GVN. While this approach generally makes sense, it is kind of pointless if we already do a use walk to handle the cross-block case anyway. This PR generalizes propagateEquality() to accept either a BasicBlockEdge or an Instruction* and replace dominated users. This removes the need for special handling of uses in the same block for assumes, as they're covered by instruction dominance, and ensures that both implementations do not go out of sync.
1 parent c32753a commit 6c2781e

File tree

5 files changed

+98
-155
lines changed

5 files changed

+98
-155
lines changed

llvm/include/llvm/Transforms/Scalar/GVN.h

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
#include <cstdint>
2929
#include <optional>
3030
#include <utility>
31+
#include <variant>
3132
#include <vector>
3233

3334
namespace llvm {
@@ -322,11 +323,6 @@ class GVNPass : public PassInfoMixin<GVNPass> {
322323
};
323324
LeaderMap LeaderTable;
324325

325-
// Block-local map of equivalent values to their leader, does not
326-
// propagate to any successors. Entries added mid-block are applied
327-
// to the remaining instructions in the block.
328-
SmallMapVector<Value *, Value *, 4> ReplaceOperandsWithMap;
329-
330326
// Map the block to reversed postorder traversal number. It is used to
331327
// find back edge easily.
332328
DenseMap<AssertingVH<BasicBlock>, uint32_t> BlockRPONumber;
@@ -402,9 +398,9 @@ class GVNPass : public PassInfoMixin<GVNPass> {
402398
void verifyRemoved(const Instruction *I) const;
403399
bool splitCriticalEdges();
404400
BasicBlock *splitCriticalEdges(BasicBlock *Pred, BasicBlock *Succ);
405-
bool replaceOperandsForInBlockEquality(Instruction *I) const;
406-
bool propagateEquality(Value *LHS, Value *RHS, const BasicBlockEdge &Root,
407-
bool DominatesByEdge);
401+
bool
402+
propagateEquality(Value *LHS, Value *RHS,
403+
const std::variant<BasicBlockEdge, Instruction *> &Root);
408404
bool processFoldableCondBr(BranchInst *BI);
409405
void addDeadBlock(BasicBlock *BB);
410406
void assignValNumForDeadCode();

llvm/include/llvm/Transforms/Utils/Local.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -452,6 +452,11 @@ LLVM_ABI unsigned replaceDominatedUsesWith(Value *From, Value *To,
452452
LLVM_ABI unsigned replaceDominatedUsesWith(Value *From, Value *To,
453453
DominatorTree &DT,
454454
const BasicBlock *BB);
455+
/// Replace each use of 'From' with 'To' if that use is dominated by the
456+
/// given instruction. Returns the number of replacements made.
457+
LLVM_ABI unsigned replaceDominatedUsesWith(Value *From, Value *To,
458+
DominatorTree &DT,
459+
const Instruction *I);
455460
/// Replace each use of 'From' with 'To' if that use is dominated by
456461
/// the given edge and the callback ShouldReplace returns true. Returns the
457462
/// number of replacements made.
@@ -464,6 +469,12 @@ LLVM_ABI unsigned replaceDominatedUsesWithIf(
464469
LLVM_ABI unsigned replaceDominatedUsesWithIf(
465470
Value *From, Value *To, DominatorTree &DT, const BasicBlock *BB,
466471
function_ref<bool(const Use &U, const Value *To)> ShouldReplace);
472+
/// Replace each use of 'From' with 'To' if that use is dominated by
473+
/// the given instruction and the callback ShouldReplace returns true. Returns
474+
/// the number of replacements made.
475+
LLVM_ABI unsigned replaceDominatedUsesWithIf(
476+
Value *From, Value *To, DominatorTree &DT, const Instruction *I,
477+
function_ref<bool(const Use &U, const Value *To)> ShouldReplace);
467478

468479
/// Return true if this call calls a gc leaf function.
469480
///

llvm/lib/Transforms/Scalar/GVN.cpp

Lines changed: 66 additions & 146 deletions
Original file line numberDiff line numberDiff line change
@@ -2084,13 +2084,6 @@ bool GVNPass::processNonLocalLoad(LoadInst *Load) {
20842084
return Changed;
20852085
}
20862086

2087-
static bool hasUsersIn(Value *V, BasicBlock *BB) {
2088-
return any_of(V->users(), [BB](User *U) {
2089-
auto *I = dyn_cast<Instruction>(U);
2090-
return I && I->getParent() == BB;
2091-
});
2092-
}
2093-
20942087
bool GVNPass::processAssumeIntrinsic(AssumeInst *IntrinsicI) {
20952088
Value *V = IntrinsicI->getArgOperand(0);
20962089

@@ -2149,85 +2142,7 @@ bool GVNPass::processAssumeIntrinsic(AssumeInst *IntrinsicI) {
21492142
}
21502143

21512144
Constant *True = ConstantInt::getTrue(V->getContext());
2152-
bool Changed = false;
2153-
2154-
for (BasicBlock *Successor : successors(IntrinsicI->getParent())) {
2155-
BasicBlockEdge Edge(IntrinsicI->getParent(), Successor);
2156-
2157-
// This property is only true in dominated successors, propagateEquality
2158-
// will check dominance for us.
2159-
Changed |= propagateEquality(V, True, Edge, false);
2160-
}
2161-
2162-
// We can replace assume value with true, which covers cases like this:
2163-
// call void @llvm.assume(i1 %cmp)
2164-
// br i1 %cmp, label %bb1, label %bb2 ; will change %cmp to true
2165-
ReplaceOperandsWithMap[V] = True;
2166-
2167-
// Similarly, after assume(!NotV) we know that NotV == false.
2168-
Value *NotV;
2169-
if (match(V, m_Not(m_Value(NotV))))
2170-
ReplaceOperandsWithMap[NotV] = ConstantInt::getFalse(V->getContext());
2171-
2172-
// If we find an equality fact, canonicalize all dominated uses in this block
2173-
// to one of the two values. We heuristically choice the "oldest" of the
2174-
// two where age is determined by value number. (Note that propagateEquality
2175-
// above handles the cross block case.)
2176-
//
2177-
// Key case to cover are:
2178-
// 1)
2179-
// %cmp = fcmp oeq float 3.000000e+00, %0 ; const on lhs could happen
2180-
// call void @llvm.assume(i1 %cmp)
2181-
// ret float %0 ; will change it to ret float 3.000000e+00
2182-
// 2)
2183-
// %load = load float, float* %addr
2184-
// %cmp = fcmp oeq float %load, %0
2185-
// call void @llvm.assume(i1 %cmp)
2186-
// ret float %load ; will change it to ret float %0
2187-
if (auto *CmpI = dyn_cast<CmpInst>(V)) {
2188-
if (CmpI->isEquivalence()) {
2189-
Value *CmpLHS = CmpI->getOperand(0);
2190-
Value *CmpRHS = CmpI->getOperand(1);
2191-
// Heuristically pick the better replacement -- the choice of heuristic
2192-
// isn't terribly important here, but the fact we canonicalize on some
2193-
// replacement is for exposing other simplifications.
2194-
// TODO: pull this out as a helper function and reuse w/ existing
2195-
// (slightly different) logic.
2196-
if (isa<Constant>(CmpLHS) && !isa<Constant>(CmpRHS))
2197-
std::swap(CmpLHS, CmpRHS);
2198-
if (!isa<Instruction>(CmpLHS) && isa<Instruction>(CmpRHS))
2199-
std::swap(CmpLHS, CmpRHS);
2200-
if ((isa<Argument>(CmpLHS) && isa<Argument>(CmpRHS)) ||
2201-
(isa<Instruction>(CmpLHS) && isa<Instruction>(CmpRHS))) {
2202-
// Move the 'oldest' value to the right-hand side, using the value
2203-
// number as a proxy for age.
2204-
uint32_t LVN = VN.lookupOrAdd(CmpLHS);
2205-
uint32_t RVN = VN.lookupOrAdd(CmpRHS);
2206-
if (LVN < RVN)
2207-
std::swap(CmpLHS, CmpRHS);
2208-
}
2209-
2210-
// Handle degenerate case where we either haven't pruned a dead path or a
2211-
// removed a trivial assume yet.
2212-
if (isa<Constant>(CmpLHS) && isa<Constant>(CmpRHS))
2213-
return Changed;
2214-
2215-
LLVM_DEBUG(dbgs() << "Replacing dominated uses of "
2216-
<< *CmpLHS << " with "
2217-
<< *CmpRHS << " in block "
2218-
<< IntrinsicI->getParent()->getName() << "\n");
2219-
2220-
// Setup the replacement map - this handles uses within the same block.
2221-
if (hasUsersIn(CmpLHS, IntrinsicI->getParent()))
2222-
ReplaceOperandsWithMap[CmpLHS] = CmpRHS;
2223-
2224-
// NOTE: The non-block local cases are handled by the call to
2225-
// propagateEquality above; this block is just about handling the block
2226-
// local cases. TODO: There's a bunch of logic in propagateEqualiy which
2227-
// isn't duplicated for the block local case, can we share it somehow?
2228-
}
2229-
}
2230-
return Changed;
2145+
return propagateEquality(V, True, IntrinsicI);
22312146
}
22322147

22332148
static void patchAndReplaceAllUsesWith(Instruction *I, Value *Repl) {
@@ -2526,39 +2441,28 @@ void GVNPass::assignBlockRPONumber(Function &F) {
25262441
InvalidBlockRPONumbers = false;
25272442
}
25282443

2529-
bool GVNPass::replaceOperandsForInBlockEquality(Instruction *Instr) const {
2530-
bool Changed = false;
2531-
for (unsigned OpNum = 0; OpNum < Instr->getNumOperands(); ++OpNum) {
2532-
Use &Operand = Instr->getOperandUse(OpNum);
2533-
auto It = ReplaceOperandsWithMap.find(Operand.get());
2534-
if (It != ReplaceOperandsWithMap.end()) {
2535-
const DataLayout &DL = Instr->getDataLayout();
2536-
if (!canReplacePointersInUseIfEqual(Operand, It->second, DL))
2537-
continue;
2538-
2539-
LLVM_DEBUG(dbgs() << "GVN replacing: " << *Operand << " with "
2540-
<< *It->second << " in instruction " << *Instr << '\n');
2541-
Instr->setOperand(OpNum, It->second);
2542-
Changed = true;
2543-
}
2544-
}
2545-
return Changed;
2546-
}
2547-
2548-
/// The given values are known to be equal in every block
2444+
/// The given values are known to be equal in every use
25492445
/// dominated by 'Root'. Exploit this, for example by replacing 'LHS' with
25502446
/// 'RHS' everywhere in the scope. Returns whether a change was made.
2551-
/// If DominatesByEdge is false, then it means that we will propagate the RHS
2552-
/// value starting from the end of Root.Start.
2553-
bool GVNPass::propagateEquality(Value *LHS, Value *RHS,
2554-
const BasicBlockEdge &Root,
2555-
bool DominatesByEdge) {
2447+
/// The Root may either be a basic block edge (for conditions) or an
2448+
/// instruction (for assumes).
2449+
bool GVNPass::propagateEquality(
2450+
Value *LHS, Value *RHS,
2451+
const std::variant<BasicBlockEdge, Instruction *> &Root) {
25562452
SmallVector<std::pair<Value*, Value*>, 4> Worklist;
25572453
Worklist.push_back(std::make_pair(LHS, RHS));
25582454
bool Changed = false;
2559-
// For speed, compute a conservative fast approximation to
2560-
// DT->dominates(Root, Root.getEnd());
2561-
const bool RootDominatesEnd = isOnlyReachableViaThisEdge(Root, DT);
2455+
SmallVector<const BasicBlock *> DominatedBlocks;
2456+
if (const BasicBlockEdge *Edge = std::get_if<BasicBlockEdge>(&Root)) {
2457+
// For speed, compute a conservative fast approximation to
2458+
// DT->dominates(Root, Root.getEnd());
2459+
if (isOnlyReachableViaThisEdge(*Edge, DT))
2460+
DominatedBlocks.push_back(Edge->getEnd());
2461+
} else {
2462+
Instruction *I = std::get<Instruction *>(Root);
2463+
for (const auto *Node : DT->getNode(I->getParent())->children())
2464+
DominatedBlocks.push_back(Node->getBlock());
2465+
}
25622466

25632467
while (!Worklist.empty()) {
25642468
std::pair<Value*, Value*> Item = Worklist.pop_back_val();
@@ -2606,9 +2510,9 @@ bool GVNPass::propagateEquality(Value *LHS, Value *RHS,
26062510
// using the leader table is about compiling faster, not optimizing better).
26072511
// The leader table only tracks basic blocks, not edges. Only add to if we
26082512
// have the simple case where the edge dominates the end.
2609-
if (RootDominatesEnd && !isa<Instruction>(RHS) &&
2610-
canReplacePointersIfEqual(LHS, RHS, DL))
2611-
LeaderTable.insert(LVN, RHS, Root.getEnd());
2513+
if (!isa<Instruction>(RHS) && canReplacePointersIfEqual(LHS, RHS, DL))
2514+
for (const BasicBlock *BB : DominatedBlocks)
2515+
LeaderTable.insert(LVN, RHS, BB);
26122516

26132517
// Replace all occurrences of 'LHS' with 'RHS' everywhere in the scope. As
26142518
// LHS always has at least one use that is not dominated by Root, this will
@@ -2618,12 +2522,14 @@ bool GVNPass::propagateEquality(Value *LHS, Value *RHS,
26182522
auto CanReplacePointersCallBack = [&DL](const Use &U, const Value *To) {
26192523
return canReplacePointersInUseIfEqual(U, To, DL);
26202524
};
2621-
unsigned NumReplacements =
2622-
DominatesByEdge
2623-
? replaceDominatedUsesWithIf(LHS, RHS, *DT, Root,
2624-
CanReplacePointersCallBack)
2625-
: replaceDominatedUsesWithIf(LHS, RHS, *DT, Root.getStart(),
2626-
CanReplacePointersCallBack);
2525+
unsigned NumReplacements;
2526+
if (const BasicBlockEdge *Edge = std::get_if<BasicBlockEdge>(&Root))
2527+
NumReplacements = replaceDominatedUsesWithIf(
2528+
LHS, RHS, *DT, *Edge, CanReplacePointersCallBack);
2529+
else
2530+
NumReplacements = replaceDominatedUsesWithIf(
2531+
LHS, RHS, *DT, std::get<Instruction *>(Root),
2532+
CanReplacePointersCallBack);
26272533

26282534
if (NumReplacements > 0) {
26292535
Changed = true;
@@ -2682,26 +2588,45 @@ bool GVNPass::propagateEquality(Value *LHS, Value *RHS,
26822588
// If the number we were assigned was brand new then there is no point in
26832589
// looking for an instruction realizing it: there cannot be one!
26842590
if (Num < NextNum) {
2685-
Value *NotCmp = findLeader(Root.getEnd(), Num);
2686-
if (NotCmp && isa<Instruction>(NotCmp)) {
2687-
unsigned NumReplacements =
2688-
DominatesByEdge
2689-
? replaceDominatedUsesWith(NotCmp, NotVal, *DT, Root)
2690-
: replaceDominatedUsesWith(NotCmp, NotVal, *DT,
2691-
Root.getStart());
2692-
Changed |= NumReplacements > 0;
2693-
NumGVNEqProp += NumReplacements;
2694-
// Cached information for anything that uses NotCmp will be invalid.
2695-
if (MD)
2696-
MD->invalidateCachedPointerInfo(NotCmp);
2591+
for (const auto &Entry : LeaderTable.getLeaders(Num)) {
2592+
// Only look at leaders that either dominate the start of the edge,
2593+
// or are dominated by the end. This check is not necessary for
2594+
// correctness, it only discards cases for which the following
2595+
// use replacement will not work anyway.
2596+
if (const BasicBlockEdge *Edge = std::get_if<BasicBlockEdge>(&Root)) {
2597+
if (!DT->dominates(Entry.BB, Edge->getStart()) &&
2598+
!DT->dominates(Edge->getEnd(), Entry.BB))
2599+
continue;
2600+
} else {
2601+
auto *InstBB = std::get<Instruction *>(Root)->getParent();
2602+
if (!DT->dominates(Entry.BB, InstBB) &&
2603+
!DT->dominates(InstBB, Entry.BB))
2604+
continue;
2605+
}
2606+
2607+
Value *NotCmp = Entry.Val;
2608+
if (NotCmp && isa<Instruction>(NotCmp)) {
2609+
unsigned NumReplacements;
2610+
if (const BasicBlockEdge *Edge = std::get_if<BasicBlockEdge>(&Root))
2611+
NumReplacements =
2612+
replaceDominatedUsesWith(NotCmp, NotVal, *DT, *Edge);
2613+
else
2614+
NumReplacements = replaceDominatedUsesWith(
2615+
NotCmp, NotVal, *DT, std::get<Instruction *>(Root));
2616+
Changed |= NumReplacements > 0;
2617+
NumGVNEqProp += NumReplacements;
2618+
// Cached information for anything that uses NotCmp will be invalid.
2619+
if (MD)
2620+
MD->invalidateCachedPointerInfo(NotCmp);
2621+
}
26972622
}
26982623
}
26992624
// Ensure that any instruction in scope that gets the "A < B" value number
27002625
// is replaced with false.
27012626
// The leader table only tracks basic blocks, not edges. Only add to if we
27022627
// have the simple case where the edge dominates the end.
2703-
if (RootDominatesEnd)
2704-
LeaderTable.insert(Num, NotVal, Root.getEnd());
2628+
for (const BasicBlock *BB : DominatedBlocks)
2629+
LeaderTable.insert(Num, NotVal, BB);
27052630

27062631
continue;
27072632
}
@@ -2789,11 +2714,11 @@ bool GVNPass::processInstruction(Instruction *I) {
27892714

27902715
Value *TrueVal = ConstantInt::getTrue(TrueSucc->getContext());
27912716
BasicBlockEdge TrueE(Parent, TrueSucc);
2792-
Changed |= propagateEquality(BranchCond, TrueVal, TrueE, true);
2717+
Changed |= propagateEquality(BranchCond, TrueVal, TrueE);
27932718

27942719
Value *FalseVal = ConstantInt::getFalse(FalseSucc->getContext());
27952720
BasicBlockEdge FalseE(Parent, FalseSucc);
2796-
Changed |= propagateEquality(BranchCond, FalseVal, FalseE, true);
2721+
Changed |= propagateEquality(BranchCond, FalseVal, FalseE);
27972722

27982723
return Changed;
27992724
}
@@ -2814,7 +2739,7 @@ bool GVNPass::processInstruction(Instruction *I) {
28142739
// If there is only a single edge, propagate the case value into it.
28152740
if (SwitchEdges.lookup(Dst) == 1) {
28162741
BasicBlockEdge E(Parent, Dst);
2817-
Changed |= propagateEquality(SwitchCond, Case.getCaseValue(), E, true);
2742+
Changed |= propagateEquality(SwitchCond, Case.getCaseValue(), E);
28182743
}
28192744
}
28202745
return Changed;
@@ -2942,8 +2867,6 @@ bool GVNPass::processBlock(BasicBlock *BB) {
29422867
if (DeadBlocks.count(BB))
29432868
return false;
29442869

2945-
// Clearing map before every BB because it can be used only for single BB.
2946-
ReplaceOperandsWithMap.clear();
29472870
bool ChangedFunction = false;
29482871

29492872
// Since we may not have visited the input blocks of the phis, we can't
@@ -2955,11 +2878,8 @@ bool GVNPass::processBlock(BasicBlock *BB) {
29552878
for (PHINode *PN : PHINodesToRemove) {
29562879
removeInstruction(PN);
29572880
}
2958-
for (Instruction &Inst : make_early_inc_range(*BB)) {
2959-
if (!ReplaceOperandsWithMap.empty())
2960-
ChangedFunction |= replaceOperandsForInBlockEquality(&Inst);
2881+
for (Instruction &Inst : make_early_inc_range(*BB))
29612882
ChangedFunction |= processInstruction(&Inst);
2962-
}
29632883
return ChangedFunction;
29642884
}
29652885

llvm/lib/Transforms/Utils/Local.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3246,6 +3246,13 @@ unsigned llvm::replaceDominatedUsesWith(Value *From, Value *To,
32463246
return ::replaceDominatedUsesWith(From, To, Dominates);
32473247
}
32483248

3249+
unsigned llvm::replaceDominatedUsesWith(Value *From, Value *To,
3250+
DominatorTree &DT,
3251+
const Instruction *I) {
3252+
auto Dominates = [&](const Use &U) { return DT.dominates(I, U); };
3253+
return ::replaceDominatedUsesWith(From, To, Dominates);
3254+
}
3255+
32493256
unsigned llvm::replaceDominatedUsesWithIf(
32503257
Value *From, Value *To, DominatorTree &DT, const BasicBlockEdge &Root,
32513258
function_ref<bool(const Use &U, const Value *To)> ShouldReplace) {
@@ -3264,6 +3271,15 @@ unsigned llvm::replaceDominatedUsesWithIf(
32643271
return ::replaceDominatedUsesWith(From, To, DominatesAndShouldReplace);
32653272
}
32663273

3274+
unsigned llvm::replaceDominatedUsesWithIf(
3275+
Value *From, Value *To, DominatorTree &DT, const Instruction *I,
3276+
function_ref<bool(const Use &U, const Value *To)> ShouldReplace) {
3277+
auto DominatesAndShouldReplace = [&](const Use &U) {
3278+
return DT.dominates(I, U) && ShouldReplace(U, To);
3279+
};
3280+
return ::replaceDominatedUsesWith(From, To, DominatesAndShouldReplace);
3281+
}
3282+
32673283
bool llvm::callsGCLeafFunction(const CallBase *Call,
32683284
const TargetLibraryInfo &TLI) {
32693285
// Check if the function is specifically marked as a gc leaf function.

llvm/test/Transforms/GVN/condprop.ll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -360,7 +360,7 @@ define i1 @test6_phi2(i1 %c, i32 %x, i32 %y) {
360360
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[X]], [[Y]]
361361
; CHECK-NEXT: br i1 [[CMP]], label [[BB2]], label [[BB3:%.*]]
362362
; CHECK: bb2:
363-
; CHECK-NEXT: [[PHI:%.*]] = phi i1 [ [[CMP_NOT]], [[BB1]] ], [ true, [[ENTRY:%.*]] ]
363+
; CHECK-NEXT: [[PHI:%.*]] = phi i1 [ false, [[BB1]] ], [ true, [[ENTRY:%.*]] ]
364364
; CHECK-NEXT: ret i1 [[PHI]]
365365
; CHECK: bb3:
366366
; CHECK-NEXT: ret i1 false

0 commit comments

Comments
 (0)