@@ -2488,20 +2488,21 @@ DeleteDeadIFuncs(Module &M,
24882488// Follows the use-def chain of \p V backwards until it finds a Function,
24892489// in which case it collects in \p Versions. Return true on successful
24902490// use-def chain traversal, false otherwise.
2491- static bool collectVersions (TargetTransformInfo &TTI, Value *V,
2492- SmallVectorImpl<Function *> &Versions) {
2491+ static bool
2492+ collectVersions (Value *V, SmallVectorImpl<Function *> &Versions,
2493+ function_ref<TargetTransformInfo &(Function &)> GetTTI) {
24932494 if (auto *F = dyn_cast<Function>(V)) {
2494- if (!TTI .isMultiversionedFunction (*F))
2495+ if (!GetTTI (*F) .isMultiversionedFunction (*F))
24952496 return false ;
24962497 Versions.push_back (F);
24972498 } else if (auto *Sel = dyn_cast<SelectInst>(V)) {
2498- if (!collectVersions (TTI, Sel->getTrueValue (), Versions))
2499+ if (!collectVersions (Sel->getTrueValue (), Versions, GetTTI ))
24992500 return false ;
2500- if (!collectVersions (TTI, Sel->getFalseValue (), Versions))
2501+ if (!collectVersions (Sel->getFalseValue (), Versions, GetTTI ))
25012502 return false ;
25022503 } else if (auto *Phi = dyn_cast<PHINode>(V)) {
25032504 for (unsigned I = 0 , E = Phi->getNumIncomingValues (); I != E; ++I)
2504- if (!collectVersions (TTI, Phi->getIncomingValue (I), Versions))
2505+ if (!collectVersions (Phi->getIncomingValue (I), Versions, GetTTI ))
25052506 return false ;
25062507 } else {
25072508 // Unknown instruction type. Bail.
@@ -2510,31 +2511,43 @@ static bool collectVersions(TargetTransformInfo &TTI, Value *V,
25102511 return true ;
25112512}
25122513
2513- // Bypass the IFunc Resolver of MultiVersioned functions when possible. To
2514- // deduce whether the optimization is legal we need to compare the target
2515- // features between caller and callee versions. The criteria for bypassing
2516- // the resolver are the following:
2517- //
2518- // * If the callee's feature set is a subset of the caller's feature set,
2519- // then the callee is a candidate for direct call.
2520- //
2521- // * Among such candidates the one of highest priority is the best match
2522- // and it shall be picked, unless there is a version of the callee with
2523- // higher priority than the best match which cannot be picked from a
2524- // higher priority caller (directly or through the resolver).
2525- //
2526- // * For every higher priority callee version than the best match, there
2527- // is a higher priority caller version whose feature set availability
2528- // is implied by the callee's feature set.
2514+ // Try to statically resolve calls to versioned functions when possible. First
2515+ // we identify the function versions which are associated with an IFUNC symbol.
2516+ // We do that by examining the resolver function of the IFUNC. Once we have
2517+ // collected all the function versions, we sort them in decreasing priority
2518+ // order. This is necessary for determining the most suitable callee version
2519+ // for each caller version. We then collect all the callsites to versioned
2520+ // functions. The static resolution is performed by comparing the feature sets
2521+ // between callers and callees. Specifically:
2522+ // * Start a walk over caller and callee lists simultaneously in order of
2523+ // decreasing priority.
2524+ // * Statically resolve calls from the current caller to the current callee,
2525+ // iff the caller feature bits are a superset of the callee feature bits.
2526+ // * For FMV callers, as long as the caller feature bits are a subset of the
2527+ // callee feature bits, advance to the next callee. This effectively prevents
2528+ // considering the current callee as a candidate for static resolution by
2529+ // following callers (explanation: preceding callers would not have been
2530+ // selected in a hypothetical runtime execution).
2531+ // * Advance to the next caller.
25292532//
2533+ // Presentation in EuroLLVM2025:
2534+ // https://www.youtube.com/watch?v=k54MFimPz-A&t=867s
25302535static bool OptimizeNonTrivialIFuncs (
25312536 Module &M, function_ref<TargetTransformInfo &(Function &)> GetTTI) {
25322537 bool Changed = false ;
25332538
2534- // Cache containing the mask constructed from a function's target features .
2539+ // Map containing the feature bits for a given function .
25352540 DenseMap<Function *, APInt> FeatureMask;
2541+ // Map containing all the function versions corresponding to an IFunc symbol.
2542+ DenseMap<GlobalIFunc *, SmallVector<Function *>> VersionedFuncs;
2543+ // Map containing the IFunc symbol a function is version of.
2544+ DenseMap<Function *, GlobalIFunc *> VersionOf;
2545+ // List of all the interesting IFuncs found in the module.
2546+ SmallVector<GlobalIFunc *> IFuncs;
25362547
25372548 for (GlobalIFunc &IF : M.ifuncs ()) {
2549+ LLVM_DEBUG (dbgs () << " Examining IFUNC " << IF.getName () << " \n " );
2550+
25382551 if (IF.isInterposable ())
25392552 continue ;
25402553
@@ -2545,107 +2558,147 @@ static bool OptimizeNonTrivialIFuncs(
25452558 if (Resolver->isInterposable ())
25462559 continue ;
25472560
2548- TargetTransformInfo &TTI = GetTTI (*Resolver);
2549-
2550- // Discover the callee versions.
2551- SmallVector<Function *> Callees;
2552- if (any_of (*Resolver, [&TTI, &Callees](BasicBlock &BB) {
2561+ SmallVector<Function *> Versions;
2562+ // Discover the versioned functions.
2563+ if (any_of (*Resolver, [&](BasicBlock &BB) {
25532564 if (auto *Ret = dyn_cast_or_null<ReturnInst>(BB.getTerminator ()))
2554- if (!collectVersions (TTI, Ret->getReturnValue (), Callees ))
2565+ if (!collectVersions (Ret->getReturnValue (), Versions, GetTTI ))
25552566 return true ;
25562567 return false ;
25572568 }))
25582569 continue ;
25592570
2560- if (Callees .empty ())
2571+ if (Versions .empty ())
25612572 continue ;
25622573
2563- LLVM_DEBUG (dbgs () << " Statically resolving calls to function "
2564- << Resolver->getName () << " \n " );
2565-
2566- // Cache the feature mask for each callee.
2567- for (Function *Callee : Callees) {
2568- auto [It, Inserted] = FeatureMask.try_emplace (Callee);
2574+ for (Function *V : Versions) {
2575+ VersionOf.insert ({V, &IF});
2576+ auto [It, Inserted] = FeatureMask.try_emplace (V);
25692577 if (Inserted)
2570- It->second = TTI .getFeatureMask (*Callee );
2578+ It->second = GetTTI (*V) .getFeatureMask (*V );
25712579 }
25722580
2573- // Sort the callee versions in decreasing priority order.
2574- sort (Callees , [&](auto *LHS, auto *RHS) {
2581+ // Sort function versions in decreasing priority order.
2582+ sort (Versions , [&](auto *LHS, auto *RHS) {
25752583 return FeatureMask[LHS].ugt (FeatureMask[RHS]);
25762584 });
25772585
2578- // Find the callsites and cache the feature mask for each caller.
2579- SmallVector<Function *> Callers;
2586+ IFuncs.push_back (&IF);
2587+ VersionedFuncs.try_emplace (&IF, std::move (Versions));
2588+ }
2589+
2590+ for (GlobalIFunc *CalleeIF : IFuncs) {
2591+ SmallVector<Function *> NonFMVCallers;
2592+ DenseSet<GlobalIFunc *> CallerIFuncs;
25802593 DenseMap<Function *, SmallVector<CallBase *>> CallSites;
2581- for (User *U : IF.users ()) {
2594+
2595+ // Find the callsites.
2596+ for (User *U : CalleeIF->users ()) {
25822597 if (auto *CB = dyn_cast<CallBase>(U)) {
2583- if (CB->getCalledOperand () == &IF ) {
2598+ if (CB->getCalledOperand () == CalleeIF ) {
25842599 Function *Caller = CB->getFunction ();
2585- auto [FeatIt, FeatInserted] = FeatureMask.try_emplace (Caller);
2586- if (FeatInserted)
2587- FeatIt->second = TTI.getFeatureMask (*Caller);
2588- auto [CallIt, CallInserted] = CallSites.try_emplace (Caller);
2589- if (CallInserted)
2590- Callers.push_back (Caller);
2591- CallIt->second .push_back (CB);
2600+ GlobalIFunc *CallerIF = nullptr ;
2601+ TargetTransformInfo &TTI = GetTTI (*Caller);
2602+ bool CallerIsFMV = TTI.isMultiversionedFunction (*Caller);
2603+ // The caller is a version of a known IFunc.
2604+ if (auto It = VersionOf.find (Caller); It != VersionOf.end ())
2605+ CallerIF = It->second ;
2606+ else if (!CallerIsFMV && OptimizeNonFMVCallers) {
2607+ // The caller is non-FMV.
2608+ auto [It, Inserted] = FeatureMask.try_emplace (Caller);
2609+ if (Inserted)
2610+ It->second = TTI.getFeatureMask (*Caller);
2611+ } else
2612+ // The caller is none of the above, skip.
2613+ continue ;
2614+ auto [It, Inserted] = CallSites.try_emplace (Caller);
2615+ if (Inserted) {
2616+ if (CallerIsFMV)
2617+ CallerIFuncs.insert (CallerIF);
2618+ else
2619+ NonFMVCallers.push_back (Caller);
2620+ }
2621+ It->second .push_back (CB);
25922622 }
25932623 }
25942624 }
25952625
2596- // Sort the caller versions in decreasing priority order.
2597- sort (Callers, [&](auto *LHS, auto *RHS) {
2598- return FeatureMask[LHS].ugt (FeatureMask[RHS]);
2599- });
2626+ if (CallSites.empty ())
2627+ continue ;
26002628
2601- auto implies = [](APInt A, APInt B) { return B.isSubsetOf (A); };
2602-
2603- // Index to the highest priority candidate.
2604- unsigned I = 0 ;
2605- // Now try to redirect calls starting from higher priority callers.
2606- for (Function *Caller : Callers) {
2607- assert (I < Callees.size () && " Found callers of equal priority" );
2608-
2609- Function *Callee = Callees[I];
2610- APInt CallerBits = FeatureMask[Caller];
2611- APInt CalleeBits = FeatureMask[Callee];
2612-
2613- // In the case of FMV callers, we know that all higher priority callers
2614- // than the current one did not get selected at runtime, which helps
2615- // reason about the callees (if they have versions that mandate presence
2616- // of the features which we already know are unavailable on this target).
2617- if (TTI.isMultiversionedFunction (*Caller)) {
2618- // If the feature set of the caller implies the feature set of the
2619- // highest priority candidate then it shall be picked. In case of
2620- // identical sets advance the candidate index one position.
2621- if (CallerBits == CalleeBits)
2622- ++I;
2623- else if (!implies (CallerBits, CalleeBits)) {
2624- // Keep advancing the candidate index as long as the caller's
2625- // features are a subset of the current candidate's.
2626- while (implies (CalleeBits, CallerBits)) {
2627- if (++I == Callees.size ())
2628- break ;
2629- CalleeBits = FeatureMask[Callees[I]];
2629+ LLVM_DEBUG (dbgs () << " Statically resolving calls to function "
2630+ << CalleeIF->getResolverFunction ()->getName () << " \n " );
2631+
2632+ // The complexity of this algorithm is linear: O(NumCallers + NumCallees).
2633+ // TODO
2634+ // A limitation it has is that we are not using information about the
2635+ // current caller to deduce why an earlier caller of higher priority was
2636+ // skipped. For example let's say the current caller is aes+sve2 and a
2637+ // previous caller was mops+sve2. Knowing that sve2 is available we could
2638+ // infer that mops is unavailable. This would allow us to skip callee
2639+ // versions which depend on mops. I tried implementing this but the
2640+ // complexity was cubic :/
2641+ auto staticallyResolveCalls = [&](ArrayRef<Function *> Callers,
2642+ ArrayRef<Function *> Callees,
2643+ bool CallerIsFMV) {
2644+ // Index to the highest callee candidate.
2645+ unsigned I = 0 ;
2646+
2647+ for (Function *const &Caller : Callers) {
2648+ if (I == Callees.size ())
2649+ break ;
2650+
2651+ LLVM_DEBUG (dbgs () << " Examining "
2652+ << (CallerIsFMV ? " FMV" : " regular" ) << " caller "
2653+ << Caller->getName () << " \n " );
2654+
2655+ Function *Callee = Callees[I];
2656+ APInt CallerBits = FeatureMask[Caller];
2657+ APInt CalleeBits = FeatureMask[Callee];
2658+
2659+ // Statically resolve calls from the current caller to the current
2660+ // callee, iff the caller feature bits are a superset of the callee
2661+ // feature bits.
2662+ if (CalleeBits.isSubsetOf (CallerBits)) {
2663+ // Not all caller versions are necessarily users of the callee IFUNC.
2664+ if (auto It = CallSites.find (Caller); It != CallSites.end ()) {
2665+ for (CallBase *CS : It->second ) {
2666+ LLVM_DEBUG (dbgs () << " Redirecting call " << Caller->getName ()
2667+ << " -> " << Callee->getName () << " \n " );
2668+ CS->setCalledOperand (Callee);
2669+ }
2670+ Changed = true ;
26302671 }
2631- continue ;
26322672 }
2633- } else {
2634- // We can't reason much about non-FMV callers. Just pick the highest
2635- // priority callee if it matches, otherwise bail.
2636- if (!OptimizeNonFMVCallers || I > 0 || !implies (CallerBits, CalleeBits))
2673+
2674+ // Nothing else to do about non-FMV callers.
2675+ if (!CallerIsFMV)
26372676 continue ;
2677+
2678+ // For FMV callers, as long as the caller feature bits are a subset of
2679+ // the callee feature bits, advance to the next callee. This effectively
2680+ // prevents considering the current callee as a candidate for static
2681+ // resolution by following callers.
2682+ while (CallerBits.isSubsetOf (FeatureMask[Callees[I]]) &&
2683+ ++I < Callees.size ())
2684+ ;
26382685 }
2639- auto &Calls = CallSites[Caller];
2640- for (CallBase *CS : Calls) {
2641- LLVM_DEBUG (dbgs () << " Redirecting call " << Caller->getName () << " -> "
2642- << Callee->getName () << " \n " );
2643- CS->setCalledOperand (Callee);
2644- }
2645- Changed = true ;
2686+ };
2687+
2688+ auto &Callees = VersionedFuncs[CalleeIF];
2689+
2690+ // Optimize non-FMV calls.
2691+ if (OptimizeNonFMVCallers)
2692+ staticallyResolveCalls (NonFMVCallers, Callees, /* CallerIsFMV=*/ false );
2693+
2694+ // Optimize FMV calls.
2695+ for (GlobalIFunc *CallerIF : CallerIFuncs) {
2696+ auto &Callers = VersionedFuncs[CallerIF];
2697+ staticallyResolveCalls (Callers, Callees, /* CallerIsFMV=*/ true );
26462698 }
2647- if (IF.use_empty () ||
2648- all_of (IF.users (), [](User *U) { return isa<GlobalAlias>(U); }))
2699+
2700+ if (CalleeIF->use_empty () ||
2701+ all_of (CalleeIF->users (), [](User *U) { return isa<GlobalAlias>(U); }))
26492702 NumIFuncsResolved++;
26502703 }
26512704 return Changed;
0 commit comments