Skip to content

Commit 03f1d09

Browse files
committed
[FuncSpec] Add Phi nodes to the InstCostVisitor.
This patch allows constant folding of PHIs when estimating the user bonus. Phi nodes are a special case since some of their inputs may remain unresolved until all the specialization arguments have been processed by the InstCostVisitor. Therefore, we keep a list of dead basic blocks and then lazily visit the Phi nodes once the user bonus has been computed for all the specialization arguments. Differential Revision: https://reviews.llvm.org/D154852
1 parent c1b8297 commit 03f1d09

File tree

3 files changed

+143
-7
lines changed

3 files changed

+143
-7
lines changed

llvm/include/llvm/Transforms/IPO/FunctionSpecialization.h

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,15 @@ class InstCostVisitor : public InstVisitor<InstCostVisitor, Constant *> {
126126
SCCPSolver &Solver;
127127

128128
ConstMap KnownConstants;
129+
// Basic blocks known to be unreachable after constant propagation.
130+
DenseSet<BasicBlock *> DeadBlocks;
131+
// PHI nodes we have visited before.
132+
DenseSet<Instruction *> VisitedPHIs;
133+
// PHI nodes we have visited once without successfully constant folding them.
134+
// Once the InstCostVisitor has processed all the specialization arguments,
135+
// it should be possible to determine whether those PHIs can be folded
136+
// (some of their incoming values may have become constant or dead).
137+
SmallVector<Instruction *> PendingPHIs;
129138

130139
ConstMap::iterator LastVisited;
131140

@@ -134,7 +143,10 @@ class InstCostVisitor : public InstVisitor<InstCostVisitor, Constant *> {
134143
TargetTransformInfo &TTI, SCCPSolver &Solver)
135144
: DL(DL), BFI(BFI), TTI(TTI), Solver(Solver) {}
136145

137-
Cost getUserBonus(Instruction *User, Value *Use, Constant *C);
146+
Cost getUserBonus(Instruction *User, Value *Use = nullptr,
147+
Constant *C = nullptr);
148+
149+
Cost getBonusFromPendingPHIs();
138150

139151
private:
140152
friend class InstVisitor<InstCostVisitor, Constant *>;
@@ -143,6 +155,7 @@ class InstCostVisitor : public InstVisitor<InstCostVisitor, Constant *> {
143155
Cost estimateBranchInst(BranchInst &I);
144156

145157
Constant *visitInstruction(Instruction &I) { return nullptr; }
158+
Constant *visitPHINode(PHINode &I);
146159
Constant *visitFreezeInst(FreezeInst &I);
147160
Constant *visitCallBase(CallBase &I);
148161
Constant *visitLoadInst(LoadInst &I);

llvm/lib/Transforms/IPO/FunctionSpecialization.cpp

Lines changed: 76 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,11 @@ static cl::opt<unsigned> MaxClones(
7878
"The maximum number of clones allowed for a single function "
7979
"specialization"));
8080

81+
static cl::opt<unsigned> MaxIncomingPhiValues(
82+
"funcspec-max-incoming-phi-values", cl::init(4), cl::Hidden, cl::desc(
83+
"The maximum number of incoming values a PHI node can have to be "
84+
"considered during the specialization bonus estimation"));
85+
8186
static cl::opt<unsigned> MinFunctionSize(
8287
"funcspec-min-function-size", cl::init(100), cl::Hidden, cl::desc(
8388
"Don't specialize functions that have less than this number of "
@@ -104,6 +109,7 @@ static cl::opt<bool> SpecializeLiteralConstant(
104109
// the combination of size and latency savings in comparison to the non
105110
// specialized version of the function.
106111
static Cost estimateBasicBlocks(SmallVectorImpl<BasicBlock *> &WorkList,
112+
DenseSet<BasicBlock *> &DeadBlocks,
107113
ConstMap &KnownConstants, SCCPSolver &Solver,
108114
BlockFrequencyInfo &BFI,
109115
TargetTransformInfo &TTI) {
@@ -118,6 +124,12 @@ static Cost estimateBasicBlocks(SmallVectorImpl<BasicBlock *> &WorkList,
118124
if (!Weight)
119125
continue;
120126

127+
// These blocks are considered dead as far as the InstCostVisitor is
128+
// concerned. They haven't been proven dead yet by the Solver, but
129+
// may become if we propagate the constant specialization arguments.
130+
if (!DeadBlocks.insert(BB).second)
131+
continue;
132+
121133
for (Instruction &I : *BB) {
122134
// Disregard SSA copies.
123135
if (auto *II = dyn_cast<IntrinsicInst>(&I))
@@ -152,9 +164,19 @@ static Constant *findConstantFor(Value *V, ConstMap &KnownConstants) {
152164
return nullptr;
153165
}
154166

167+
Cost InstCostVisitor::getBonusFromPendingPHIs() {
168+
Cost Bonus = 0;
169+
while (!PendingPHIs.empty()) {
170+
Instruction *Phi = PendingPHIs.pop_back_val();
171+
Bonus += getUserBonus(Phi);
172+
}
173+
return Bonus;
174+
}
175+
155176
Cost InstCostVisitor::getUserBonus(Instruction *User, Value *Use, Constant *C) {
156177
// Cache the iterator before visiting.
157-
LastVisited = KnownConstants.insert({Use, C}).first;
178+
LastVisited = Use ? KnownConstants.insert({Use, C}).first
179+
: KnownConstants.end();
158180

159181
if (auto *I = dyn_cast<SwitchInst>(User))
160182
return estimateSwitchInst(*I);
@@ -181,13 +203,15 @@ Cost InstCostVisitor::getUserBonus(Instruction *User, Value *Use, Constant *C) {
181203

182204
for (auto *U : User->users())
183205
if (auto *UI = dyn_cast<Instruction>(U))
184-
if (Solver.isBlockExecutable(UI->getParent()))
206+
if (UI != User && Solver.isBlockExecutable(UI->getParent()))
185207
Bonus += getUserBonus(UI, User, C);
186208

187209
return Bonus;
188210
}
189211

190212
Cost InstCostVisitor::estimateSwitchInst(SwitchInst &I) {
213+
assert(LastVisited != KnownConstants.end() && "Invalid iterator!");
214+
191215
if (I.getCondition() != LastVisited->first)
192216
return 0;
193217

@@ -208,10 +232,13 @@ Cost InstCostVisitor::estimateSwitchInst(SwitchInst &I) {
208232
WorkList.push_back(BB);
209233
}
210234

211-
return estimateBasicBlocks(WorkList, KnownConstants, Solver, BFI, TTI);
235+
return estimateBasicBlocks(WorkList, DeadBlocks, KnownConstants, Solver, BFI,
236+
TTI);
212237
}
213238

214239
Cost InstCostVisitor::estimateBranchInst(BranchInst &I) {
240+
assert(LastVisited != KnownConstants.end() && "Invalid iterator!");
241+
215242
if (I.getCondition() != LastVisited->first)
216243
return 0;
217244

@@ -223,10 +250,39 @@ Cost InstCostVisitor::estimateBranchInst(BranchInst &I) {
223250
Succ->getUniquePredecessor() == I.getParent())
224251
WorkList.push_back(Succ);
225252

226-
return estimateBasicBlocks(WorkList, KnownConstants, Solver, BFI, TTI);
253+
return estimateBasicBlocks(WorkList, DeadBlocks, KnownConstants, Solver, BFI,
254+
TTI);
255+
}
256+
257+
Constant *InstCostVisitor::visitPHINode(PHINode &I) {
258+
if (I.getNumIncomingValues() > MaxIncomingPhiValues)
259+
return nullptr;
260+
261+
bool Inserted = VisitedPHIs.insert(&I).second;
262+
Constant *Const = nullptr;
263+
264+
for (unsigned Idx = 0, E = I.getNumIncomingValues(); Idx != E; ++Idx) {
265+
Value *V = I.getIncomingValue(Idx);
266+
if (auto *Inst = dyn_cast<Instruction>(V))
267+
if (Inst == &I || DeadBlocks.contains(I.getIncomingBlock(Idx)))
268+
continue;
269+
Constant *C = findConstantFor(V, KnownConstants);
270+
if (!C) {
271+
if (Inserted)
272+
PendingPHIs.push_back(&I);
273+
return nullptr;
274+
}
275+
if (!Const)
276+
Const = C;
277+
else if (C != Const)
278+
return nullptr;
279+
}
280+
return Const;
227281
}
228282

229283
Constant *InstCostVisitor::visitFreezeInst(FreezeInst &I) {
284+
assert(LastVisited != KnownConstants.end() && "Invalid iterator!");
285+
230286
if (isGuaranteedNotToBeUndefOrPoison(LastVisited->second))
231287
return LastVisited->second;
232288
return nullptr;
@@ -253,6 +309,8 @@ Constant *InstCostVisitor::visitCallBase(CallBase &I) {
253309
}
254310

255311
Constant *InstCostVisitor::visitLoadInst(LoadInst &I) {
312+
assert(LastVisited != KnownConstants.end() && "Invalid iterator!");
313+
256314
if (isa<ConstantPointerNull>(LastVisited->second))
257315
return nullptr;
258316
return ConstantFoldLoadFromConstPtr(LastVisited->second, I.getType(), DL);
@@ -275,6 +333,8 @@ Constant *InstCostVisitor::visitGetElementPtrInst(GetElementPtrInst &I) {
275333
}
276334

277335
Constant *InstCostVisitor::visitSelectInst(SelectInst &I) {
336+
assert(LastVisited != KnownConstants.end() && "Invalid iterator!");
337+
278338
if (I.getCondition() != LastVisited->first)
279339
return nullptr;
280340

@@ -290,6 +350,8 @@ Constant *InstCostVisitor::visitCastInst(CastInst &I) {
290350
}
291351

292352
Constant *InstCostVisitor::visitCmpInst(CmpInst &I) {
353+
assert(LastVisited != KnownConstants.end() && "Invalid iterator!");
354+
293355
bool Swap = I.getOperand(1) == LastVisited->first;
294356
Value *V = Swap ? I.getOperand(0) : I.getOperand(1);
295357
Constant *Other = findConstantFor(V, KnownConstants);
@@ -303,10 +365,14 @@ Constant *InstCostVisitor::visitCmpInst(CmpInst &I) {
303365
}
304366

305367
Constant *InstCostVisitor::visitUnaryOperator(UnaryOperator &I) {
368+
assert(LastVisited != KnownConstants.end() && "Invalid iterator!");
369+
306370
return ConstantFoldUnaryOpOperand(I.getOpcode(), LastVisited->second, DL);
307371
}
308372

309373
Constant *InstCostVisitor::visitBinaryOperator(BinaryOperator &I) {
374+
assert(LastVisited != KnownConstants.end() && "Invalid iterator!");
375+
310376
bool Swap = I.getOperand(1) == LastVisited->first;
311377
Value *V = Swap ? I.getOperand(0) : I.getOperand(1);
312378
Constant *Other = findConstantFor(V, KnownConstants);
@@ -713,13 +779,17 @@ bool FunctionSpecializer::findSpecializations(Function *F, Cost SpecCost,
713779
AllSpecs[Index].CallSites.push_back(&CS);
714780
} else {
715781
// Calculate the specialisation gain.
716-
Cost Score = 0 - SpecCost;
782+
Cost Score = 0;
717783
InstCostVisitor Visitor = getInstCostVisitorFor(F);
718784
for (ArgInfo &A : S.Args)
719785
Score += getSpecializationBonus(A.Formal, A.Actual, Visitor);
786+
Score += Visitor.getBonusFromPendingPHIs();
787+
788+
LLVM_DEBUG(dbgs() << "FnSpecialization: Specialization score = "
789+
<< Score << "\n");
720790

721791
// Discard unprofitable specialisations.
722-
if (!ForceSpecialization && Score <= 0)
792+
if (!ForceSpecialization && Score <= SpecCost)
723793
continue;
724794

725795
// Create a new specialisation entry.

llvm/unittests/Transforms/IPO/FunctionSpecializationTest.cpp

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,3 +287,56 @@ TEST_F(FunctionSpecializationTest, Misc) {
287287
Bonus = Specializer.getSpecializationBonus(F->getArg(3), Undef, Visitor);
288288
EXPECT_TRUE(Bonus == 0);
289289
}
290+
291+
TEST_F(FunctionSpecializationTest, PhiNode) {
292+
const char *ModuleString = R"(
293+
define void @foo(i32 %a, i32 %b, i32 %i) {
294+
entry:
295+
br label %loop
296+
loop:
297+
switch i32 %i, label %default
298+
[ i32 1, label %case1
299+
i32 2, label %case2 ]
300+
case1:
301+
%0 = add i32 %a, 1
302+
br label %bb
303+
case2:
304+
%1 = sub i32 %b, 1
305+
br label %bb
306+
bb:
307+
%2 = phi i32 [ %0, %case1 ], [ %1, %case2 ], [ %2, %bb ]
308+
%3 = icmp eq i32 %2, 2
309+
br i1 %3, label %bb, label %loop
310+
default:
311+
ret void
312+
}
313+
)";
314+
315+
Module &M = parseModule(ModuleString);
316+
Function *F = M.getFunction("foo");
317+
FunctionSpecializer Specializer = getSpecializerFor(F);
318+
InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F);
319+
320+
Constant *One = ConstantInt::get(IntegerType::getInt32Ty(M.getContext()), 1);
321+
322+
auto FuncIter = F->begin();
323+
for (int I = 0; I < 4; ++I)
324+
++FuncIter;
325+
326+
BasicBlock &BB = *FuncIter;
327+
328+
Instruction &Phi = BB.front();
329+
Instruction &Icmp = *++BB.begin();
330+
331+
Cost Bonus = Specializer.getSpecializationBonus(F->getArg(0), One, Visitor) +
332+
Specializer.getSpecializationBonus(F->getArg(1), One, Visitor) +
333+
Specializer.getSpecializationBonus(F->getArg(2), One, Visitor);
334+
EXPECT_TRUE(Bonus > 0);
335+
336+
// phi + icmp
337+
Cost Ref = getInstCost(Phi) + getInstCost(Icmp);
338+
Bonus = Visitor.getBonusFromPendingPHIs();
339+
EXPECT_EQ(Bonus, Ref);
340+
EXPECT_TRUE(Bonus > 0);
341+
}
342+

0 commit comments

Comments
 (0)