1414#include  " bolt/Passes/PAuthGadgetScanner.h" 
1515#include  " bolt/Core/ParallelUtilities.h" 
1616#include  " bolt/Passes/DataflowAnalysis.h" 
17+ #include  " bolt/Utils/CommandLineOpts.h" 
1718#include  " llvm/ADT/STLExtras.h" 
1819#include  " llvm/ADT/SmallSet.h" 
1920#include  " llvm/MC/MCInst.h" 
@@ -26,6 +27,11 @@ namespace llvm {
2627namespace  bolt  {
2728namespace  PAuthGadgetScanner  {
2829
30+ static  cl::opt<bool > AuthTrapsOnFailure (
31+     " auth-traps-on-failure" 
32+     cl::desc (" Assume authentication instructions always trap on failure" 
33+     cl::cat(opts::BinaryAnalysisCategory));
34+ 
2935[[maybe_unused]] static  void  traceInst (const  BinaryContext &BC, StringRef Label,
3036                                       const  MCInst &MI) {
3137  dbgs () << "   " " : " 
@@ -82,8 +88,8 @@ class TrackedRegisters {
8288  TrackedRegisters (ArrayRef<MCPhysReg> RegsToTrack)
8389      : Registers(RegsToTrack),
8490        RegToIndexMapping (getMappingSize(RegsToTrack), NoIndex) {
85-     for  (unsigned  I =  0 ; I < RegsToTrack. size (); ++I )
86-       RegToIndexMapping[RegsToTrack[I]]  = I ;
91+     for  (auto  [MappedIndex, Reg] :  llvm::enumerate (RegsToTrack) )
92+       RegToIndexMapping[Reg]  = MappedIndex ;
8793  }
8894
8995  ArrayRef<MCPhysReg> getRegisters () const  { return  Registers; }
@@ -197,9 +203,9 @@ struct SrcState {
197203
198204    SafeToDerefRegs &= StateIn.SafeToDerefRegs ;
199205    TrustedRegs &= StateIn.TrustedRegs ;
200-     for  (unsigned  I =  0 ; I < LastInstWritingReg. size (); ++I) 
201-       for  ( const  MCInst *J :  StateIn.LastInstWritingReg [I] )
202-         LastInstWritingReg[I]. insert (J );
206+     for  (auto  [ThisSet, OtherSet] : 
207+           llvm::zip_equal (LastInstWritingReg,  StateIn.LastInstWritingReg ) )
208+       ThisSet. insert_range (OtherSet );
203209    return  *this ;
204210  }
205211
@@ -218,11 +224,9 @@ struct SrcState {
218224static  void  printInstsShort (raw_ostream &OS,
219225                            ArrayRef<SetOfRelatedInsts> Insts) {
220226  OS << " Insts: " 
221-   for  (unsigned  I = 0 ; I < Insts.size (); ++I) {
222-     auto  &Set = Insts[I];
227+   for  (auto  [I, PtrSet] : llvm::enumerate (Insts)) {
223228    OS << " [" " ](" 
224-     for  (const  MCInst *MCInstP : Set)
225-       OS << MCInstP << "  " 
229+     interleave (PtrSet, OS, "  " 
226230    OS << " )" 
227231  }
228232}
@@ -364,6 +368,34 @@ class SrcSafetyAnalysis {
364368    return  Clobbered;
365369  }
366370
371+   std::optional<MCPhysReg> getRegMadeTrustedByChecking (const  MCInst &Inst,
372+                                                        SrcState Cur) const  {
373+     //  This function cannot return multiple registers. This is never the case
374+     //  on AArch64.
375+     std::optional<MCPhysReg> RegCheckedByInst =
376+         BC.MIB ->getAuthCheckedReg (Inst, /* MayOverwrite=*/ false );
377+     if  (RegCheckedByInst && Cur.SafeToDerefRegs [*RegCheckedByInst])
378+       return  *RegCheckedByInst;
379+ 
380+     auto  It = CheckerSequenceInfo.find (&Inst);
381+     if  (It == CheckerSequenceInfo.end ())
382+       return  std::nullopt ;
383+ 
384+     MCPhysReg RegCheckedBySequence = It->second .first ;
385+     const  MCInst *FirstCheckerInst = It->second .second ;
386+ 
387+     //  FirstCheckerInst should belong to the same basic block (see the
388+     //  assertion in DataflowSrcSafetyAnalysis::run()), meaning it was
389+     //  deterministically processed a few steps before this instruction.
390+     const  SrcState &StateBeforeChecker = getStateBefore (*FirstCheckerInst);
391+ 
392+     //  The sequence checks the register, but it should be authenticated before.
393+     if  (!StateBeforeChecker.SafeToDerefRegs [RegCheckedBySequence])
394+       return  std::nullopt ;
395+ 
396+     return  RegCheckedBySequence;
397+   }
398+ 
367399  //  Returns all registers that can be treated as if they are written by an
368400  //  authentication instruction.
369401  SmallVector<MCPhysReg> getRegsMadeSafeToDeref (const  MCInst &Point,
@@ -382,22 +414,43 @@ class SrcSafetyAnalysis {
382414    //  ... an address can be updated in a safe manner, producing the result
383415    //  which is as trusted as the input address.
384416    if  (auto  DstAndSrc = BC.MIB ->analyzeAddressArithmeticsForPtrAuth (Point)) {
385-       if  (Cur.SafeToDerefRegs [DstAndSrc->second ])
386-         Regs.push_back (DstAndSrc->first );
417+       auto  [DstReg, SrcReg] = *DstAndSrc;
418+       if  (Cur.SafeToDerefRegs [SrcReg])
419+         Regs.push_back (DstReg);
387420    }
388421
422+     //  Make sure explicit checker sequence keeps register safe-to-dereference
423+     //  when the register would be clobbered according to the regular rules:
424+     // 
425+     //     ; LR is safe to dereference here
426+     //     mov   x16, x30  ; start of the sequence, LR is s-t-d right before
427+     //     xpaclri         ; clobbers LR, LR is not safe anymore
428+     //     cmp   x30, x16
429+     //     b.eq  1f        ; end of the sequence: LR is marked as trusted
430+     //     brk   0x1234
431+     //   1:
432+     //     ; at this point LR would be marked as trusted,
433+     //     ; but not safe-to-dereference
434+     // 
435+     //  or even just
436+     // 
437+     //     ; X1 is safe to dereference here
438+     //     ldr x0, [x1, #8]!
439+     //     ; X1 is trusted here, but it was clobbered due to address write-back
440+     if  (auto  CheckedReg = getRegMadeTrustedByChecking (Point, Cur))
441+       Regs.push_back (*CheckedReg);
442+ 
389443    return  Regs;
390444  }
391445
392446  //  Returns all registers made trusted by this instruction.
393447  SmallVector<MCPhysReg> getRegsMadeTrusted (const  MCInst &Point,
394448                                            const  SrcState &Cur) const  {
449+     assert (!AuthTrapsOnFailure && " Use getRegsMadeSafeToDeref instead" 
395450    SmallVector<MCPhysReg> Regs;
396451
397452    //  An authenticated pointer can be checked, or
398-     std::optional<MCPhysReg> CheckedReg =
399-         BC.MIB ->getAuthCheckedReg (Point, /* MayOverwrite=*/ false );
400-     if  (CheckedReg && Cur.SafeToDerefRegs [*CheckedReg])
453+     if  (auto  CheckedReg = getRegMadeTrustedByChecking (Point, Cur))
401454      Regs.push_back (*CheckedReg);
402455
403456    //  ... a pointer can be authenticated by an instruction that always checks
@@ -408,28 +461,16 @@ class SrcSafetyAnalysis {
408461    if  (AutReg && IsChecked)
409462      Regs.push_back (*AutReg);
410463
411-     if  (CheckerSequenceInfo.contains (&Point)) {
412-       MCPhysReg CheckedReg;
413-       const  MCInst *FirstCheckerInst;
414-       std::tie (CheckedReg, FirstCheckerInst) = CheckerSequenceInfo.at (&Point);
415- 
416-       //  FirstCheckerInst should belong to the same basic block (see the
417-       //  assertion in DataflowSrcSafetyAnalysis::run()), meaning it was
418-       //  deterministically processed a few steps before this instruction.
419-       const  SrcState &StateBeforeChecker = getStateBefore (*FirstCheckerInst);
420-       if  (StateBeforeChecker.SafeToDerefRegs [CheckedReg])
421-         Regs.push_back (CheckedReg);
422-     }
423- 
424464    //  ... a safe address can be materialized, or
425465    if  (auto  NewAddrReg = BC.MIB ->getMaterializedAddressRegForPtrAuth (Point))
426466      Regs.push_back (*NewAddrReg);
427467
428468    //  ... an address can be updated in a safe manner, producing the result
429469    //  which is as trusted as the input address.
430470    if  (auto  DstAndSrc = BC.MIB ->analyzeAddressArithmeticsForPtrAuth (Point)) {
431-       if  (Cur.TrustedRegs [DstAndSrc->second ])
432-         Regs.push_back (DstAndSrc->first );
471+       auto  [DstReg, SrcReg] = *DstAndSrc;
472+       if  (Cur.TrustedRegs [SrcReg])
473+         Regs.push_back (DstReg);
433474    }
434475
435476    return  Regs;
@@ -463,28 +504,11 @@ class SrcSafetyAnalysis {
463504    BitVector Clobbered = getClobberedRegs (Point);
464505    SmallVector<MCPhysReg> NewSafeToDerefRegs =
465506        getRegsMadeSafeToDeref (Point, Cur);
466-     SmallVector<MCPhysReg> NewTrustedRegs = getRegsMadeTrusted (Point, Cur);
467- 
468-     //  Ideally, being trusted is a strictly stronger property than being
469-     //  safe-to-dereference. To simplify the computation of Next state, enforce
470-     //  this for NewSafeToDerefRegs and NewTrustedRegs. Additionally, this
471-     //  fixes the properly for "cumulative" register states in tricky cases
472-     //  like the following:
473-     // 
474-     //     ; LR is safe to dereference here
475-     //     mov   x16, x30  ; start of the sequence, LR is s-t-d right before
476-     //     xpaclri         ; clobbers LR, LR is not safe anymore
477-     //     cmp   x30, x16
478-     //     b.eq  1f        ; end of the sequence: LR is marked as trusted
479-     //     brk   0x1234
480-     //   1:
481-     //     ; at this point LR would be marked as trusted,
482-     //     ; but not safe-to-dereference
483-     // 
484-     for  (auto  TrustedReg : NewTrustedRegs) {
485-       if  (!is_contained (NewSafeToDerefRegs, TrustedReg))
486-         NewSafeToDerefRegs.push_back (TrustedReg);
487-     }
507+     //  If authentication instructions trap on failure, safe-to-dereference
508+     //  registers are always trusted.
509+     SmallVector<MCPhysReg> NewTrustedRegs =
510+         AuthTrapsOnFailure ? NewSafeToDerefRegs
511+                            : getRegsMadeTrusted (Point, Cur);
488512
489513    //  Then, compute the state after this instruction is executed.
490514    SrcState Next = Cur;
@@ -521,6 +545,11 @@ class SrcSafetyAnalysis {
521545      dbgs () << " )\n " 
522546    });
523547
548+     //  Being trusted is a strictly stronger property than being
549+     //  safe-to-dereference.
550+     assert (!Next.TrustedRegs .test (Next.SafeToDerefRegs ) &&
551+            " SafeToDerefRegs should contain all TrustedRegs" 
552+ 
524553    return  Next;
525554  }
526555
@@ -836,9 +865,9 @@ struct DstState {
836865      return  (*this  = StateIn);
837866
838867    CannotEscapeUnchecked &= StateIn.CannotEscapeUnchecked ;
839-     for  (unsigned  I =  0 ; I < FirstInstLeakingReg. size (); ++I) 
840-       for  ( const  MCInst *J :  StateIn.FirstInstLeakingReg [I] )
841-         FirstInstLeakingReg[I]. insert (J );
868+     for  (auto  [ThisSet, OtherSet] : 
869+           llvm::zip_equal (FirstInstLeakingReg,  StateIn.FirstInstLeakingReg ) )
870+       ThisSet. insert_range (OtherSet );
842871    return  *this ;
843872  }
844873
@@ -1004,8 +1033,7 @@ class DstSafetyAnalysis {
10041033
10051034    //  ... an address can be updated in a safe manner, or
10061035    if  (auto  DstAndSrc = BC.MIB ->analyzeAddressArithmeticsForPtrAuth (Inst)) {
1007-       MCPhysReg DstReg, SrcReg;
1008-       std::tie (DstReg, SrcReg) = *DstAndSrc;
1036+       auto  [DstReg, SrcReg] = *DstAndSrc;
10091037      //  Note that *all* registers containing the derived values must be safe,
10101038      //  both source and destination ones. No temporaries are supported at now.
10111039      if  (Cur.CannotEscapeUnchecked [SrcReg] &&
@@ -1045,7 +1073,7 @@ class DstSafetyAnalysis {
10451073    //  If this instruction terminates the program immediately, no
10461074    //  authentication oracles are possible past this point.
10471075    if  (BC.MIB ->isTrap (Point)) {
1048-       LLVM_DEBUG ({  traceInst (BC, " Trap instruction found" ; } );
1076+       LLVM_DEBUG (traceInst (BC, " Trap instruction found" 
10491077      DstState Next (NumRegs, RegsToTrackInstsFor.getNumTrackedRegisters ());
10501078      Next.CannotEscapeUnchecked .set ();
10511079      return  Next;
@@ -1130,6 +1158,11 @@ class DataflowDstSafetyAnalysis
11301158  }
11311159
11321160  void  run () override  {
1161+     //  As long as DstSafetyAnalysis is only computed to detect authentication
1162+     //  oracles, it is a waste of time to compute it when authentication
1163+     //  instructions are known to always trap on failure.
1164+     assert (!AuthTrapsOnFailure &&
1165+            " DstSafetyAnalysis is useless with faulting auth" 
11331166    for  (BinaryBasicBlock &BB : Func) {
11341167      if  (auto  CheckerInfo = BC.MIB ->getAuthCheckedReg (BB)) {
11351168        LLVM_DEBUG ({
@@ -1215,7 +1248,7 @@ class CFGUnawareDstSafetyAnalysis : public DstSafetyAnalysis,
12151248      //  starting to analyze Inst.
12161249      if  (BC.MIB ->isCall (Inst) || BC.MIB ->isBranch (Inst) ||
12171250          BC.MIB ->isReturn (Inst)) {
1218-         LLVM_DEBUG ({  traceInst (BC, " Control flow instruction" ; } );
1251+         LLVM_DEBUG (traceInst (BC, " Control flow instruction" 
12191252        S = createUnsafeState ();
12201253      }
12211254
@@ -1360,7 +1393,7 @@ shouldReportUnsafeTailCall(const BinaryContext &BC, const BinaryFunction &BF,
13601393  //  such libc, ignore tail calls performed by ELF entry function.
13611394  if  (BC.StartFunctionAddress  &&
13621395      *BC.StartFunctionAddress  == Inst.getFunction ()->getAddress ()) {
1363-     LLVM_DEBUG ({  dbgs () << "   Skipping tail call in ELF entry function.\n " ; } );
1396+     LLVM_DEBUG (dbgs () << "   Skipping tail call in ELF entry function.\n " 
13641397    return  std::nullopt ;
13651398  }
13661399
@@ -1434,7 +1467,7 @@ shouldReportAuthOracle(const BinaryContext &BC, const MCInstReference &Inst,
14341467  });
14351468
14361469  if  (S.empty ()) {
1437-     LLVM_DEBUG ({  dbgs () << "     DstState is empty!\n " ; } );
1470+     LLVM_DEBUG (dbgs () << "     DstState is empty!\n " 
14381471    return  make_generic_report (
14391472        Inst, " Warning: no state computed for an authentication instruction " 
14401473              " (possibly unreachable)" 
@@ -1461,7 +1494,7 @@ collectRegsToTrack(ArrayRef<PartialReport<MCPhysReg>> Reports) {
14611494void  FunctionAnalysisContext::findUnsafeUses (
14621495    SmallVector<PartialReport<MCPhysReg>> &Reports) {
14631496  auto  Analysis = SrcSafetyAnalysis::create (BF, AllocatorId, {});
1464-   LLVM_DEBUG ({  dbgs () << " Running src register safety analysis...\n " ; } );
1497+   LLVM_DEBUG (dbgs () << " Running src register safety analysis...\n " 
14651498  Analysis->run ();
14661499  LLVM_DEBUG ({
14671500    dbgs () << " After src register safety analysis:\n " 
@@ -1518,8 +1551,7 @@ void FunctionAnalysisContext::findUnsafeUses(
15181551
15191552    const  SrcState &S = Analysis->getStateBefore (Inst);
15201553    if  (S.empty ()) {
1521-       LLVM_DEBUG (
1522-           { traceInst (BC, " Instruction has no state, skipping" 
1554+       LLVM_DEBUG (traceInst (BC, " Instruction has no state, skipping" 
15231555      assert (UnreachableBBReported && " Should be reported at least once" 
15241556      (void )UnreachableBBReported;
15251557      return ;
@@ -1546,8 +1578,7 @@ void FunctionAnalysisContext::augmentUnsafeUseReports(
15461578  SmallVector<MCPhysReg> RegsToTrack = collectRegsToTrack (Reports);
15471579  //  Re-compute the analysis with register tracking.
15481580  auto  Analysis = SrcSafetyAnalysis::create (BF, AllocatorId, RegsToTrack);
1549-   LLVM_DEBUG (
1550-       { dbgs () << " \n Running detailed src register safety analysis...\n " 
1581+   LLVM_DEBUG (dbgs () << " \n Running detailed src register safety analysis...\n " 
15511582  Analysis->run ();
15521583  LLVM_DEBUG ({
15531584    dbgs () << " After detailed src register safety analysis:\n " 
@@ -1557,7 +1588,7 @@ void FunctionAnalysisContext::augmentUnsafeUseReports(
15571588  //  Augment gadget reports.
15581589  for  (auto  &Report : Reports) {
15591590    MCInstReference Location = Report.Issue ->Location ;
1560-     LLVM_DEBUG ({  traceInst (BC, " Attaching clobbering info to" ; } );
1591+     LLVM_DEBUG (traceInst (BC, " Attaching clobbering info to" 
15611592    assert (Report.RequestedDetails  &&
15621593           " Should be removed by handleSimpleReports" 
15631594    auto  DetailedInfo =
@@ -1571,9 +1602,11 @@ void FunctionAnalysisContext::findUnsafeDefs(
15711602    SmallVector<PartialReport<MCPhysReg>> &Reports) {
15721603  if  (PacRetGadgetsOnly)
15731604    return ;
1605+   if  (AuthTrapsOnFailure)
1606+     return ;
15741607
15751608  auto  Analysis = DstSafetyAnalysis::create (BF, AllocatorId, {});
1576-   LLVM_DEBUG ({  dbgs () << " Running dst register safety analysis...\n " ; } );
1609+   LLVM_DEBUG (dbgs () << " Running dst register safety analysis...\n " 
15771610  Analysis->run ();
15781611  LLVM_DEBUG ({
15791612    dbgs () << " After dst register safety analysis:\n " 
@@ -1596,8 +1629,7 @@ void FunctionAnalysisContext::augmentUnsafeDefReports(
15961629  SmallVector<MCPhysReg> RegsToTrack = collectRegsToTrack (Reports);
15971630  //  Re-compute the analysis with register tracking.
15981631  auto  Analysis = DstSafetyAnalysis::create (BF, AllocatorId, RegsToTrack);
1599-   LLVM_DEBUG (
1600-       { dbgs () << " \n Running detailed dst register safety analysis...\n " 
1632+   LLVM_DEBUG (dbgs () << " \n Running detailed dst register safety analysis...\n " 
16011633  Analysis->run ();
16021634  LLVM_DEBUG ({
16031635    dbgs () << " After detailed dst register safety analysis:\n " 
@@ -1607,7 +1639,7 @@ void FunctionAnalysisContext::augmentUnsafeDefReports(
16071639  //  Augment gadget reports.
16081640  for  (auto  &Report : Reports) {
16091641    MCInstReference Location = Report.Issue ->Location ;
1610-     LLVM_DEBUG ({  traceInst (BC, " Attaching leakage info to" ; } );
1642+     LLVM_DEBUG (traceInst (BC, " Attaching leakage info to" 
16111643    assert (Report.RequestedDetails  &&
16121644           " Should be removed by handleSimpleReports" 
16131645    auto  DetailedInfo = std::make_shared<LeakageInfo>(
0 commit comments