@@ -2482,20 +2482,21 @@ DeleteDeadIFuncs(Module &M,
24822482// Follows the use-def chain of \p V backwards until it finds a Function,
24832483// in which case it collects in \p Versions. Return true on successful
24842484// use-def chain traversal, false otherwise.
2485- static bool collectVersions (TargetTransformInfo &TTI, Value *V,
2486- SmallVectorImpl<Function *> &Versions) {
2485+ static bool
2486+ collectVersions (Value *V, SmallVectorImpl<Function *> &Versions,
2487+ function_ref<TargetTransformInfo &(Function &)> GetTTI) {
24872488 if (auto *F = dyn_cast<Function>(V)) {
2488- if (!TTI .isMultiversionedFunction (*F))
2489+ if (!GetTTI (*F) .isMultiversionedFunction (*F))
24892490 return false ;
24902491 Versions.push_back (F);
24912492 } else if (auto *Sel = dyn_cast<SelectInst>(V)) {
2492- if (!collectVersions (TTI, Sel->getTrueValue (), Versions))
2493+ if (!collectVersions (Sel->getTrueValue (), Versions, GetTTI ))
24932494 return false ;
2494- if (!collectVersions (TTI, Sel->getFalseValue (), Versions))
2495+ if (!collectVersions (Sel->getFalseValue (), Versions, GetTTI ))
24952496 return false ;
24962497 } else if (auto *Phi = dyn_cast<PHINode>(V)) {
24972498 for (unsigned I = 0 , E = Phi->getNumIncomingValues (); I != E; ++I)
2498- if (!collectVersions (TTI, Phi->getIncomingValue (I), Versions))
2499+ if (!collectVersions (Phi->getIncomingValue (I), Versions, GetTTI ))
24992500 return false ;
25002501 } else {
25012502 // Unknown instruction type. Bail.
@@ -2525,8 +2526,14 @@ static bool OptimizeNonTrivialIFuncs(
25252526 Module &M, function_ref<TargetTransformInfo &(Function &)> GetTTI) {
25262527 bool Changed = false ;
25272528
2528- // Cache containing the mask constructed from a function's target features .
2529+ // Map containing the feature bits for a given function .
25292530 DenseMap<Function *, APInt> FeatureMask;
2531+ // Map containing all the versions corresponding to an IFunc symbol.
2532+ DenseMap<GlobalIFunc *, SmallVector<Function *>> VersionedFuncs;
2533+ // Map containing the IFunc symbol a function is version of.
2534+ DenseMap<Function *, GlobalIFunc *> VersionOf;
2535+ // List of all the interesting IFuncs found in the module.
2536+ SmallVector<GlobalIFunc *> IFuncs;
25302537
25312538 for (GlobalIFunc &IF : M.ifuncs ()) {
25322539 if (IF.isInterposable ())
@@ -2539,107 +2546,140 @@ static bool OptimizeNonTrivialIFuncs(
25392546 if (Resolver->isInterposable ())
25402547 continue ;
25412548
2542- TargetTransformInfo &TTI = GetTTI (*Resolver);
2543-
2544- // Discover the callee versions.
2545- SmallVector<Function *> Callees;
2546- if (any_of (*Resolver, [&TTI, &Callees](BasicBlock &BB) {
2549+ SmallVector<Function *> Versions;
2550+ // Discover the versioned functions.
2551+ if (any_of (*Resolver, [&](BasicBlock &BB) {
25472552 if (auto *Ret = dyn_cast_or_null<ReturnInst>(BB.getTerminator ()))
2548- if (!collectVersions (TTI, Ret->getReturnValue (), Callees ))
2553+ if (!collectVersions (Ret->getReturnValue (), Versions, GetTTI ))
25492554 return true ;
25502555 return false ;
25512556 }))
25522557 continue ;
25532558
2554- if (Callees .empty ())
2559+ if (Versions .empty ())
25552560 continue ;
25562561
2557- LLVM_DEBUG (dbgs () << " Statically resolving calls to function "
2558- << Resolver->getName () << " \n " );
2559-
2560- // Cache the feature mask for each callee.
2561- for (Function *Callee : Callees) {
2562- auto [It, Inserted] = FeatureMask.try_emplace (Callee);
2562+ for (Function *V : Versions) {
2563+ VersionOf.insert ({V, &IF});
2564+ auto [It, Inserted] = FeatureMask.try_emplace (V);
25632565 if (Inserted)
2564- It->second = TTI .getFeatureMask (*Callee );
2566+ It->second = GetTTI (*V) .getFeatureMask (*V );
25652567 }
25662568
2567- // Sort the callee versions in decreasing priority order.
2568- sort (Callees , [&](auto *LHS, auto *RHS) {
2569+ // Sort function versions in decreasing priority order.
2570+ sort (Versions , [&](auto *LHS, auto *RHS) {
25692571 return FeatureMask[LHS].ugt (FeatureMask[RHS]);
25702572 });
25712573
2572- // Find the callsites and cache the feature mask for each caller.
2573- SmallVector<Function *> Callers;
2574+ IFuncs.push_back (&IF);
2575+ VersionedFuncs.try_emplace (&IF, std::move (Versions));
2576+ }
2577+
2578+ for (GlobalIFunc *CalleeIF : IFuncs) {
2579+ SmallVector<Function *> NonFMVCallers;
2580+ SmallVector<GlobalIFunc *> CallerIFuncs;
25742581 DenseMap<Function *, SmallVector<CallBase *>> CallSites;
2575- for (User *U : IF.users ()) {
2582+
2583+ // Find the callsites.
2584+ for (User *U : CalleeIF->users ()) {
25762585 if (auto *CB = dyn_cast<CallBase>(U)) {
2577- if (CB->getCalledOperand () == &IF ) {
2586+ if (CB->getCalledOperand () == CalleeIF ) {
25782587 Function *Caller = CB->getFunction ();
2579- auto [FeatIt, FeatInserted] = FeatureMask.try_emplace (Caller);
2580- if (FeatInserted)
2581- FeatIt->second = TTI.getFeatureMask (*Caller);
2582- auto [CallIt, CallInserted] = CallSites.try_emplace (Caller);
2583- if (CallInserted)
2584- Callers.push_back (Caller);
2585- CallIt->second .push_back (CB);
2588+ GlobalIFunc *CallerIFunc = nullptr ;
2589+ TargetTransformInfo &TTI = GetTTI (*Caller);
2590+ bool CallerIsFMV = TTI.isMultiversionedFunction (*Caller);
2591+ // The caller is a version of a known IFunc.
2592+ if (auto It = VersionOf.find (Caller); It != VersionOf.end ())
2593+ CallerIFunc = It->second ;
2594+ else if (!CallerIsFMV && OptimizeNonFMVCallers) {
2595+ // The caller is non-FMV.
2596+ auto [It, Inserted] = FeatureMask.try_emplace (Caller);
2597+ if (Inserted)
2598+ It->second = TTI.getFeatureMask (*Caller);
2599+ } else
2600+ // The caller is none of the above, skip.
2601+ continue ;
2602+ auto [It, Inserted] = CallSites.try_emplace (Caller);
2603+ if (Inserted) {
2604+ if (CallerIsFMV)
2605+ CallerIFuncs.push_back (CallerIFunc);
2606+ else
2607+ NonFMVCallers.push_back (Caller);
2608+ }
2609+ It->second .push_back (CB);
25862610 }
25872611 }
25882612 }
25892613
2590- // Sort the caller versions in decreasing priority order.
2591- sort (Callers, [&](auto *LHS, auto *RHS) {
2592- return FeatureMask[LHS].ugt (FeatureMask[RHS]);
2593- });
2594-
2595- auto implies = [](APInt A, APInt B) { return B.isSubsetOf (A); };
2614+ LLVM_DEBUG (dbgs () << " Statically resolving calls to function "
2615+ << CalleeIF->getResolverFunction ()->getName () << " \n " );
25962616
2597- // Index to the highest priority candidate.
2598- unsigned I = 0 ;
2599- // Now try to redirect calls starting from higher priority callers.
2600- for (Function *Caller : Callers) {
2601- assert (I < Callees.size () && " Found callers of equal priority" );
2617+ auto redirectCalls = [&](SmallVectorImpl<Function *> &Callers,
2618+ SmallVectorImpl<Function *> &Callees) {
2619+ // Index to the current callee candidate.
2620+ unsigned I = 0 ;
26022621
2603- Function *Callee = Callees[I];
2604- APInt CallerBits = FeatureMask[Caller];
2605- APInt CalleeBits = FeatureMask[Callee];
2622+ // Try to redirect calls starting from higher priority callers.
2623+ for (Function *Caller : Callers) {
2624+ if (I == Callees.size ())
2625+ break ;
26062626
2607- // In the case of FMV callers, we know that all higher priority callers
2608- // than the current one did not get selected at runtime, which helps
2609- // reason about the callees (if they have versions that mandate presence
2610- // of the features which we already know are unavailable on this target).
2611- if (TTI.isMultiversionedFunction (*Caller)) {
2627+ bool CallerIsFMV = GetTTI (*Caller).isMultiversionedFunction (*Caller);
2628+ // In the case of FMV callers, we know that all higher priority callers
2629+ // than the current one did not get selected at runtime, which helps
2630+ // reason about the callees (if they have versions that mandate presence
2631+ // of the features which we already know are unavailable on this
2632+ // target).
2633+ if (!CallerIsFMV)
2634+ // We can't reason much about non-FMV callers. Just pick the highest
2635+ // priority callee if it matches, otherwise bail.
2636+ assert (I == 0 && " Should only select the highest priority candidate" );
2637+
2638+ Function *Callee = Callees[I];
2639+ APInt CallerBits = FeatureMask[Caller];
2640+ APInt CalleeBits = FeatureMask[Callee];
26122641 // If the feature set of the caller implies the feature set of the
2613- // highest priority candidate then it shall be picked. In case of
2614- // identical sets advance the candidate index one position.
2615- if (CallerBits == CalleeBits)
2616- ++I;
2617- else if (!implies (CallerBits, CalleeBits)) {
2618- // Keep advancing the candidate index as long as the caller's
2619- // features are a subset of the current candidate's.
2620- while (implies (CalleeBits, CallerBits)) {
2642+ // highest priority candidate then it shall be picked.
2643+ if (CalleeBits.isSubsetOf (CallerBits)) {
2644+ // If there are no records of call sites for this particular function
2645+ // version, then it is not actually a caller, in which case skip.
2646+ if (auto It = CallSites.find (Caller); It != CallSites.end ()) {
2647+ for (CallBase *CS : It->second ) {
2648+ LLVM_DEBUG (dbgs () << " Redirecting call " << Caller->getName ()
2649+ << " -> " << Callee->getName () << " \n " );
2650+ CS->setCalledOperand (Callee);
2651+ }
2652+ Changed = true ;
2653+ }
2654+ }
2655+ // Keep advancing the candidate index as long as the caller's
2656+ // features are a subset of the current candidate's.
2657+ if (CallerIsFMV) {
2658+ while (CallerBits.isSubsetOf (CalleeBits)) {
26212659 if (++I == Callees.size ())
26222660 break ;
26232661 CalleeBits = FeatureMask[Callees[I]];
26242662 }
2625- continue ;
26262663 }
2627- } else {
2628- // We can't reason much about non-FMV callers. Just pick the highest
2629- // priority callee if it matches, otherwise bail.
2630- if (!OptimizeNonFMVCallers || I > 0 || !implies (CallerBits, CalleeBits))
2631- continue ;
26322664 }
2633- auto &Calls = CallSites[Caller];
2634- for (CallBase *CS : Calls) {
2635- LLVM_DEBUG (dbgs () << " Redirecting call " << Caller->getName () << " -> "
2636- << Callee->getName () << " \n " );
2637- CS->setCalledOperand (Callee);
2665+ };
2666+
2667+ auto &Callees = VersionedFuncs[CalleeIF];
2668+
2669+ // Optimize non-FMV calls.
2670+ if (!NonFMVCallers.empty () && OptimizeNonFMVCallers)
2671+ redirectCalls (NonFMVCallers, Callees);
2672+
2673+ // Optimize FMV calls.
2674+ if (!CallerIFuncs.empty ()) {
2675+ for (GlobalIFunc *CallerIF : CallerIFuncs) {
2676+ auto &Callers = VersionedFuncs[CallerIF];
2677+ redirectCalls (Callers, Callees);
26382678 }
2639- Changed = true ;
26402679 }
2641- if (IF.use_empty () ||
2642- all_of (IF.users (), [](User *U) { return isa<GlobalAlias>(U); }))
2680+
2681+ if (CalleeIF->use_empty () ||
2682+ all_of (CalleeIF->users (), [](User *U) { return isa<GlobalAlias>(U); }))
26432683 NumIFuncsResolved++;
26442684 }
26452685 return Changed;
0 commit comments