Skip to content

Commit 6472cb1

Browse files
authored
[FuncSpec] Improve estimation of select instruction. (#111176)
When propagating a constant to a select instruction we only consider the condition operand as the use. I am extending the logic to consider the true and false values too, in case the condition had been found to be constant in a previous propagation but halted.
1 parent e71ac93 commit 6472cb1

File tree

2 files changed

+38
-7
lines changed

2 files changed

+38
-7
lines changed

llvm/lib/Transforms/IPO/FunctionSpecialization.cpp

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -423,13 +423,16 @@ Constant *InstCostVisitor::visitGetElementPtrInst(GetElementPtrInst &I) {
423423
Constant *InstCostVisitor::visitSelectInst(SelectInst &I) {
424424
assert(LastVisited != KnownConstants.end() && "Invalid iterator!");
425425

426-
if (I.getCondition() != LastVisited->first)
427-
return nullptr;
428-
429-
Value *V = LastVisited->second->isZeroValue() ? I.getFalseValue()
430-
: I.getTrueValue();
431-
Constant *C = findConstantFor(V, KnownConstants);
432-
return C;
426+
if (I.getCondition() == LastVisited->first) {
427+
Value *V = LastVisited->second->isZeroValue() ? I.getFalseValue()
428+
: I.getTrueValue();
429+
return findConstantFor(V, KnownConstants);
430+
}
431+
if (Constant *Condition = findConstantFor(I.getCondition(), KnownConstants))
432+
if ((I.getTrueValue() == LastVisited->first && Condition->isOneValue()) ||
433+
(I.getFalseValue() == LastVisited->first && Condition->isZeroValue()))
434+
return LastVisited->second;
435+
return nullptr;
433436
}
434437

435438
Constant *InstCostVisitor::visitCastInst(CastInst &I) {

llvm/unittests/Transforms/IPO/FunctionSpecializationTest.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,34 @@ TEST_F(FunctionSpecializationTest, BranchInst) {
261261
EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0);
262262
}
263263

264+
TEST_F(FunctionSpecializationTest, SelectInst) {
265+
const char *ModuleString = R"(
266+
define i32 @foo(i1 %cond, i32 %a, i32 %b) {
267+
%sel = select i1 %cond, i32 %a, i32 %b
268+
ret i32 %sel
269+
}
270+
)";
271+
272+
Module &M = parseModule(ModuleString);
273+
Function *F = M.getFunction("foo");
274+
FunctionSpecializer Specializer = getSpecializerFor(F);
275+
InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F);
276+
277+
Constant *One = ConstantInt::get(IntegerType::getInt32Ty(M.getContext()), 1);
278+
Constant *Zero = ConstantInt::get(IntegerType::getInt32Ty(M.getContext()), 0);
279+
Constant *False = ConstantInt::getFalse(M.getContext());
280+
Instruction &Select = *F->front().begin();
281+
282+
Bonus Ref = getInstCost(Select);
283+
Bonus Test = Visitor.getSpecializationBonus(F->getArg(0), False);
284+
EXPECT_TRUE(Test.CodeSize == 0 && Test.Latency == 0);
285+
Test = Visitor.getSpecializationBonus(F->getArg(1), One);
286+
EXPECT_TRUE(Test.CodeSize == 0 && Test.Latency == 0);
287+
Test = Visitor.getSpecializationBonus(F->getArg(2), Zero);
288+
EXPECT_EQ(Test, Ref);
289+
EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0);
290+
}
291+
264292
TEST_F(FunctionSpecializationTest, Misc) {
265293
const char *ModuleString = R"(
266294
%struct_t = type { [8 x i16], [8 x i16], i32, i32, i32, ptr, [8 x i8] }

0 commit comments

Comments
 (0)