1414#include " bolt/Passes/PAuthGadgetScanner.h"
1515#include " bolt/Core/ParallelUtilities.h"
1616#include " bolt/Passes/DataflowAnalysis.h"
17+ #include " bolt/Utils/CommandLineOpts.h"
1718#include " llvm/ADT/STLExtras.h"
1819#include " llvm/ADT/SmallSet.h"
1920#include " llvm/MC/MCInst.h"
@@ -26,6 +27,11 @@ namespace llvm {
2627namespace bolt {
2728namespace PAuthGadgetScanner {
2829
30+ static cl::opt<bool > AuthTrapsOnFailure (
31+ " auth-traps-on-failure" ,
32+ cl::desc (" Assume authentication instructions always trap on failure" ),
33+ cl::cat(opts::BinaryAnalysisCategory));
34+
2935[[maybe_unused]] static void traceInst (const BinaryContext &BC, StringRef Label,
3036 const MCInst &MI) {
3137 dbgs () << " " << Label << " : " ;
@@ -82,8 +88,8 @@ class TrackedRegisters {
8288 TrackedRegisters (ArrayRef<MCPhysReg> RegsToTrack)
8389 : Registers(RegsToTrack),
8490 RegToIndexMapping (getMappingSize(RegsToTrack), NoIndex) {
85- for (unsigned I = 0 ; I < RegsToTrack. size (); ++I )
86- RegToIndexMapping[RegsToTrack[I]] = I ;
91+ for (auto [MappedIndex, Reg] : llvm::enumerate (RegsToTrack) )
92+ RegToIndexMapping[Reg] = MappedIndex ;
8793 }
8894
8995 ArrayRef<MCPhysReg> getRegisters () const { return Registers; }
@@ -197,9 +203,9 @@ struct SrcState {
197203
198204 SafeToDerefRegs &= StateIn.SafeToDerefRegs ;
199205 TrustedRegs &= StateIn.TrustedRegs ;
200- for (unsigned I = 0 ; I < LastInstWritingReg. size (); ++I)
201- for ( const MCInst *J : StateIn.LastInstWritingReg [I] )
202- LastInstWritingReg[I]. insert (J );
206+ for (auto [ThisSet, OtherSet] :
207+ llvm::zip_equal (LastInstWritingReg, StateIn.LastInstWritingReg ) )
208+ ThisSet. insert_range (OtherSet );
203209 return *this ;
204210 }
205211
@@ -218,11 +224,9 @@ struct SrcState {
218224static void printInstsShort (raw_ostream &OS,
219225 ArrayRef<SetOfRelatedInsts> Insts) {
220226 OS << " Insts: " ;
221- for (unsigned I = 0 ; I < Insts.size (); ++I) {
222- auto &Set = Insts[I];
227+ for (auto [I, PtrSet] : llvm::enumerate (Insts)) {
223228 OS << " [" << I << " ](" ;
224- for (const MCInst *MCInstP : Set)
225- OS << MCInstP << " " ;
229+ interleave (PtrSet, OS, " " );
226230 OS << " )" ;
227231 }
228232}
@@ -364,6 +368,34 @@ class SrcSafetyAnalysis {
364368 return Clobbered;
365369 }
366370
371+ std::optional<MCPhysReg> getRegMadeTrustedByChecking (const MCInst &Inst,
372+ SrcState Cur) const {
373+ // This function cannot return multiple registers. This is never the case
374+ // on AArch64.
375+ std::optional<MCPhysReg> RegCheckedByInst =
376+ BC.MIB ->getAuthCheckedReg (Inst, /* MayOverwrite=*/ false );
377+ if (RegCheckedByInst && Cur.SafeToDerefRegs [*RegCheckedByInst])
378+ return *RegCheckedByInst;
379+
380+ auto It = CheckerSequenceInfo.find (&Inst);
381+ if (It == CheckerSequenceInfo.end ())
382+ return std::nullopt ;
383+
384+ MCPhysReg RegCheckedBySequence = It->second .first ;
385+ const MCInst *FirstCheckerInst = It->second .second ;
386+
387+ // FirstCheckerInst should belong to the same basic block (see the
388+ // assertion in DataflowSrcSafetyAnalysis::run()), meaning it was
389+ // deterministically processed a few steps before this instruction.
390+ const SrcState &StateBeforeChecker = getStateBefore (*FirstCheckerInst);
391+
392+ // The sequence checks the register, but it should be authenticated before.
393+ if (!StateBeforeChecker.SafeToDerefRegs [RegCheckedBySequence])
394+ return std::nullopt ;
395+
396+ return RegCheckedBySequence;
397+ }
398+
367399 // Returns all registers that can be treated as if they are written by an
368400 // authentication instruction.
369401 SmallVector<MCPhysReg> getRegsMadeSafeToDeref (const MCInst &Point,
@@ -382,22 +414,43 @@ class SrcSafetyAnalysis {
382414 // ... an address can be updated in a safe manner, producing the result
383415 // which is as trusted as the input address.
384416 if (auto DstAndSrc = BC.MIB ->analyzeAddressArithmeticsForPtrAuth (Point)) {
385- if (Cur.SafeToDerefRegs [DstAndSrc->second ])
386- Regs.push_back (DstAndSrc->first );
417+ auto [DstReg, SrcReg] = *DstAndSrc;
418+ if (Cur.SafeToDerefRegs [SrcReg])
419+ Regs.push_back (DstReg);
387420 }
388421
422+ // Make sure explicit checker sequence keeps register safe-to-dereference
423+ // when the register would be clobbered according to the regular rules:
424+ //
425+ // ; LR is safe to dereference here
426+ // mov x16, x30 ; start of the sequence, LR is s-t-d right before
427+ // xpaclri ; clobbers LR, LR is not safe anymore
428+ // cmp x30, x16
429+ // b.eq 1f ; end of the sequence: LR is marked as trusted
430+ // brk 0x1234
431+ // 1:
432+ // ; at this point LR would be marked as trusted,
433+ // ; but not safe-to-dereference
434+ //
435+ // or even just
436+ //
437+ // ; X1 is safe to dereference here
438+ // ldr x0, [x1, #8]!
439+ // ; X1 is trusted here, but it was clobbered due to address write-back
440+ if (auto CheckedReg = getRegMadeTrustedByChecking (Point, Cur))
441+ Regs.push_back (*CheckedReg);
442+
389443 return Regs;
390444 }
391445
392446 // Returns all registers made trusted by this instruction.
393447 SmallVector<MCPhysReg> getRegsMadeTrusted (const MCInst &Point,
394448 const SrcState &Cur) const {
449+ assert (!AuthTrapsOnFailure && " Use getRegsMadeSafeToDeref instead" );
395450 SmallVector<MCPhysReg> Regs;
396451
397452 // An authenticated pointer can be checked, or
398- std::optional<MCPhysReg> CheckedReg =
399- BC.MIB ->getAuthCheckedReg (Point, /* MayOverwrite=*/ false );
400- if (CheckedReg && Cur.SafeToDerefRegs [*CheckedReg])
453+ if (auto CheckedReg = getRegMadeTrustedByChecking (Point, Cur))
401454 Regs.push_back (*CheckedReg);
402455
403456 // ... a pointer can be authenticated by an instruction that always checks
@@ -408,28 +461,16 @@ class SrcSafetyAnalysis {
408461 if (AutReg && IsChecked)
409462 Regs.push_back (*AutReg);
410463
411- if (CheckerSequenceInfo.contains (&Point)) {
412- MCPhysReg CheckedReg;
413- const MCInst *FirstCheckerInst;
414- std::tie (CheckedReg, FirstCheckerInst) = CheckerSequenceInfo.at (&Point);
415-
416- // FirstCheckerInst should belong to the same basic block (see the
417- // assertion in DataflowSrcSafetyAnalysis::run()), meaning it was
418- // deterministically processed a few steps before this instruction.
419- const SrcState &StateBeforeChecker = getStateBefore (*FirstCheckerInst);
420- if (StateBeforeChecker.SafeToDerefRegs [CheckedReg])
421- Regs.push_back (CheckedReg);
422- }
423-
424464 // ... a safe address can be materialized, or
425465 if (auto NewAddrReg = BC.MIB ->getMaterializedAddressRegForPtrAuth (Point))
426466 Regs.push_back (*NewAddrReg);
427467
428468 // ... an address can be updated in a safe manner, producing the result
429469 // which is as trusted as the input address.
430470 if (auto DstAndSrc = BC.MIB ->analyzeAddressArithmeticsForPtrAuth (Point)) {
431- if (Cur.TrustedRegs [DstAndSrc->second ])
432- Regs.push_back (DstAndSrc->first );
471+ auto [DstReg, SrcReg] = *DstAndSrc;
472+ if (Cur.TrustedRegs [SrcReg])
473+ Regs.push_back (DstReg);
433474 }
434475
435476 return Regs;
@@ -463,28 +504,11 @@ class SrcSafetyAnalysis {
463504 BitVector Clobbered = getClobberedRegs (Point);
464505 SmallVector<MCPhysReg> NewSafeToDerefRegs =
465506 getRegsMadeSafeToDeref (Point, Cur);
466- SmallVector<MCPhysReg> NewTrustedRegs = getRegsMadeTrusted (Point, Cur);
467-
468- // Ideally, being trusted is a strictly stronger property than being
469- // safe-to-dereference. To simplify the computation of Next state, enforce
470- // this for NewSafeToDerefRegs and NewTrustedRegs. Additionally, this
471- // fixes the properly for "cumulative" register states in tricky cases
472- // like the following:
473- //
474- // ; LR is safe to dereference here
475- // mov x16, x30 ; start of the sequence, LR is s-t-d right before
476- // xpaclri ; clobbers LR, LR is not safe anymore
477- // cmp x30, x16
478- // b.eq 1f ; end of the sequence: LR is marked as trusted
479- // brk 0x1234
480- // 1:
481- // ; at this point LR would be marked as trusted,
482- // ; but not safe-to-dereference
483- //
484- for (auto TrustedReg : NewTrustedRegs) {
485- if (!is_contained (NewSafeToDerefRegs, TrustedReg))
486- NewSafeToDerefRegs.push_back (TrustedReg);
487- }
507+ // If authentication instructions trap on failure, safe-to-dereference
508+ // registers are always trusted.
509+ SmallVector<MCPhysReg> NewTrustedRegs =
510+ AuthTrapsOnFailure ? NewSafeToDerefRegs
511+ : getRegsMadeTrusted (Point, Cur);
488512
489513 // Then, compute the state after this instruction is executed.
490514 SrcState Next = Cur;
@@ -521,6 +545,11 @@ class SrcSafetyAnalysis {
521545 dbgs () << " )\n " ;
522546 });
523547
548+ // Being trusted is a strictly stronger property than being
549+ // safe-to-dereference.
550+ assert (!Next.TrustedRegs .test (Next.SafeToDerefRegs ) &&
551+ " SafeToDerefRegs should contain all TrustedRegs" );
552+
524553 return Next;
525554 }
526555
@@ -836,9 +865,9 @@ struct DstState {
836865 return (*this = StateIn);
837866
838867 CannotEscapeUnchecked &= StateIn.CannotEscapeUnchecked ;
839- for (unsigned I = 0 ; I < FirstInstLeakingReg. size (); ++I)
840- for ( const MCInst *J : StateIn.FirstInstLeakingReg [I] )
841- FirstInstLeakingReg[I]. insert (J );
868+ for (auto [ThisSet, OtherSet] :
869+ llvm::zip_equal (FirstInstLeakingReg, StateIn.FirstInstLeakingReg ) )
870+ ThisSet. insert_range (OtherSet );
842871 return *this ;
843872 }
844873
@@ -1004,8 +1033,7 @@ class DstSafetyAnalysis {
10041033
10051034 // ... an address can be updated in a safe manner, or
10061035 if (auto DstAndSrc = BC.MIB ->analyzeAddressArithmeticsForPtrAuth (Inst)) {
1007- MCPhysReg DstReg, SrcReg;
1008- std::tie (DstReg, SrcReg) = *DstAndSrc;
1036+ auto [DstReg, SrcReg] = *DstAndSrc;
10091037 // Note that *all* registers containing the derived values must be safe,
10101038 // both source and destination ones. No temporaries are supported at now.
10111039 if (Cur.CannotEscapeUnchecked [SrcReg] &&
@@ -1045,7 +1073,7 @@ class DstSafetyAnalysis {
10451073 // If this instruction terminates the program immediately, no
10461074 // authentication oracles are possible past this point.
10471075 if (BC.MIB ->isTrap (Point)) {
1048- LLVM_DEBUG ({ traceInst (BC, " Trap instruction found" , Point); } );
1076+ LLVM_DEBUG (traceInst (BC, " Trap instruction found" , Point));
10491077 DstState Next (NumRegs, RegsToTrackInstsFor.getNumTrackedRegisters ());
10501078 Next.CannotEscapeUnchecked .set ();
10511079 return Next;
@@ -1130,6 +1158,11 @@ class DataflowDstSafetyAnalysis
11301158 }
11311159
11321160 void run () override {
1161+ // As long as DstSafetyAnalysis is only computed to detect authentication
1162+ // oracles, it is a waste of time to compute it when authentication
1163+ // instructions are known to always trap on failure.
1164+ assert (!AuthTrapsOnFailure &&
1165+ " DstSafetyAnalysis is useless with faulting auth" );
11331166 for (BinaryBasicBlock &BB : Func) {
11341167 if (auto CheckerInfo = BC.MIB ->getAuthCheckedReg (BB)) {
11351168 LLVM_DEBUG ({
@@ -1215,7 +1248,7 @@ class CFGUnawareDstSafetyAnalysis : public DstSafetyAnalysis,
12151248 // starting to analyze Inst.
12161249 if (BC.MIB ->isCall (Inst) || BC.MIB ->isBranch (Inst) ||
12171250 BC.MIB ->isReturn (Inst)) {
1218- LLVM_DEBUG ({ traceInst (BC, " Control flow instruction" , Inst); } );
1251+ LLVM_DEBUG (traceInst (BC, " Control flow instruction" , Inst));
12191252 S = createUnsafeState ();
12201253 }
12211254
@@ -1360,7 +1393,7 @@ shouldReportUnsafeTailCall(const BinaryContext &BC, const BinaryFunction &BF,
13601393 // such libc, ignore tail calls performed by ELF entry function.
13611394 if (BC.StartFunctionAddress &&
13621395 *BC.StartFunctionAddress == Inst.getFunction ()->getAddress ()) {
1363- LLVM_DEBUG ({ dbgs () << " Skipping tail call in ELF entry function.\n " ; } );
1396+ LLVM_DEBUG (dbgs () << " Skipping tail call in ELF entry function.\n " );
13641397 return std::nullopt ;
13651398 }
13661399
@@ -1434,7 +1467,7 @@ shouldReportAuthOracle(const BinaryContext &BC, const MCInstReference &Inst,
14341467 });
14351468
14361469 if (S.empty ()) {
1437- LLVM_DEBUG ({ dbgs () << " DstState is empty!\n " ; } );
1470+ LLVM_DEBUG (dbgs () << " DstState is empty!\n " );
14381471 return make_generic_report (
14391472 Inst, " Warning: no state computed for an authentication instruction "
14401473 " (possibly unreachable)" );
@@ -1461,7 +1494,7 @@ collectRegsToTrack(ArrayRef<PartialReport<MCPhysReg>> Reports) {
14611494void FunctionAnalysisContext::findUnsafeUses (
14621495 SmallVector<PartialReport<MCPhysReg>> &Reports) {
14631496 auto Analysis = SrcSafetyAnalysis::create (BF, AllocatorId, {});
1464- LLVM_DEBUG ({ dbgs () << " Running src register safety analysis...\n " ; } );
1497+ LLVM_DEBUG (dbgs () << " Running src register safety analysis...\n " );
14651498 Analysis->run ();
14661499 LLVM_DEBUG ({
14671500 dbgs () << " After src register safety analysis:\n " ;
@@ -1518,8 +1551,7 @@ void FunctionAnalysisContext::findUnsafeUses(
15181551
15191552 const SrcState &S = Analysis->getStateBefore (Inst);
15201553 if (S.empty ()) {
1521- LLVM_DEBUG (
1522- { traceInst (BC, " Instruction has no state, skipping" , Inst); });
1554+ LLVM_DEBUG (traceInst (BC, " Instruction has no state, skipping" , Inst));
15231555 assert (UnreachableBBReported && " Should be reported at least once" );
15241556 (void )UnreachableBBReported;
15251557 return ;
@@ -1546,8 +1578,7 @@ void FunctionAnalysisContext::augmentUnsafeUseReports(
15461578 SmallVector<MCPhysReg> RegsToTrack = collectRegsToTrack (Reports);
15471579 // Re-compute the analysis with register tracking.
15481580 auto Analysis = SrcSafetyAnalysis::create (BF, AllocatorId, RegsToTrack);
1549- LLVM_DEBUG (
1550- { dbgs () << " \n Running detailed src register safety analysis...\n " ; });
1581+ LLVM_DEBUG (dbgs () << " \n Running detailed src register safety analysis...\n " );
15511582 Analysis->run ();
15521583 LLVM_DEBUG ({
15531584 dbgs () << " After detailed src register safety analysis:\n " ;
@@ -1557,7 +1588,7 @@ void FunctionAnalysisContext::augmentUnsafeUseReports(
15571588 // Augment gadget reports.
15581589 for (auto &Report : Reports) {
15591590 MCInstReference Location = Report.Issue ->Location ;
1560- LLVM_DEBUG ({ traceInst (BC, " Attaching clobbering info to" , Location); } );
1591+ LLVM_DEBUG (traceInst (BC, " Attaching clobbering info to" , Location));
15611592 assert (Report.RequestedDetails &&
15621593 " Should be removed by handleSimpleReports" );
15631594 auto DetailedInfo =
@@ -1571,9 +1602,11 @@ void FunctionAnalysisContext::findUnsafeDefs(
15711602 SmallVector<PartialReport<MCPhysReg>> &Reports) {
15721603 if (PacRetGadgetsOnly)
15731604 return ;
1605+ if (AuthTrapsOnFailure)
1606+ return ;
15741607
15751608 auto Analysis = DstSafetyAnalysis::create (BF, AllocatorId, {});
1576- LLVM_DEBUG ({ dbgs () << " Running dst register safety analysis...\n " ; } );
1609+ LLVM_DEBUG (dbgs () << " Running dst register safety analysis...\n " );
15771610 Analysis->run ();
15781611 LLVM_DEBUG ({
15791612 dbgs () << " After dst register safety analysis:\n " ;
@@ -1596,8 +1629,7 @@ void FunctionAnalysisContext::augmentUnsafeDefReports(
15961629 SmallVector<MCPhysReg> RegsToTrack = collectRegsToTrack (Reports);
15971630 // Re-compute the analysis with register tracking.
15981631 auto Analysis = DstSafetyAnalysis::create (BF, AllocatorId, RegsToTrack);
1599- LLVM_DEBUG (
1600- { dbgs () << " \n Running detailed dst register safety analysis...\n " ; });
1632+ LLVM_DEBUG (dbgs () << " \n Running detailed dst register safety analysis...\n " );
16011633 Analysis->run ();
16021634 LLVM_DEBUG ({
16031635 dbgs () << " After detailed dst register safety analysis:\n " ;
@@ -1607,7 +1639,7 @@ void FunctionAnalysisContext::augmentUnsafeDefReports(
16071639 // Augment gadget reports.
16081640 for (auto &Report : Reports) {
16091641 MCInstReference Location = Report.Issue ->Location ;
1610- LLVM_DEBUG ({ traceInst (BC, " Attaching leakage info to" , Location); } );
1642+ LLVM_DEBUG (traceInst (BC, " Attaching leakage info to" , Location));
16111643 assert (Report.RequestedDetails &&
16121644 " Should be removed by handleSimpleReports" );
16131645 auto DetailedInfo = std::make_shared<LeakageInfo>(
0 commit comments