Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
212 changes: 131 additions & 81 deletions llvm/lib/Transforms/IPO/GlobalOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2482,20 +2482,21 @@ DeleteDeadIFuncs(Module &M,
// Follows the use-def chain of \p V backwards until it finds a Function,
// in which case it collects in \p Versions. Return true on successful
// use-def chain traversal, false otherwise.
static bool collectVersions(TargetTransformInfo &TTI, Value *V,
SmallVectorImpl<Function *> &Versions) {
static bool
collectVersions(Value *V, SmallVectorImpl<Function *> &Versions,
function_ref<TargetTransformInfo &(Function &)> GetTTI) {
if (auto *F = dyn_cast<Function>(V)) {
if (!TTI.isMultiversionedFunction(*F))
if (!GetTTI(*F).isMultiversionedFunction(*F))
return false;
Versions.push_back(F);
} else if (auto *Sel = dyn_cast<SelectInst>(V)) {
if (!collectVersions(TTI, Sel->getTrueValue(), Versions))
if (!collectVersions(Sel->getTrueValue(), Versions, GetTTI))
return false;
if (!collectVersions(TTI, Sel->getFalseValue(), Versions))
if (!collectVersions(Sel->getFalseValue(), Versions, GetTTI))
return false;
} else if (auto *Phi = dyn_cast<PHINode>(V)) {
for (unsigned I = 0, E = Phi->getNumIncomingValues(); I != E; ++I)
if (!collectVersions(TTI, Phi->getIncomingValue(I), Versions))
if (!collectVersions(Phi->getIncomingValue(I), Versions, GetTTI))
return false;
} else {
// Unknown instruction type. Bail.
Expand Down Expand Up @@ -2525,8 +2526,14 @@ static bool OptimizeNonTrivialIFuncs(
Module &M, function_ref<TargetTransformInfo &(Function &)> GetTTI) {
bool Changed = false;

// Cache containing the mask constructed from a function's target features.
// Map containing the feature bits for a given function.
DenseMap<Function *, APInt> FeatureMask;
// Map containing all the versions corresponding to an IFunc symbol.
DenseMap<GlobalIFunc *, SmallVector<Function *>> VersionedFuncs;
// Map containing the IFunc symbol a function is version of.
DenseMap<Function *, GlobalIFunc *> VersionOf;
// List of all the interesting IFuncs found in the module.
SmallVector<GlobalIFunc *> IFuncs;

for (GlobalIFunc &IF : M.ifuncs()) {
if (IF.isInterposable())
Expand All @@ -2539,107 +2546,150 @@ static bool OptimizeNonTrivialIFuncs(
if (Resolver->isInterposable())
continue;

TargetTransformInfo &TTI = GetTTI(*Resolver);

// Discover the callee versions.
SmallVector<Function *> Callees;
if (any_of(*Resolver, [&TTI, &Callees](BasicBlock &BB) {
SmallVector<Function *> Versions;
// Discover the versioned functions.
if (any_of(*Resolver, [&](BasicBlock &BB) {
if (auto *Ret = dyn_cast_or_null<ReturnInst>(BB.getTerminator()))
if (!collectVersions(TTI, Ret->getReturnValue(), Callees))
if (!collectVersions(Ret->getReturnValue(), Versions, GetTTI))
return true;
return false;
}))
continue;

if (Callees.empty())
if (Versions.empty())
continue;

LLVM_DEBUG(dbgs() << "Statically resolving calls to function "
<< Resolver->getName() << "\n");

// Cache the feature mask for each callee.
for (Function *Callee : Callees) {
auto [It, Inserted] = FeatureMask.try_emplace(Callee);
for (Function *V : Versions) {
VersionOf.insert({V, &IF});
auto [It, Inserted] = FeatureMask.try_emplace(V);
if (Inserted)
It->second = TTI.getFeatureMask(*Callee);
It->second = GetTTI(*V).getFeatureMask(*V);
}

// Sort the callee versions in decreasing priority order.
sort(Callees, [&](auto *LHS, auto *RHS) {
// Sort function versions in decreasing priority order.
sort(Versions, [&](auto *LHS, auto *RHS) {
return FeatureMask[LHS].ugt(FeatureMask[RHS]);
});

// Find the callsites and cache the feature mask for each caller.
SmallVector<Function *> Callers;
IFuncs.push_back(&IF);
VersionedFuncs.try_emplace(&IF, std::move(Versions));
}

for (GlobalIFunc *CalleeIF : IFuncs) {
SmallVector<Function *> NonFMVCallers;
SmallVector<GlobalIFunc *> CallerIFuncs;
DenseMap<Function *, SmallVector<CallBase *>> CallSites;
for (User *U : IF.users()) {

// Find the callsites.
for (User *U : CalleeIF->users()) {
if (auto *CB = dyn_cast<CallBase>(U)) {
if (CB->getCalledOperand() == &IF) {
if (CB->getCalledOperand() == CalleeIF) {
Function *Caller = CB->getFunction();
auto [FeatIt, FeatInserted] = FeatureMask.try_emplace(Caller);
if (FeatInserted)
FeatIt->second = TTI.getFeatureMask(*Caller);
auto [CallIt, CallInserted] = CallSites.try_emplace(Caller);
if (CallInserted)
Callers.push_back(Caller);
CallIt->second.push_back(CB);
GlobalIFunc *CallerIF = nullptr;
TargetTransformInfo &TTI = GetTTI(*Caller);
bool CallerIsFMV = TTI.isMultiversionedFunction(*Caller);
// The caller is a version of a known IFunc.
if (auto It = VersionOf.find(Caller); It != VersionOf.end())
CallerIF = It->second;
else if (!CallerIsFMV && OptimizeNonFMVCallers) {
// The caller is non-FMV.
auto [It, Inserted] = FeatureMask.try_emplace(Caller);
if (Inserted)
It->second = TTI.getFeatureMask(*Caller);
} else
// The caller is none of the above, skip.
continue;
auto [It, Inserted] = CallSites.try_emplace(Caller);
if (Inserted) {
if (CallerIsFMV)
CallerIFuncs.push_back(CallerIF);
else
NonFMVCallers.push_back(Caller);
}
It->second.push_back(CB);
}
}
}

// Sort the caller versions in decreasing priority order.
sort(Callers, [&](auto *LHS, auto *RHS) {
return FeatureMask[LHS].ugt(FeatureMask[RHS]);
});

auto implies = [](APInt A, APInt B) { return B.isSubsetOf(A); };

// Index to the highest priority candidate.
unsigned I = 0;
// Now try to redirect calls starting from higher priority callers.
for (Function *Caller : Callers) {
assert(I < Callees.size() && "Found callers of equal priority");

Function *Callee = Callees[I];
APInt CallerBits = FeatureMask[Caller];
APInt CalleeBits = FeatureMask[Callee];
LLVM_DEBUG(dbgs() << "Statically resolving calls to function "
<< CalleeIF->getResolverFunction()->getName() << "\n");

auto redirectCalls = [&](SmallVectorImpl<Function *> &Callers,
SmallVectorImpl<Function *> &Callees) {
// Index to the current callee candidate.
unsigned I = 0;
// Feature bits from callers of previous iterations.
SmallVector<APInt> KnownBits;

// Try to redirect calls starting from higher priority callers.
for (Function *Caller : Callers) {
if (I == Callees.size())
break;

// In the case of FMV callers, we know that all higher priority callers
// than the current one did not get selected at runtime, which helps
// reason about the callees (if they have versions that mandate presence
// of the features which we already know are unavailable on this target).
if (TTI.isMultiversionedFunction(*Caller)) {
// If the feature set of the caller implies the feature set of the
// highest priority candidate then it shall be picked. In case of
// identical sets advance the candidate index one position.
if (CallerBits == CalleeBits)
++I;
else if (!implies(CallerBits, CalleeBits)) {
// Keep advancing the candidate index as long as the caller's
bool CallerIsFMV = GetTTI(*Caller).isMultiversionedFunction(*Caller);
// We can't reason much about non-FMV callers. Just pick the highest
// priority callee if it matches, otherwise bail.
if (!CallerIsFMV)
assert(I == 0 && "Should only select the highest priority candidate");
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This may require curly braces for builds without assertions.


APInt CallerBits = FeatureMask[Caller];
APInt CalleeBits = FeatureMask[Callees[I]];
// In the case of FMV callers, we know that all higher priority callers
// than the current one did not get selected at runtime, which helps
// reason about the callees (if they have versions that mandate presence
// of the features which we already know are unavailable on this
// target, then we can skip over those versions/candidates).
if (CallerIsFMV) {
// Keep advancing the candidate index as long as the unavailable
// features are a subset of the current candidate's.
while (implies(CalleeBits, CallerBits)) {
if (++I == Callees.size())
break;
CalleeBits = FeatureMask[Callees[I]];
unsigned J = 0;
while (J < KnownBits.size()) {
// Discard feature bits that are known to be available
// in the current iteration.
APInt Version = KnownBits[J] & ~CallerBits;
if (Version.isSubsetOf(CalleeBits)) {
if (++I == Callees.size())
break;
CalleeBits = FeatureMask[Callees[I]];
// Start over.
J = 0;
} else
++J;
}
KnownBits.push_back(CallerBits);
}
Function *Callee = Callees[I];
// If the feature set of the caller implies the feature set of the
// highest priority candidate then it shall be picked.
if (CalleeBits.isSubsetOf(CallerBits)) {
// If there are no records of call sites for this particular function
// version, then it is not actually a caller, in which case skip.
if (auto It = CallSites.find(Caller); It != CallSites.end()) {
for (CallBase *CS : It->second) {
LLVM_DEBUG(dbgs() << "Redirecting call " << Caller->getName()
<< " -> " << Callee->getName() << "\n");
CS->setCalledOperand(Callee);
}
Changed = true;
}
continue;
}
} else {
// We can't reason much about non-FMV callers. Just pick the highest
// priority callee if it matches, otherwise bail.
if (!OptimizeNonFMVCallers || I > 0 || !implies(CallerBits, CalleeBits))
continue;
}
auto &Calls = CallSites[Caller];
for (CallBase *CS : Calls) {
LLVM_DEBUG(dbgs() << "Redirecting call " << Caller->getName() << " -> "
<< Callee->getName() << "\n");
CS->setCalledOperand(Callee);
}
Changed = true;
};

auto &Callees = VersionedFuncs[CalleeIF];

// Optimize non-FMV calls.
if (OptimizeNonFMVCallers)
redirectCalls(NonFMVCallers, Callees);

// Optimize FMV calls.
for (GlobalIFunc *CallerIF : CallerIFuncs) {
auto &Callers = VersionedFuncs[CallerIF];
redirectCalls(Callers, Callees);
}
if (IF.use_empty() ||
all_of(IF.users(), [](User *U) { return isa<GlobalAlias>(U); }))

if (CalleeIF->use_empty() ||
all_of(CalleeIF->users(), [](User *U) { return isa<GlobalAlias>(U); }))
NumIFuncsResolved++;
}
return Changed;
Expand Down
Loading