diff --git a/llvm/lib/Transforms/IPO/MemProfContextDisambiguation.cpp b/llvm/lib/Transforms/IPO/MemProfContextDisambiguation.cpp index 61a8f4a448bbd..988e912b2de83 100644 --- a/llvm/lib/Transforms/IPO/MemProfContextDisambiguation.cpp +++ b/llvm/lib/Transforms/IPO/MemProfContextDisambiguation.cpp @@ -821,19 +821,31 @@ struct IndexCall : public PointerUnion { IndexCall *operator->() { return this; } - PointerUnion getBase() const { return *this; } - void print(raw_ostream &OS) const { - if (auto *AI = llvm::dyn_cast_if_present(getBase())) { + PointerUnion Base = *this; + if (auto *AI = llvm::dyn_cast_if_present(Base)) { OS << *AI; } else { - auto *CI = llvm::dyn_cast_if_present(getBase()); + auto *CI = llvm::dyn_cast_if_present(Base); assert(CI); OS << *CI; } } }; +} // namespace + +namespace llvm { +template <> struct simplify_type { + using SimpleType = PointerUnion; + static SimpleType getSimplifiedValue(IndexCall &Val) { return Val; } +}; +template <> struct simplify_type { + using SimpleType = const PointerUnion; + static SimpleType getSimplifiedValue(const IndexCall &Val) { return Val; } +}; +} // namespace llvm +namespace { /// CRTP derived class for graphs built from summary index (ThinLTO). class IndexCallsiteContextGraph : public CallsiteContextGraph(Call.getBase())); + assert(isa(Call)); CallStack::const_iterator> - CallsiteContext(dyn_cast_if_present(Call.getBase())); + CallsiteContext(dyn_cast_if_present(Call)); // Need to convert index into stack id. return Index.getStackIdAtIndex(CallsiteContext.back()); } @@ -1911,10 +1923,10 @@ std::string IndexCallsiteContextGraph::getLabel(const FunctionSummary *Func, unsigned CloneNo) const { auto VI = FSToVIMap.find(Func); assert(VI != FSToVIMap.end()); - if (isa(Call.getBase())) + if (isa(Call)) return (VI->second.name() + " -> alloc").str(); else { - auto *Callsite = dyn_cast_if_present(Call.getBase()); + auto *Callsite = dyn_cast_if_present(Call); return (VI->second.name() + " -> " + getMemProfFuncName(Callsite->Callee.name(), Callsite->Clones[CloneNo])) @@ -1933,9 +1945,9 @@ ModuleCallsiteContextGraph::getStackIdsWithContextNodesForCall( std::vector IndexCallsiteContextGraph::getStackIdsWithContextNodesForCall(IndexCall &Call) { - assert(isa(Call.getBase())); + assert(isa(Call)); CallStack::const_iterator> - CallsiteContext(dyn_cast_if_present(Call.getBase())); + CallsiteContext(dyn_cast_if_present(Call)); return getStackIdsWithContextNodes::const_iterator>( CallsiteContext); @@ -2696,8 +2708,7 @@ bool IndexCallsiteContextGraph::findProfiledCalleeThroughTailCalls( const FunctionSummary * IndexCallsiteContextGraph::getCalleeFunc(IndexCall &Call) { - ValueInfo Callee = - dyn_cast_if_present(Call.getBase())->Callee; + ValueInfo Callee = dyn_cast_if_present(Call)->Callee; if (Callee.getSummaryList().empty()) return nullptr; return dyn_cast(Callee.getSummaryList()[0]->getBaseObject()); @@ -2707,8 +2718,7 @@ bool IndexCallsiteContextGraph::calleeMatchesFunc( IndexCall &Call, const FunctionSummary *Func, const FunctionSummary *CallerFunc, std::vector> &FoundCalleeChain) { - ValueInfo Callee = - dyn_cast_if_present(Call.getBase())->Callee; + ValueInfo Callee = dyn_cast_if_present(Call)->Callee; // If there is no summary list then this is a call to an externally defined // symbol. AliasSummary *Alias = @@ -2751,10 +2761,8 @@ bool IndexCallsiteContextGraph::calleeMatchesFunc( } bool IndexCallsiteContextGraph::sameCallee(IndexCall &Call1, IndexCall &Call2) { - ValueInfo Callee1 = - dyn_cast_if_present(Call1.getBase())->Callee; - ValueInfo Callee2 = - dyn_cast_if_present(Call2.getBase())->Callee; + ValueInfo Callee1 = dyn_cast_if_present(Call1)->Callee; + ValueInfo Callee2 = dyn_cast_if_present(Call2)->Callee; return Callee1 == Callee2; } @@ -3610,7 +3618,7 @@ IndexCallsiteContextGraph::cloneFunctionForCallsite( // Confirm this matches the CloneNo provided by the caller, which is based on // the number of function clones we have. assert(CloneNo == - (isa(Call.call().getBase()) + (isa(Call.call()) ? Call.call().dyn_cast()->Versions.size() : Call.call().dyn_cast()->Clones.size())); // Walk all the instructions in this function. Create a new version for