Skip to content

Commit 4bac313

Browse files
committed
[GlobalOpt][FMV] Fix static resolution of calls.
Addresses the issues found on the review of https://github.com/llvm/llvm-project/pull/150267/files#r2356936355 Currently when collecting the users of an IFunc symbol to determine the callers, we incorrectly mix versions of different functions together, alongside non-FMV callers all in the same bag. That is problematic because we incorrectly deduce which features are unavailable as we iterate the callers. I have updated the unit tests to require a resolver function for the callers and regenerated the resolvers since some FMV features have been removed making the detection bitmasks different. I've replaced the deleted FMV feature ls64 with cssc. I've added a new test to cover unrelated callers.
1 parent 771c94c commit 4bac313

File tree

2 files changed

+416
-129
lines changed

2 files changed

+416
-129
lines changed

llvm/lib/Transforms/IPO/GlobalOpt.cpp

Lines changed: 115 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -2482,20 +2482,21 @@ DeleteDeadIFuncs(Module &M,
24822482
// Follows the use-def chain of \p V backwards until it finds a Function,
24832483
// in which case it collects in \p Versions. Return true on successful
24842484
// use-def chain traversal, false otherwise.
2485-
static bool collectVersions(TargetTransformInfo &TTI, Value *V,
2486-
SmallVectorImpl<Function *> &Versions) {
2485+
static bool
2486+
collectVersions(Value *V, SmallVectorImpl<Function *> &Versions,
2487+
function_ref<TargetTransformInfo &(Function &)> GetTTI) {
24872488
if (auto *F = dyn_cast<Function>(V)) {
2488-
if (!TTI.isMultiversionedFunction(*F))
2489+
if (!GetTTI(*F).isMultiversionedFunction(*F))
24892490
return false;
24902491
Versions.push_back(F);
24912492
} else if (auto *Sel = dyn_cast<SelectInst>(V)) {
2492-
if (!collectVersions(TTI, Sel->getTrueValue(), Versions))
2493+
if (!collectVersions(Sel->getTrueValue(), Versions, GetTTI))
24932494
return false;
2494-
if (!collectVersions(TTI, Sel->getFalseValue(), Versions))
2495+
if (!collectVersions(Sel->getFalseValue(), Versions, GetTTI))
24952496
return false;
24962497
} else if (auto *Phi = dyn_cast<PHINode>(V)) {
24972498
for (unsigned I = 0, E = Phi->getNumIncomingValues(); I != E; ++I)
2498-
if (!collectVersions(TTI, Phi->getIncomingValue(I), Versions))
2499+
if (!collectVersions(Phi->getIncomingValue(I), Versions, GetTTI))
24992500
return false;
25002501
} else {
25012502
// Unknown instruction type. Bail.
@@ -2525,8 +2526,14 @@ static bool OptimizeNonTrivialIFuncs(
25252526
Module &M, function_ref<TargetTransformInfo &(Function &)> GetTTI) {
25262527
bool Changed = false;
25272528

2528-
// Cache containing the mask constructed from a function's target features.
2529+
// Map containing the feature bits for a given function.
25292530
DenseMap<Function *, APInt> FeatureMask;
2531+
// Map containing all the versions corresponding to an IFunc symbol.
2532+
DenseMap<GlobalIFunc *, SmallVector<Function *>> VersionedFuncs;
2533+
// Map containing the IFunc symbol a function is version of.
2534+
DenseMap<Function *, GlobalIFunc *> VersionOf;
2535+
// List of all the interesting IFuncs found in the module.
2536+
SmallVector<GlobalIFunc *> IFuncs;
25302537

25312538
for (GlobalIFunc &IF : M.ifuncs()) {
25322539
if (IF.isInterposable())
@@ -2539,107 +2546,140 @@ static bool OptimizeNonTrivialIFuncs(
25392546
if (Resolver->isInterposable())
25402547
continue;
25412548

2542-
TargetTransformInfo &TTI = GetTTI(*Resolver);
2543-
2544-
// Discover the callee versions.
2545-
SmallVector<Function *> Callees;
2546-
if (any_of(*Resolver, [&TTI, &Callees](BasicBlock &BB) {
2549+
SmallVector<Function *> Versions;
2550+
// Discover the versioned functions.
2551+
if (any_of(*Resolver, [&](BasicBlock &BB) {
25472552
if (auto *Ret = dyn_cast_or_null<ReturnInst>(BB.getTerminator()))
2548-
if (!collectVersions(TTI, Ret->getReturnValue(), Callees))
2553+
if (!collectVersions(Ret->getReturnValue(), Versions, GetTTI))
25492554
return true;
25502555
return false;
25512556
}))
25522557
continue;
25532558

2554-
if (Callees.empty())
2559+
if (Versions.empty())
25552560
continue;
25562561

2557-
LLVM_DEBUG(dbgs() << "Statically resolving calls to function "
2558-
<< Resolver->getName() << "\n");
2559-
2560-
// Cache the feature mask for each callee.
2561-
for (Function *Callee : Callees) {
2562-
auto [It, Inserted] = FeatureMask.try_emplace(Callee);
2562+
for (Function *V : Versions) {
2563+
VersionOf.insert({V, &IF});
2564+
auto [It, Inserted] = FeatureMask.try_emplace(V);
25632565
if (Inserted)
2564-
It->second = TTI.getFeatureMask(*Callee);
2566+
It->second = GetTTI(*V).getFeatureMask(*V);
25652567
}
25662568

2567-
// Sort the callee versions in decreasing priority order.
2568-
sort(Callees, [&](auto *LHS, auto *RHS) {
2569+
// Sort function versions in decreasing priority order.
2570+
sort(Versions, [&](auto *LHS, auto *RHS) {
25692571
return FeatureMask[LHS].ugt(FeatureMask[RHS]);
25702572
});
25712573

2572-
// Find the callsites and cache the feature mask for each caller.
2573-
SmallVector<Function *> Callers;
2574+
IFuncs.push_back(&IF);
2575+
VersionedFuncs.try_emplace(&IF, std::move(Versions));
2576+
}
2577+
2578+
for (GlobalIFunc *CalleeIF : IFuncs) {
2579+
SmallVector<Function *> NonFMVCallers;
2580+
SmallVector<GlobalIFunc *> CallerIFuncs;
25742581
DenseMap<Function *, SmallVector<CallBase *>> CallSites;
2575-
for (User *U : IF.users()) {
2582+
2583+
// Find the callsites.
2584+
for (User *U : CalleeIF->users()) {
25762585
if (auto *CB = dyn_cast<CallBase>(U)) {
2577-
if (CB->getCalledOperand() == &IF) {
2586+
if (CB->getCalledOperand() == CalleeIF) {
25782587
Function *Caller = CB->getFunction();
2579-
auto [FeatIt, FeatInserted] = FeatureMask.try_emplace(Caller);
2580-
if (FeatInserted)
2581-
FeatIt->second = TTI.getFeatureMask(*Caller);
2582-
auto [CallIt, CallInserted] = CallSites.try_emplace(Caller);
2583-
if (CallInserted)
2584-
Callers.push_back(Caller);
2585-
CallIt->second.push_back(CB);
2588+
GlobalIFunc *CallerIFunc = nullptr;
2589+
TargetTransformInfo &TTI = GetTTI(*Caller);
2590+
bool CallerIsFMV = TTI.isMultiversionedFunction(*Caller);
2591+
// The caller is a version of a known IFunc.
2592+
if (auto It = VersionOf.find(Caller); It != VersionOf.end())
2593+
CallerIFunc = It->second;
2594+
else if (!CallerIsFMV && OptimizeNonFMVCallers) {
2595+
// The caller is non-FMV.
2596+
auto [It, Inserted] = FeatureMask.try_emplace(Caller);
2597+
if (Inserted)
2598+
It->second = TTI.getFeatureMask(*Caller);
2599+
} else
2600+
// The caller is none of the above, skip.
2601+
continue;
2602+
auto [It, Inserted] = CallSites.try_emplace(Caller);
2603+
if (Inserted) {
2604+
if (CallerIsFMV)
2605+
CallerIFuncs.push_back(CallerIFunc);
2606+
else
2607+
NonFMVCallers.push_back(Caller);
2608+
}
2609+
It->second.push_back(CB);
25862610
}
25872611
}
25882612
}
25892613

2590-
// Sort the caller versions in decreasing priority order.
2591-
sort(Callers, [&](auto *LHS, auto *RHS) {
2592-
return FeatureMask[LHS].ugt(FeatureMask[RHS]);
2593-
});
2594-
2595-
auto implies = [](APInt A, APInt B) { return B.isSubsetOf(A); };
2614+
LLVM_DEBUG(dbgs() << "Statically resolving calls to function "
2615+
<< CalleeIF->getResolverFunction()->getName() << "\n");
25962616

2597-
// Index to the highest priority candidate.
2598-
unsigned I = 0;
2599-
// Now try to redirect calls starting from higher priority callers.
2600-
for (Function *Caller : Callers) {
2601-
assert(I < Callees.size() && "Found callers of equal priority");
2617+
auto redirectCalls = [&](SmallVectorImpl<Function *> &Callers,
2618+
SmallVectorImpl<Function *> &Callees) {
2619+
// Index to the current callee candidate.
2620+
unsigned I = 0;
26022621

2603-
Function *Callee = Callees[I];
2604-
APInt CallerBits = FeatureMask[Caller];
2605-
APInt CalleeBits = FeatureMask[Callee];
2622+
// Try to redirect calls starting from higher priority callers.
2623+
for (Function *Caller : Callers) {
2624+
if (I == Callees.size())
2625+
break;
26062626

2607-
// In the case of FMV callers, we know that all higher priority callers
2608-
// than the current one did not get selected at runtime, which helps
2609-
// reason about the callees (if they have versions that mandate presence
2610-
// of the features which we already know are unavailable on this target).
2611-
if (TTI.isMultiversionedFunction(*Caller)) {
2627+
bool CallerIsFMV = GetTTI(*Caller).isMultiversionedFunction(*Caller);
2628+
// In the case of FMV callers, we know that all higher priority callers
2629+
// than the current one did not get selected at runtime, which helps
2630+
// reason about the callees (if they have versions that mandate presence
2631+
// of the features which we already know are unavailable on this
2632+
// target).
2633+
if (!CallerIsFMV)
2634+
// We can't reason much about non-FMV callers. Just pick the highest
2635+
// priority callee if it matches, otherwise bail.
2636+
assert(I == 0 && "Should only select the highest priority candidate");
2637+
2638+
Function *Callee = Callees[I];
2639+
APInt CallerBits = FeatureMask[Caller];
2640+
APInt CalleeBits = FeatureMask[Callee];
26122641
// If the feature set of the caller implies the feature set of the
2613-
// highest priority candidate then it shall be picked. In case of
2614-
// identical sets advance the candidate index one position.
2615-
if (CallerBits == CalleeBits)
2616-
++I;
2617-
else if (!implies(CallerBits, CalleeBits)) {
2618-
// Keep advancing the candidate index as long as the caller's
2619-
// features are a subset of the current candidate's.
2620-
while (implies(CalleeBits, CallerBits)) {
2642+
// highest priority candidate then it shall be picked.
2643+
if (CalleeBits.isSubsetOf(CallerBits)) {
2644+
// If there are no records of call sites for this particular function
2645+
// version, then it is not actually a caller, in which case skip.
2646+
if (auto It = CallSites.find(Caller); It != CallSites.end()) {
2647+
for (CallBase *CS : It->second) {
2648+
LLVM_DEBUG(dbgs() << "Redirecting call " << Caller->getName()
2649+
<< " -> " << Callee->getName() << "\n");
2650+
CS->setCalledOperand(Callee);
2651+
}
2652+
Changed = true;
2653+
}
2654+
}
2655+
// Keep advancing the candidate index as long as the caller's
2656+
// features are a subset of the current candidate's.
2657+
if (CallerIsFMV) {
2658+
while (CallerBits.isSubsetOf(CalleeBits)) {
26212659
if (++I == Callees.size())
26222660
break;
26232661
CalleeBits = FeatureMask[Callees[I]];
26242662
}
2625-
continue;
26262663
}
2627-
} else {
2628-
// We can't reason much about non-FMV callers. Just pick the highest
2629-
// priority callee if it matches, otherwise bail.
2630-
if (!OptimizeNonFMVCallers || I > 0 || !implies(CallerBits, CalleeBits))
2631-
continue;
26322664
}
2633-
auto &Calls = CallSites[Caller];
2634-
for (CallBase *CS : Calls) {
2635-
LLVM_DEBUG(dbgs() << "Redirecting call " << Caller->getName() << " -> "
2636-
<< Callee->getName() << "\n");
2637-
CS->setCalledOperand(Callee);
2665+
};
2666+
2667+
auto &Callees = VersionedFuncs[CalleeIF];
2668+
2669+
// Optimize non-FMV calls.
2670+
if (!NonFMVCallers.empty() && OptimizeNonFMVCallers)
2671+
redirectCalls(NonFMVCallers, Callees);
2672+
2673+
// Optimize FMV calls.
2674+
if (!CallerIFuncs.empty()) {
2675+
for (GlobalIFunc *CallerIF : CallerIFuncs) {
2676+
auto &Callers = VersionedFuncs[CallerIF];
2677+
redirectCalls(Callers, Callees);
26382678
}
2639-
Changed = true;
26402679
}
2641-
if (IF.use_empty() ||
2642-
all_of(IF.users(), [](User *U) { return isa<GlobalAlias>(U); }))
2680+
2681+
if (CalleeIF->use_empty() ||
2682+
all_of(CalleeIF->users(), [](User *U) { return isa<GlobalAlias>(U); }))
26432683
NumIFuncsResolved++;
26442684
}
26452685
return Changed;

0 commit comments

Comments
 (0)