@@ -730,50 +730,54 @@ bool DataAggregator::doInterBranch(BinaryFunction *FromFunc,
730730 return true ;
731731}
732732
733+ bool DataAggregator::checkReturn (uint64_t Addr) {
734+ auto isReturn = [&](auto MI) { return MI && BC->MIB ->isReturn (*MI); };
735+ if (llvm::is_contained (Returns, Addr))
736+ return true ;
737+
738+ BinaryFunction *Func = getBinaryFunctionContainingAddress (Addr);
739+ if (!Func)
740+ return false ;
741+
742+ const uint64_t Offset = Addr - Func->getAddress ();
743+ if (Func->hasInstructions ()
744+ ? isReturn (Func->getInstructionAtOffset (Offset))
745+ : isReturn (Func->disassembleInstructionAtOffset (Offset))) {
746+ Returns.emplace (Addr);
747+ return true ;
748+ }
749+ return false ;
750+ }
751+
733752bool DataAggregator::doBranch (uint64_t From, uint64_t To, uint64_t Count,
734753 uint64_t Mispreds) {
735- // Returns whether \p Offset in \p Func contains a return instruction.
736- auto checkReturn = [&](const BinaryFunction &Func, const uint64_t Offset) {
737- auto isReturn = [&](auto MI) { return MI && BC->MIB ->isReturn (*MI); };
738- return Func.hasInstructions ()
739- ? isReturn (Func.getInstructionAtOffset (Offset))
740- : isReturn (Func.disassembleInstructionAtOffset (Offset));
741- };
742-
743754 // Mutates \p Addr to an offset into the containing function, performing BAT
744755 // offset translation and parent lookup.
745756 //
746- // Returns the containing function (or BAT parent) and whether the address
747- // corresponds to a return (if \p IsFrom) or a call continuation (otherwise).
757+ // Returns the containing function (or BAT parent).
748758 auto handleAddress = [&](uint64_t &Addr, bool IsFrom) {
749759 BinaryFunction *Func = getBinaryFunctionContainingAddress (Addr);
750760 if (!Func) {
751761 Addr = 0 ;
752- return std::pair{ Func, false } ;
762+ return Func;
753763 }
754764
755765 Addr -= Func->getAddress ();
756766
757- bool IsRet = IsFrom && checkReturn (*Func, Addr);
758-
759767 if (BAT)
760768 Addr = BAT->translate (Func->getAddress (), Addr, IsFrom);
761769
762770 if (BinaryFunction *ParentFunc = getBATParentFunction (*Func))
763- Func = ParentFunc;
771+ return ParentFunc;
764772
765- return std::pair{ Func, IsRet} ;
773+ return Func;
766774 };
767775
768- auto [ FromFunc, IsReturn] = handleAddress (From, /* IsFrom*/ true );
769- auto [ ToFunc, _] = handleAddress (To, /* IsFrom*/ false );
776+ BinaryFunction * FromFunc = handleAddress (From, /* IsFrom*/ true );
777+ BinaryFunction * ToFunc = handleAddress (To, /* IsFrom*/ false );
770778 if (!FromFunc && !ToFunc)
771779 return false ;
772780
773- // Ignore returns.
774- if (IsReturn)
775- return true ;
776-
777781 // Treat recursive control transfers as inter-branches.
778782 if (FromFunc == ToFunc && To != 0 ) {
779783 recordBranch (*FromFunc, From, To, Count, Mispreds);
@@ -783,7 +787,8 @@ bool DataAggregator::doBranch(uint64_t From, uint64_t To, uint64_t Count,
783787 return doInterBranch (FromFunc, ToFunc, From, To, Count, Mispreds);
784788}
785789
786- bool DataAggregator::doTrace (const Trace &Trace, uint64_t Count) {
790+ bool DataAggregator::doTrace (const Trace &Trace, uint64_t Count,
791+ bool IsReturn) {
787792 const uint64_t From = Trace.From , To = Trace.To ;
788793 BinaryFunction *FromFunc = getBinaryFunctionContainingAddress (From);
789794 BinaryFunction *ToFunc = getBinaryFunctionContainingAddress (To);
@@ -808,8 +813,8 @@ bool DataAggregator::doTrace(const Trace &Trace, uint64_t Count) {
808813 const uint64_t FuncAddress = FromFunc->getAddress ();
809814 std::optional<BoltAddressTranslation::FallthroughListTy> FTs =
810815 BAT && BAT->isBATFunction (FuncAddress)
811- ? BAT->getFallthroughsInTrace (FuncAddress, From, To)
812- : getFallthroughsInTrace (*FromFunc, Trace, Count);
816+ ? BAT->getFallthroughsInTrace (FuncAddress, From - IsReturn , To)
817+ : getFallthroughsInTrace (*FromFunc, Trace, Count, IsReturn );
813818 if (!FTs) {
814819 LLVM_DEBUG (dbgs () << " Invalid trace " << Trace << ' \n ' );
815820 NumInvalidTraces += Count;
@@ -831,7 +836,7 @@ bool DataAggregator::doTrace(const Trace &Trace, uint64_t Count) {
831836
832837std::optional<SmallVector<std::pair<uint64_t , uint64_t >, 16 >>
833838DataAggregator::getFallthroughsInTrace (BinaryFunction &BF, const Trace &Trace,
834- uint64_t Count) const {
839+ uint64_t Count, bool IsReturn ) const {
835840 SmallVector<std::pair<uint64_t , uint64_t >, 16 > Branches;
836841
837842 BinaryContext &BC = BF.getBinaryContext ();
@@ -865,9 +870,13 @@ DataAggregator::getFallthroughsInTrace(BinaryFunction &BF, const Trace &Trace,
865870
866871 // Adjust FromBB if the first LBR is a return from the last instruction in
867872 // the previous block (that instruction should be a call).
868- if (Trace.Branch != Trace::FT_ONLY && !BF.containsAddress (Trace.Branch ) &&
869- From == FromBB->getOffset () && !FromBB->isEntryPoint () &&
870- !FromBB->isLandingPad ()) {
873+ if (IsReturn) {
874+ if (From)
875+ FromBB = BF.getBasicBlockContainingOffset (From - 1 );
876+ else
877+ LLVM_DEBUG (dbgs () << " return to the function start: " << Trace << ' \n ' );
878+ } else if (Trace.Branch == Trace::EXTERNAL && From == FromBB->getOffset () &&
879+ !FromBB->isEntryPoint () && !FromBB->isLandingPad ()) {
871880 const BinaryBasicBlock *PrevBB =
872881 BF.getLayout ().getBlock (FromBB->getIndex () - 1 );
873882 if (PrevBB->getSuccessor (FromBB->getLabel ())) {
@@ -1557,11 +1566,13 @@ void DataAggregator::processBranchEvents() {
15571566 TimerGroupName, TimerGroupDesc, opts::TimeAggregator);
15581567
15591568 for (const auto &[Trace, Info] : Traces) {
1560- if (Trace.Branch != Trace::FT_ONLY &&
1569+ bool IsReturn = checkReturn (Trace.Branch );
1570+ // Ignore returns.
1571+ if (!IsReturn && Trace.Branch != Trace::FT_ONLY &&
15611572 Trace.Branch != Trace::FT_EXTERNAL_ORIGIN)
15621573 doBranch (Trace.Branch , Trace.From , Info.TakenCount , Info.MispredCount );
15631574 if (Trace.To != Trace::BR_ONLY)
1564- doTrace (Trace, Info.TakenCount );
1575+ doTrace (Trace, Info.TakenCount , IsReturn );
15651576 }
15661577 printBranchSamplesDiagnostics ();
15671578}
0 commit comments