Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
190 changes: 115 additions & 75 deletions llvm/lib/Transforms/IPO/GlobalOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2482,20 +2482,21 @@ DeleteDeadIFuncs(Module &M,
// Follows the use-def chain of \p V backwards until it finds a Function,
// in which case it collects in \p Versions. Return true on successful
// use-def chain traversal, false otherwise.
static bool collectVersions(TargetTransformInfo &TTI, Value *V,
SmallVectorImpl<Function *> &Versions) {
static bool
collectVersions(Value *V, SmallVectorImpl<Function *> &Versions,
function_ref<TargetTransformInfo &(Function &)> GetTTI) {
if (auto *F = dyn_cast<Function>(V)) {
if (!TTI.isMultiversionedFunction(*F))
if (!GetTTI(*F).isMultiversionedFunction(*F))
return false;
Versions.push_back(F);
} else if (auto *Sel = dyn_cast<SelectInst>(V)) {
if (!collectVersions(TTI, Sel->getTrueValue(), Versions))
if (!collectVersions(Sel->getTrueValue(), Versions, GetTTI))
return false;
if (!collectVersions(TTI, Sel->getFalseValue(), Versions))
if (!collectVersions(Sel->getFalseValue(), Versions, GetTTI))
return false;
} else if (auto *Phi = dyn_cast<PHINode>(V)) {
for (unsigned I = 0, E = Phi->getNumIncomingValues(); I != E; ++I)
if (!collectVersions(TTI, Phi->getIncomingValue(I), Versions))
if (!collectVersions(Phi->getIncomingValue(I), Versions, GetTTI))
return false;
} else {
// Unknown instruction type. Bail.
Expand Down Expand Up @@ -2525,8 +2526,14 @@ static bool OptimizeNonTrivialIFuncs(
Module &M, function_ref<TargetTransformInfo &(Function &)> GetTTI) {
bool Changed = false;

// Cache containing the mask constructed from a function's target features.
// Map containing the feature bits for a given function.
DenseMap<Function *, APInt> FeatureMask;
// Map containing all the versions corresponding to an IFunc symbol.
DenseMap<GlobalIFunc *, SmallVector<Function *>> VersionedFuncs;
// Map containing the IFunc symbol a function is version of.
DenseMap<Function *, GlobalIFunc *> VersionOf;
// List of all the interesting IFuncs found in the module.
SmallVector<GlobalIFunc *> IFuncs;

for (GlobalIFunc &IF : M.ifuncs()) {
if (IF.isInterposable())
Expand All @@ -2539,107 +2546,140 @@ static bool OptimizeNonTrivialIFuncs(
if (Resolver->isInterposable())
continue;

TargetTransformInfo &TTI = GetTTI(*Resolver);

// Discover the callee versions.
SmallVector<Function *> Callees;
if (any_of(*Resolver, [&TTI, &Callees](BasicBlock &BB) {
SmallVector<Function *> Versions;
// Discover the versioned functions.
if (any_of(*Resolver, [&](BasicBlock &BB) {
if (auto *Ret = dyn_cast_or_null<ReturnInst>(BB.getTerminator()))
if (!collectVersions(TTI, Ret->getReturnValue(), Callees))
if (!collectVersions(Ret->getReturnValue(), Versions, GetTTI))
return true;
return false;
}))
continue;

if (Callees.empty())
if (Versions.empty())
continue;

LLVM_DEBUG(dbgs() << "Statically resolving calls to function "
<< Resolver->getName() << "\n");

// Cache the feature mask for each callee.
for (Function *Callee : Callees) {
auto [It, Inserted] = FeatureMask.try_emplace(Callee);
for (Function *V : Versions) {
VersionOf.insert({V, &IF});
auto [It, Inserted] = FeatureMask.try_emplace(V);
if (Inserted)
It->second = TTI.getFeatureMask(*Callee);
It->second = GetTTI(*V).getFeatureMask(*V);
}

// Sort the callee versions in decreasing priority order.
sort(Callees, [&](auto *LHS, auto *RHS) {
// Sort function versions in decreasing priority order.
sort(Versions, [&](auto *LHS, auto *RHS) {
return FeatureMask[LHS].ugt(FeatureMask[RHS]);
});

// Find the callsites and cache the feature mask for each caller.
SmallVector<Function *> Callers;
IFuncs.push_back(&IF);
VersionedFuncs.try_emplace(&IF, std::move(Versions));
}

for (GlobalIFunc *CalleeIF : IFuncs) {
SmallVector<Function *> NonFMVCallers;
SmallVector<GlobalIFunc *> CallerIFuncs;
DenseMap<Function *, SmallVector<CallBase *>> CallSites;
for (User *U : IF.users()) {

// Find the callsites.
for (User *U : CalleeIF->users()) {
if (auto *CB = dyn_cast<CallBase>(U)) {
if (CB->getCalledOperand() == &IF) {
if (CB->getCalledOperand() == CalleeIF) {
Function *Caller = CB->getFunction();
auto [FeatIt, FeatInserted] = FeatureMask.try_emplace(Caller);
if (FeatInserted)
FeatIt->second = TTI.getFeatureMask(*Caller);
auto [CallIt, CallInserted] = CallSites.try_emplace(Caller);
if (CallInserted)
Callers.push_back(Caller);
CallIt->second.push_back(CB);
GlobalIFunc *CallerIF = nullptr;
TargetTransformInfo &TTI = GetTTI(*Caller);
bool CallerIsFMV = TTI.isMultiversionedFunction(*Caller);
// The caller is a version of a known IFunc.
if (auto It = VersionOf.find(Caller); It != VersionOf.end())
CallerIF = It->second;
else if (!CallerIsFMV && OptimizeNonFMVCallers) {
// The caller is non-FMV.
auto [It, Inserted] = FeatureMask.try_emplace(Caller);
if (Inserted)
It->second = TTI.getFeatureMask(*Caller);
} else
// The caller is none of the above, skip.
continue;
auto [It, Inserted] = CallSites.try_emplace(Caller);
if (Inserted) {
if (CallerIsFMV)
CallerIFuncs.push_back(CallerIF);
else
NonFMVCallers.push_back(Caller);
}
It->second.push_back(CB);
}
}
}

// Sort the caller versions in decreasing priority order.
sort(Callers, [&](auto *LHS, auto *RHS) {
return FeatureMask[LHS].ugt(FeatureMask[RHS]);
});

auto implies = [](APInt A, APInt B) { return B.isSubsetOf(A); };
LLVM_DEBUG(dbgs() << "Statically resolving calls to function "
<< CalleeIF->getResolverFunction()->getName() << "\n");

// Index to the highest priority candidate.
unsigned I = 0;
// Now try to redirect calls starting from higher priority callers.
for (Function *Caller : Callers) {
assert(I < Callees.size() && "Found callers of equal priority");
auto redirectCalls = [&](SmallVectorImpl<Function *> &Callers,
SmallVectorImpl<Function *> &Callees) {
// Index to the current callee candidate.
unsigned I = 0;

Function *Callee = Callees[I];
APInt CallerBits = FeatureMask[Caller];
APInt CalleeBits = FeatureMask[Callee];
// Try to redirect calls starting from higher priority callers.
for (Function *Caller : Callers) {
if (I == Callees.size())
break;

// In the case of FMV callers, we know that all higher priority callers
// than the current one did not get selected at runtime, which helps
// reason about the callees (if they have versions that mandate presence
// of the features which we already know are unavailable on this target).
if (TTI.isMultiversionedFunction(*Caller)) {
bool CallerIsFMV = GetTTI(*Caller).isMultiversionedFunction(*Caller);
// In the case of FMV callers, we know that all higher priority callers
// than the current one did not get selected at runtime, which helps
// reason about the callees (if they have versions that mandate presence
// of the features which we already know are unavailable on this
// target).
if (!CallerIsFMV)
// We can't reason much about non-FMV callers. Just pick the highest
// priority callee if it matches, otherwise bail.
assert(I == 0 && "Should only select the highest priority candidate");

Function *Callee = Callees[I];
APInt CallerBits = FeatureMask[Caller];
APInt CalleeBits = FeatureMask[Callee];
// If the feature set of the caller implies the feature set of the
// highest priority candidate then it shall be picked. In case of
// identical sets advance the candidate index one position.
if (CallerBits == CalleeBits)
++I;
else if (!implies(CallerBits, CalleeBits)) {
// Keep advancing the candidate index as long as the caller's
// features are a subset of the current candidate's.
while (implies(CalleeBits, CallerBits)) {
// highest priority candidate then it shall be picked.
if (CalleeBits.isSubsetOf(CallerBits)) {
// If there are no records of call sites for this particular function
// version, then it is not actually a caller, in which case skip.
if (auto It = CallSites.find(Caller); It != CallSites.end()) {
for (CallBase *CS : It->second) {
LLVM_DEBUG(dbgs() << "Redirecting call " << Caller->getName()
<< " -> " << Callee->getName() << "\n");
CS->setCalledOperand(Callee);
}
Changed = true;
}
}
// Keep advancing the candidate index as long as the caller's
// features are a subset of the current candidate's.
if (CallerIsFMV) {
while (CallerBits.isSubsetOf(CalleeBits)) {
if (++I == Callees.size())
break;
CalleeBits = FeatureMask[Callees[I]];
}
continue;
}
} else {
// We can't reason much about non-FMV callers. Just pick the highest
// priority callee if it matches, otherwise bail.
if (!OptimizeNonFMVCallers || I > 0 || !implies(CallerBits, CalleeBits))
continue;
}
auto &Calls = CallSites[Caller];
for (CallBase *CS : Calls) {
LLVM_DEBUG(dbgs() << "Redirecting call " << Caller->getName() << " -> "
<< Callee->getName() << "\n");
CS->setCalledOperand(Callee);
};

auto &Callees = VersionedFuncs[CalleeIF];

// Optimize non-FMV calls.
if (!NonFMVCallers.empty() && OptimizeNonFMVCallers)
redirectCalls(NonFMVCallers, Callees);

// Optimize FMV calls.
if (!CallerIFuncs.empty()) {
for (GlobalIFunc *CallerIF : CallerIFuncs) {
auto &Callers = VersionedFuncs[CallerIF];
redirectCalls(Callers, Callees);
}
Changed = true;
}
if (IF.use_empty() ||
all_of(IF.users(), [](User *U) { return isa<GlobalAlias>(U); }))

if (CalleeIF->use_empty() ||
all_of(CalleeIF->users(), [](User *U) { return isa<GlobalAlias>(U); }))
NumIFuncsResolved++;
}
return Changed;
Expand Down
Loading