@@ -89,7 +89,7 @@ STATISTIC(NumAliasesRemoved, "Number of global aliases eliminated");
8989STATISTIC (NumCXXDtorsRemoved, " Number of global C++ destructors removed" );
9090STATISTIC (NumInternalFunc, " Number of internal functions" );
9191STATISTIC (NumColdCC, " Number of functions marked coldcc" );
92- STATISTIC (NumIFuncsResolved, " Number of statically resolved IFuncs" );
92+ STATISTIC (NumIFuncsResolved, " Number of resolved IFuncs" );
9393STATISTIC (NumIFuncsDeleted, " Number of IFuncs removed" );
9494
9595static cl::opt<bool >
@@ -2462,6 +2462,228 @@ DeleteDeadIFuncs(Module &M,
24622462 return Changed;
24632463}
24642464
2465+ static Function *foldResolverForCallSite (CallBase *CS, uint64_t Priority,
2466+ TargetTransformInfo &TTI) {
2467+ // Look for the instruction which feeds the feature mask to the users.
2468+ auto findRoot = [&TTI](Function *F) -> Instruction * {
2469+ for (Instruction &I : F->getEntryBlock ())
2470+ if (auto *Load = dyn_cast<LoadInst>(&I))
2471+ if (Load->getPointerOperand () == TTI.getCPUFeatures (*F->getParent ()))
2472+ return Load;
2473+ return nullptr ;
2474+ };
2475+
2476+ auto *IF = cast<GlobalIFunc>(CS->getCalledOperand ());
2477+ Instruction *Root = findRoot (IF->getResolverFunction ());
2478+ // There is no such instruction. Bail.
2479+ if (!Root)
2480+ return nullptr ;
2481+
2482+ // Create a constant mask to use as seed for the constant propagation.
2483+ Constant *Seed = Constant::getIntegerValue (
2484+ Root->getType (), APInt (Root->getType ()->getIntegerBitWidth (), Priority));
2485+
2486+ auto DL = CS->getModule ()->getDataLayout ();
2487+
2488+ // Recursively propagate on single use chains.
2489+ std::function<Constant *(Instruction *, Instruction *, Constant *,
2490+ BasicBlock *)>
2491+ constFoldInst = [&](Instruction *I, Instruction *Use, Constant *C,
2492+ BasicBlock *Pred) -> Constant * {
2493+ // Base case.
2494+ if (auto *Ret = dyn_cast<ReturnInst>(I))
2495+ if (Ret->getReturnValue () == Use)
2496+ return C;
2497+
2498+ // Minimal set of instruction types to handle.
2499+ if (auto *BinOp = dyn_cast<BinaryOperator>(I)) {
2500+ bool Swap = BinOp->getOperand (1 ) == Use;
2501+ if (auto *Other = dyn_cast<Constant>(BinOp->getOperand (Swap ? 0 : 1 )))
2502+ C = Swap ? ConstantFoldBinaryInstruction (BinOp->getOpcode (), Other, C)
2503+ : ConstantFoldBinaryInstruction (BinOp->getOpcode (), C, Other);
2504+ } else if (auto *Cmp = dyn_cast<CmpInst>(I)) {
2505+ bool Swap = Cmp->getOperand (1 ) == Use;
2506+ if (auto *Other = dyn_cast<Constant>(Cmp->getOperand (Swap ? 0 : 1 )))
2507+ C = Swap ? ConstantFoldCompareInstOperands (Cmp->getPredicate (), Other,
2508+ C, DL)
2509+ : ConstantFoldCompareInstOperands (Cmp->getPredicate (), C,
2510+ Other, DL);
2511+ } else if (auto *Sel = dyn_cast<SelectInst>(I)) {
2512+ if (Sel->getCondition () == Use)
2513+ C = dyn_cast<Constant>(C->isZeroValue () ? Sel->getFalseValue ()
2514+ : Sel->getTrueValue ());
2515+ } else if (auto *Phi = dyn_cast<PHINode>(I)) {
2516+ if (Pred)
2517+ C = dyn_cast<Constant>(Phi->getIncomingValueForBlock (Pred));
2518+ } else if (auto *Br = dyn_cast<BranchInst>(I)) {
2519+ if (Br->getCondition () == Use) {
2520+ BasicBlock *BB = Br->getSuccessor (C->isZeroValue ());
2521+ return constFoldInst (&BB->front (), Root, Seed, Br->getParent ());
2522+ }
2523+ } else {
2524+ // Don't know how to handle. Bail.
2525+ return nullptr ;
2526+ }
2527+
2528+ // Folding succeeded. Continue.
2529+ if (C && I->hasOneUse ())
2530+ if (auto *UI = dyn_cast<Instruction>(I->user_back ()))
2531+ return constFoldInst (UI, I, C, nullptr );
2532+
2533+ return nullptr ;
2534+ };
2535+
2536+ // Collect all users in the entry block ordered by proximity. The rest of
2537+ // them can be discovered later. Unfortunately we cannot simply traverse
2538+ // the Root's 'users()' as their order is not the same as execution order.
2539+ unsigned NUsersLeft = std::distance (Root->user_begin (), Root->user_end ());
2540+ SmallVector<Instruction *> Users;
2541+ for (Instruction &I : *Root->getParent ()) {
2542+ if (any_of (I.operands (), [Root](auto &Op) { return Op == Root; })) {
2543+ Users.push_back (&I);
2544+ if (--NUsersLeft == 0 )
2545+ break ;
2546+ }
2547+ }
2548+
2549+ // Return as soon as we find a foldable user. It has the highest priority.
2550+ for (Instruction *I : Users) {
2551+ Constant *C = constFoldInst (I, Root, Seed, nullptr );
2552+ if (C)
2553+ return cast<Function>(C);
2554+ }
2555+
2556+ return nullptr ;
2557+ }
2558+
2559+ // Bypass the IFunc Resolver of MultiVersioned functions when possible. To
2560+ // deduce whether the optimization is legal we need to compare the target
2561+ // features between caller and callee versions. The criteria for bypassing
2562+ // the resolver are the following:
2563+ //
2564+ // * If the callee's feature set is a subset of the caller's feature set,
2565+ // then the callee is a candidate for direct call.
2566+ //
2567+ // * Among such candidates the one of highest priority is the best match
2568+ // and it shall be picked, unless there is a version of the callee with
2569+ // higher priority than the best match which cannot be picked because
2570+ // there is no corresponding caller for whom it would have been the best
2571+ // match.
2572+ //
2573+ static bool OptimizeNonTrivialIFuncs (
2574+ Module &M, function_ref<TargetTransformInfo &(Function &)> GetTTI) {
2575+ bool Changed = false ;
2576+
2577+ std::function<void (Value *, SmallVectorImpl<Function *> &)> visitValue =
2578+ [&](Value *V, SmallVectorImpl<Function *> &FuncVersions) {
2579+ if (auto *Func = dyn_cast<Function>(V)) {
2580+ FuncVersions.push_back (Func);
2581+ } else if (auto *Sel = dyn_cast<SelectInst>(V)) {
2582+ visitValue (Sel->getTrueValue (), FuncVersions);
2583+ visitValue (Sel->getFalseValue (), FuncVersions);
2584+ } else if (auto *Phi = dyn_cast<PHINode>(V))
2585+ for (unsigned I = 0 , E = Phi->getNumIncomingValues (); I != E; ++I)
2586+ visitValue (Phi->getIncomingValue (I), FuncVersions);
2587+ };
2588+
2589+ // Cache containing the mask constructed from a function's target features.
2590+ DenseMap<Function *, uint64_t > FeaturePriorityMap;
2591+
2592+ for (GlobalIFunc &IF : M.ifuncs ()) {
2593+ if (IF.isInterposable ())
2594+ continue ;
2595+
2596+ Function *Resolver = IF.getResolverFunction ();
2597+ if (!Resolver)
2598+ continue ;
2599+
2600+ if (Resolver->isInterposable ())
2601+ continue ;
2602+
2603+ TargetTransformInfo &TTI = GetTTI (*Resolver);
2604+ if (!TTI.hasFMV ())
2605+ return false ;
2606+
2607+ // Discover the callee versions.
2608+ SmallVector<Function *> Callees;
2609+ for (BasicBlock &BB : *Resolver)
2610+ if (auto *Ret = dyn_cast_or_null<ReturnInst>(BB.getTerminator ()))
2611+ visitValue (Ret->getReturnValue (), Callees);
2612+
2613+ if (Callees.empty ())
2614+ continue ;
2615+
2616+ // Cache the feature mask for each callee.
2617+ for (Function *Callee : Callees) {
2618+ auto [It, Inserted] = FeaturePriorityMap.try_emplace (Callee);
2619+ if (Inserted)
2620+ It->second = TTI.getFMVPriority (*Callee);
2621+ }
2622+
2623+ // Sort the callee versions in increasing feature priority order.
2624+ // Every time we find a caller that matches the highest priority
2625+ // callee we pop_back() one from this ordered list.
2626+ llvm::stable_sort (Callees, [&](auto *LHS, auto *RHS) {
2627+ return FeaturePriorityMap[LHS] < FeaturePriorityMap[RHS];
2628+ });
2629+
2630+ // Find the callsites and cache the feature mask for each caller.
2631+ SmallVector<CallBase *> CallSites;
2632+ for (User *U : IF.users ()) {
2633+ if (auto *CB = dyn_cast<CallBase>(U)) {
2634+ if (CB->getCalledOperand () == &IF) {
2635+ Function *Caller = CB->getFunction ();
2636+ auto [It, Inserted] = FeaturePriorityMap.try_emplace (Caller);
2637+ if (Inserted)
2638+ It->second = TTI.getFMVPriority (*Caller);
2639+ CallSites.push_back (CB);
2640+ }
2641+ }
2642+ }
2643+
2644+ // Sort the callsites in decreasing feature priority order.
2645+ llvm::stable_sort (CallSites, [&](auto *LHS, auto *RHS) {
2646+ return FeaturePriorityMap[LHS->getFunction ()] >
2647+ FeaturePriorityMap[RHS->getFunction ()];
2648+ });
2649+
2650+ // Now try to constant fold the resolver for every callsite starting
2651+ // from higher priority callers. This guarantees that as soon as we
2652+ // find a callee whose priority is lower than the expected best match
2653+ // then there is no point in continuing further.
2654+ DenseMap<uint64_t , Function *> foldedResolverCache;
2655+ for (CallBase *CS : CallSites) {
2656+ uint64_t CallerPriority = FeaturePriorityMap[CS->getFunction ()];
2657+ auto [It, Inserted] = foldedResolverCache.try_emplace (CallerPriority);
2658+ Function *&Callee = It->second ;
2659+ if (Inserted)
2660+ Callee = foldResolverForCallSite (CS, CallerPriority, TTI);
2661+ if (Callee) {
2662+ if (!Callees.empty ()) {
2663+ // If the priority of the candidate is greater or equal to
2664+ // the expected best match then it shall be picked. Otherwise
2665+ // there is a higher priority callee without a corresponding
2666+ // caller, in which case abort.
2667+ uint64_t CalleePriority = FeaturePriorityMap[Callee];
2668+ if (CalleePriority == FeaturePriorityMap[Callees.back ()])
2669+ Callees.pop_back ();
2670+ else if (CalleePriority < FeaturePriorityMap[Callees.back ()])
2671+ break ;
2672+ }
2673+ CS->setCalledOperand (Callee);
2674+ Changed = true ;
2675+ } else {
2676+ // Oops, something went wrong. We couldn't fold. Abort.
2677+ break ;
2678+ }
2679+ }
2680+ if (IF.use_empty () ||
2681+ all_of (IF.users (), [](User *U) { return isa<GlobalAlias>(U); }))
2682+ NumIFuncsResolved++;
2683+ }
2684+ return Changed;
2685+ }
2686+
24652687static bool
24662688optimizeGlobalsInModule (Module &M, const DataLayout &DL,
24672689 function_ref<TargetLibraryInfo &(Function &)> GetTLI,
@@ -2525,6 +2747,9 @@ optimizeGlobalsInModule(Module &M, const DataLayout &DL,
25252747 // Optimize IFuncs whose callee's are statically known.
25262748 LocalChange |= OptimizeStaticIFuncs (M);
25272749
2750+ // Optimize IFuncs based on the target features of the caller.
2751+ LocalChange |= OptimizeNonTrivialIFuncs (M, GetTTI);
2752+
25282753 // Remove any IFuncs that are now dead.
25292754 LocalChange |= DeleteDeadIFuncs (M, NotDiscardableComdats);
25302755
0 commit comments