Skip to content

Commit 3add3ab

Browse files
authored
[GlobalOpt][FMV] Fix static resolution of calls. (#160011)
Addresses the issues found on the review of https://github.com/llvm/llvm-project/pull/150267/files#r2356936355 Currently when collecting the users of an IFunc symbol to determine the callers, we incorrectly mix versions of different functions together, alongside non-FMV callers all in the same bag. That is problematic because we incorrectly deduce which features are unavailable as we iterate the callers. I have updated the unit tests to require a resolver function for the callers and regenerated the resolvers since some FMV features have been removed making the detection bitmasks different. I've replaced the deleted FMV feature ls64 with cssc. I've added a new test to cover unrelated callers.
1 parent 8cc93c4 commit 3add3ab

File tree

2 files changed

+448
-151
lines changed

2 files changed

+448
-151
lines changed

llvm/lib/Transforms/IPO/GlobalOpt.cpp

Lines changed: 150 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -2488,20 +2488,21 @@ DeleteDeadIFuncs(Module &M,
24882488
// Follows the use-def chain of \p V backwards until it finds a Function,
24892489
// in which case it collects in \p Versions. Return true on successful
24902490
// use-def chain traversal, false otherwise.
2491-
static bool collectVersions(TargetTransformInfo &TTI, Value *V,
2492-
SmallVectorImpl<Function *> &Versions) {
2491+
static bool
2492+
collectVersions(Value *V, SmallVectorImpl<Function *> &Versions,
2493+
function_ref<TargetTransformInfo &(Function &)> GetTTI) {
24932494
if (auto *F = dyn_cast<Function>(V)) {
2494-
if (!TTI.isMultiversionedFunction(*F))
2495+
if (!GetTTI(*F).isMultiversionedFunction(*F))
24952496
return false;
24962497
Versions.push_back(F);
24972498
} else if (auto *Sel = dyn_cast<SelectInst>(V)) {
2498-
if (!collectVersions(TTI, Sel->getTrueValue(), Versions))
2499+
if (!collectVersions(Sel->getTrueValue(), Versions, GetTTI))
24992500
return false;
2500-
if (!collectVersions(TTI, Sel->getFalseValue(), Versions))
2501+
if (!collectVersions(Sel->getFalseValue(), Versions, GetTTI))
25012502
return false;
25022503
} else if (auto *Phi = dyn_cast<PHINode>(V)) {
25032504
for (unsigned I = 0, E = Phi->getNumIncomingValues(); I != E; ++I)
2504-
if (!collectVersions(TTI, Phi->getIncomingValue(I), Versions))
2505+
if (!collectVersions(Phi->getIncomingValue(I), Versions, GetTTI))
25052506
return false;
25062507
} else {
25072508
// Unknown instruction type. Bail.
@@ -2510,31 +2511,43 @@ static bool collectVersions(TargetTransformInfo &TTI, Value *V,
25102511
return true;
25112512
}
25122513

2513-
// Bypass the IFunc Resolver of MultiVersioned functions when possible. To
2514-
// deduce whether the optimization is legal we need to compare the target
2515-
// features between caller and callee versions. The criteria for bypassing
2516-
// the resolver are the following:
2517-
//
2518-
// * If the callee's feature set is a subset of the caller's feature set,
2519-
// then the callee is a candidate for direct call.
2520-
//
2521-
// * Among such candidates the one of highest priority is the best match
2522-
// and it shall be picked, unless there is a version of the callee with
2523-
// higher priority than the best match which cannot be picked from a
2524-
// higher priority caller (directly or through the resolver).
2525-
//
2526-
// * For every higher priority callee version than the best match, there
2527-
// is a higher priority caller version whose feature set availability
2528-
// is implied by the callee's feature set.
2514+
// Try to statically resolve calls to versioned functions when possible. First
2515+
// we identify the function versions which are associated with an IFUNC symbol.
2516+
// We do that by examining the resolver function of the IFUNC. Once we have
2517+
// collected all the function versions, we sort them in decreasing priority
2518+
// order. This is necessary for determining the most suitable callee version
2519+
// for each caller version. We then collect all the callsites to versioned
2520+
// functions. The static resolution is performed by comparing the feature sets
2521+
// between callers and callees. Specifically:
2522+
// * Start a walk over caller and callee lists simultaneously in order of
2523+
// decreasing priority.
2524+
// * Statically resolve calls from the current caller to the current callee,
2525+
// iff the caller feature bits are a superset of the callee feature bits.
2526+
// * For FMV callers, as long as the caller feature bits are a subset of the
2527+
// callee feature bits, advance to the next callee. This effectively prevents
2528+
// considering the current callee as a candidate for static resolution by
2529+
// following callers (explanation: preceding callers would not have been
2530+
// selected in a hypothetical runtime execution).
2531+
// * Advance to the next caller.
25292532
//
2533+
// Presentation in EuroLLVM2025:
2534+
// https://www.youtube.com/watch?v=k54MFimPz-A&t=867s
25302535
static bool OptimizeNonTrivialIFuncs(
25312536
Module &M, function_ref<TargetTransformInfo &(Function &)> GetTTI) {
25322537
bool Changed = false;
25332538

2534-
// Cache containing the mask constructed from a function's target features.
2539+
// Map containing the feature bits for a given function.
25352540
DenseMap<Function *, APInt> FeatureMask;
2541+
// Map containing all the function versions corresponding to an IFunc symbol.
2542+
DenseMap<GlobalIFunc *, SmallVector<Function *>> VersionedFuncs;
2543+
// Map containing the IFunc symbol a function is version of.
2544+
DenseMap<Function *, GlobalIFunc *> VersionOf;
2545+
// List of all the interesting IFuncs found in the module.
2546+
SmallVector<GlobalIFunc *> IFuncs;
25362547

25372548
for (GlobalIFunc &IF : M.ifuncs()) {
2549+
LLVM_DEBUG(dbgs() << "Examining IFUNC " << IF.getName() << "\n");
2550+
25382551
if (IF.isInterposable())
25392552
continue;
25402553

@@ -2545,107 +2558,147 @@ static bool OptimizeNonTrivialIFuncs(
25452558
if (Resolver->isInterposable())
25462559
continue;
25472560

2548-
TargetTransformInfo &TTI = GetTTI(*Resolver);
2549-
2550-
// Discover the callee versions.
2551-
SmallVector<Function *> Callees;
2552-
if (any_of(*Resolver, [&TTI, &Callees](BasicBlock &BB) {
2561+
SmallVector<Function *> Versions;
2562+
// Discover the versioned functions.
2563+
if (any_of(*Resolver, [&](BasicBlock &BB) {
25532564
if (auto *Ret = dyn_cast_or_null<ReturnInst>(BB.getTerminator()))
2554-
if (!collectVersions(TTI, Ret->getReturnValue(), Callees))
2565+
if (!collectVersions(Ret->getReturnValue(), Versions, GetTTI))
25552566
return true;
25562567
return false;
25572568
}))
25582569
continue;
25592570

2560-
if (Callees.empty())
2571+
if (Versions.empty())
25612572
continue;
25622573

2563-
LLVM_DEBUG(dbgs() << "Statically resolving calls to function "
2564-
<< Resolver->getName() << "\n");
2565-
2566-
// Cache the feature mask for each callee.
2567-
for (Function *Callee : Callees) {
2568-
auto [It, Inserted] = FeatureMask.try_emplace(Callee);
2574+
for (Function *V : Versions) {
2575+
VersionOf.insert({V, &IF});
2576+
auto [It, Inserted] = FeatureMask.try_emplace(V);
25692577
if (Inserted)
2570-
It->second = TTI.getFeatureMask(*Callee);
2578+
It->second = GetTTI(*V).getFeatureMask(*V);
25712579
}
25722580

2573-
// Sort the callee versions in decreasing priority order.
2574-
sort(Callees, [&](auto *LHS, auto *RHS) {
2581+
// Sort function versions in decreasing priority order.
2582+
sort(Versions, [&](auto *LHS, auto *RHS) {
25752583
return FeatureMask[LHS].ugt(FeatureMask[RHS]);
25762584
});
25772585

2578-
// Find the callsites and cache the feature mask for each caller.
2579-
SmallVector<Function *> Callers;
2586+
IFuncs.push_back(&IF);
2587+
VersionedFuncs.try_emplace(&IF, std::move(Versions));
2588+
}
2589+
2590+
for (GlobalIFunc *CalleeIF : IFuncs) {
2591+
SmallVector<Function *> NonFMVCallers;
2592+
DenseSet<GlobalIFunc *> CallerIFuncs;
25802593
DenseMap<Function *, SmallVector<CallBase *>> CallSites;
2581-
for (User *U : IF.users()) {
2594+
2595+
// Find the callsites.
2596+
for (User *U : CalleeIF->users()) {
25822597
if (auto *CB = dyn_cast<CallBase>(U)) {
2583-
if (CB->getCalledOperand() == &IF) {
2598+
if (CB->getCalledOperand() == CalleeIF) {
25842599
Function *Caller = CB->getFunction();
2585-
auto [FeatIt, FeatInserted] = FeatureMask.try_emplace(Caller);
2586-
if (FeatInserted)
2587-
FeatIt->second = TTI.getFeatureMask(*Caller);
2588-
auto [CallIt, CallInserted] = CallSites.try_emplace(Caller);
2589-
if (CallInserted)
2590-
Callers.push_back(Caller);
2591-
CallIt->second.push_back(CB);
2600+
GlobalIFunc *CallerIF = nullptr;
2601+
TargetTransformInfo &TTI = GetTTI(*Caller);
2602+
bool CallerIsFMV = TTI.isMultiversionedFunction(*Caller);
2603+
// The caller is a version of a known IFunc.
2604+
if (auto It = VersionOf.find(Caller); It != VersionOf.end())
2605+
CallerIF = It->second;
2606+
else if (!CallerIsFMV && OptimizeNonFMVCallers) {
2607+
// The caller is non-FMV.
2608+
auto [It, Inserted] = FeatureMask.try_emplace(Caller);
2609+
if (Inserted)
2610+
It->second = TTI.getFeatureMask(*Caller);
2611+
} else
2612+
// The caller is none of the above, skip.
2613+
continue;
2614+
auto [It, Inserted] = CallSites.try_emplace(Caller);
2615+
if (Inserted) {
2616+
if (CallerIsFMV)
2617+
CallerIFuncs.insert(CallerIF);
2618+
else
2619+
NonFMVCallers.push_back(Caller);
2620+
}
2621+
It->second.push_back(CB);
25922622
}
25932623
}
25942624
}
25952625

2596-
// Sort the caller versions in decreasing priority order.
2597-
sort(Callers, [&](auto *LHS, auto *RHS) {
2598-
return FeatureMask[LHS].ugt(FeatureMask[RHS]);
2599-
});
2626+
if (CallSites.empty())
2627+
continue;
26002628

2601-
auto implies = [](APInt A, APInt B) { return B.isSubsetOf(A); };
2602-
2603-
// Index to the highest priority candidate.
2604-
unsigned I = 0;
2605-
// Now try to redirect calls starting from higher priority callers.
2606-
for (Function *Caller : Callers) {
2607-
assert(I < Callees.size() && "Found callers of equal priority");
2608-
2609-
Function *Callee = Callees[I];
2610-
APInt CallerBits = FeatureMask[Caller];
2611-
APInt CalleeBits = FeatureMask[Callee];
2612-
2613-
// In the case of FMV callers, we know that all higher priority callers
2614-
// than the current one did not get selected at runtime, which helps
2615-
// reason about the callees (if they have versions that mandate presence
2616-
// of the features which we already know are unavailable on this target).
2617-
if (TTI.isMultiversionedFunction(*Caller)) {
2618-
// If the feature set of the caller implies the feature set of the
2619-
// highest priority candidate then it shall be picked. In case of
2620-
// identical sets advance the candidate index one position.
2621-
if (CallerBits == CalleeBits)
2622-
++I;
2623-
else if (!implies(CallerBits, CalleeBits)) {
2624-
// Keep advancing the candidate index as long as the caller's
2625-
// features are a subset of the current candidate's.
2626-
while (implies(CalleeBits, CallerBits)) {
2627-
if (++I == Callees.size())
2628-
break;
2629-
CalleeBits = FeatureMask[Callees[I]];
2629+
LLVM_DEBUG(dbgs() << "Statically resolving calls to function "
2630+
<< CalleeIF->getResolverFunction()->getName() << "\n");
2631+
2632+
// The complexity of this algorithm is linear: O(NumCallers + NumCallees).
2633+
// TODO
2634+
// A limitation it has is that we are not using information about the
2635+
// current caller to deduce why an earlier caller of higher priority was
2636+
// skipped. For example let's say the current caller is aes+sve2 and a
2637+
// previous caller was mops+sve2. Knowing that sve2 is available we could
2638+
// infer that mops is unavailable. This would allow us to skip callee
2639+
// versions which depend on mops. I tried implementing this but the
2640+
// complexity was cubic :/
2641+
auto staticallyResolveCalls = [&](ArrayRef<Function *> Callers,
2642+
ArrayRef<Function *> Callees,
2643+
bool CallerIsFMV) {
2644+
// Index to the highest callee candidate.
2645+
unsigned I = 0;
2646+
2647+
for (Function *const &Caller : Callers) {
2648+
if (I == Callees.size())
2649+
break;
2650+
2651+
LLVM_DEBUG(dbgs() << " Examining "
2652+
<< (CallerIsFMV ? "FMV" : "regular") << " caller "
2653+
<< Caller->getName() << "\n");
2654+
2655+
Function *Callee = Callees[I];
2656+
APInt CallerBits = FeatureMask[Caller];
2657+
APInt CalleeBits = FeatureMask[Callee];
2658+
2659+
// Statically resolve calls from the current caller to the current
2660+
// callee, iff the caller feature bits are a superset of the callee
2661+
// feature bits.
2662+
if (CalleeBits.isSubsetOf(CallerBits)) {
2663+
// Not all caller versions are necessarily users of the callee IFUNC.
2664+
if (auto It = CallSites.find(Caller); It != CallSites.end()) {
2665+
for (CallBase *CS : It->second) {
2666+
LLVM_DEBUG(dbgs() << " Redirecting call " << Caller->getName()
2667+
<< " -> " << Callee->getName() << "\n");
2668+
CS->setCalledOperand(Callee);
2669+
}
2670+
Changed = true;
26302671
}
2631-
continue;
26322672
}
2633-
} else {
2634-
// We can't reason much about non-FMV callers. Just pick the highest
2635-
// priority callee if it matches, otherwise bail.
2636-
if (!OptimizeNonFMVCallers || I > 0 || !implies(CallerBits, CalleeBits))
2673+
2674+
// Nothing else to do about non-FMV callers.
2675+
if (!CallerIsFMV)
26372676
continue;
2677+
2678+
// For FMV callers, as long as the caller feature bits are a subset of
2679+
// the callee feature bits, advance to the next callee. This effectively
2680+
// prevents considering the current callee as a candidate for static
2681+
// resolution by following callers.
2682+
while (CallerBits.isSubsetOf(FeatureMask[Callees[I]]) &&
2683+
++I < Callees.size())
2684+
;
26382685
}
2639-
auto &Calls = CallSites[Caller];
2640-
for (CallBase *CS : Calls) {
2641-
LLVM_DEBUG(dbgs() << "Redirecting call " << Caller->getName() << " -> "
2642-
<< Callee->getName() << "\n");
2643-
CS->setCalledOperand(Callee);
2644-
}
2645-
Changed = true;
2686+
};
2687+
2688+
auto &Callees = VersionedFuncs[CalleeIF];
2689+
2690+
// Optimize non-FMV calls.
2691+
if (OptimizeNonFMVCallers)
2692+
staticallyResolveCalls(NonFMVCallers, Callees, /*CallerIsFMV=*/false);
2693+
2694+
// Optimize FMV calls.
2695+
for (GlobalIFunc *CallerIF : CallerIFuncs) {
2696+
auto &Callers = VersionedFuncs[CallerIF];
2697+
staticallyResolveCalls(Callers, Callees, /*CallerIsFMV=*/true);
26462698
}
2647-
if (IF.use_empty() ||
2648-
all_of(IF.users(), [](User *U) { return isa<GlobalAlias>(U); }))
2699+
2700+
if (CalleeIF->use_empty() ||
2701+
all_of(CalleeIF->users(), [](User *U) { return isa<GlobalAlias>(U); }))
26492702
NumIFuncsResolved++;
26502703
}
26512704
return Changed;

0 commit comments

Comments
 (0)