1414#include " bolt/Passes/NonPacProtectedRetAnalysis.h"
1515#include " bolt/Core/ParallelUtilities.h"
1616#include " bolt/Passes/DataflowAnalysis.h"
17+ #include " llvm/ADT/STLExtras.h"
1718#include " llvm/ADT/SmallSet.h"
1819#include " llvm/MC/MCInst.h"
1920#include " llvm/Support/Format.h"
@@ -58,6 +59,71 @@ raw_ostream &operator<<(raw_ostream &OS, const MCInstReference &Ref) {
5859
5960namespace NonPacProtectedRetAnalysis {
6061
62+ static void traceInst (const BinaryContext &BC, StringRef Label,
63+ const MCInst &MI) {
64+ dbgs () << " " << Label << " : " ;
65+ BC.printInstruction (dbgs (), MI);
66+ }
67+
68+ static void traceReg (const BinaryContext &BC, StringRef Label,
69+ ErrorOr<MCPhysReg> Reg) {
70+ dbgs () << " " << Label << " : " ;
71+ if (Reg.getError ())
72+ dbgs () << " (error)" ;
73+ else if (*Reg == BC.MIB ->getNoRegister ())
74+ dbgs () << " (none)" ;
75+ else
76+ dbgs () << BC.MRI ->getName (*Reg);
77+ dbgs () << " \n " ;
78+ }
79+
80+ static void traceRegMask (const BinaryContext &BC, StringRef Label,
81+ BitVector Mask) {
82+ dbgs () << " " << Label << " : " ;
83+ RegStatePrinter (BC).print (dbgs (), Mask);
84+ dbgs () << " \n " ;
85+ }
86+
87+ // This class represents mapping from a set of arbitrary physical registers to
88+ // consecutive array indexes.
89+ class TrackedRegisters {
90+ static constexpr uint16_t NoIndex = -1 ;
91+ const std::vector<MCPhysReg> Registers;
92+ std::vector<uint16_t > RegToIndexMapping;
93+
94+ static size_t getMappingSize (const std::vector<MCPhysReg> &RegsToTrack) {
95+ if (RegsToTrack.empty ())
96+ return 0 ;
97+ return 1 + *llvm::max_element (RegsToTrack);
98+ }
99+
100+ public:
101+ TrackedRegisters (const std::vector<MCPhysReg> &RegsToTrack)
102+ : Registers(RegsToTrack),
103+ RegToIndexMapping (getMappingSize(RegsToTrack), NoIndex) {
104+ for (unsigned I = 0 ; I < RegsToTrack.size (); ++I)
105+ RegToIndexMapping[RegsToTrack[I]] = I;
106+ }
107+
108+ const ArrayRef<MCPhysReg> getRegisters () const { return Registers; }
109+
110+ size_t getNumTrackedRegisters () const { return Registers.size (); }
111+
112+ bool empty () const { return Registers.empty (); }
113+
114+ bool isTracked (MCPhysReg Reg) const {
115+ bool IsTracked = (unsigned )Reg < RegToIndexMapping.size () &&
116+ RegToIndexMapping[Reg] != NoIndex;
117+ assert (IsTracked == llvm::is_contained (Registers, Reg));
118+ return IsTracked;
119+ }
120+
121+ unsigned getIndex (MCPhysReg Reg) const {
122+ assert (isTracked (Reg) && " Register is not tracked" );
123+ return RegToIndexMapping[Reg];
124+ }
125+ };
126+
61127// The security property that is checked is:
62128// When a register is used as the address to jump to in a return instruction,
63129// that register must either:
@@ -169,52 +235,34 @@ class PacRetAnalysis
169235 PacRetAnalysis (BinaryFunction &BF, MCPlusBuilder::AllocatorIdTy AllocId,
170236 const std::vector<MCPhysReg> &RegsToTrackInstsFor)
171237 : Parent(BF, AllocId), NumRegs(BF.getBinaryContext().MRI->getNumRegs ()),
172- RegsToTrackInstsFor(RegsToTrackInstsFor),
173- TrackingLastInsts(!RegsToTrackInstsFor.empty()),
174- Reg2StateIdx(RegsToTrackInstsFor.empty()
175- ? 0
176- : *llvm::max_element(RegsToTrackInstsFor) + 1,
177- -1) {
178- for (unsigned I = 0 ; I < RegsToTrackInstsFor.size (); ++I)
179- Reg2StateIdx[RegsToTrackInstsFor[I]] = I;
180- }
238+ RegsToTrackInstsFor(RegsToTrackInstsFor) {}
181239 virtual ~PacRetAnalysis () {}
182240
183241protected:
184242 const unsigned NumRegs;
185243 // / RegToTrackInstsFor is the set of registers for which the dataflow analysis
186244 // / must compute which the last set of instructions writing to it are.
187- const std::vector<MCPhysReg> RegsToTrackInstsFor;
188- const bool TrackingLastInsts;
189- // / Reg2StateIdx maps Register to the index in the vector used in State to
190- // / track which instructions last wrote to this register.
191- std::vector<uint16_t > Reg2StateIdx;
245+ const TrackedRegisters RegsToTrackInstsFor;
192246
193247 SmallPtrSet<const MCInst *, 4 > &lastWritingInsts (State &S,
194248 MCPhysReg Reg) const {
195- assert (Reg < Reg2StateIdx.size ());
196- assert (isTrackingReg (Reg));
197- return S.LastInstWritingReg [Reg2StateIdx[Reg]];
249+ unsigned Index = RegsToTrackInstsFor.getIndex (Reg);
250+ return S.LastInstWritingReg [Index];
198251 }
199252 const SmallPtrSet<const MCInst *, 4 > &lastWritingInsts (const State &S,
200253 MCPhysReg Reg) const {
201- assert (Reg < Reg2StateIdx.size ());
202- assert (isTrackingReg (Reg));
203- return S.LastInstWritingReg [Reg2StateIdx[Reg]];
204- }
205-
206- bool isTrackingReg (MCPhysReg Reg) const {
207- return llvm::is_contained (RegsToTrackInstsFor, Reg);
254+ unsigned Index = RegsToTrackInstsFor.getIndex (Reg);
255+ return S.LastInstWritingReg [Index];
208256 }
209257
210258 void preflight () {}
211259
212260 State getStartingStateAtBB (const BinaryBasicBlock &BB) {
213- return State (NumRegs, RegsToTrackInstsFor.size ());
261+ return State (NumRegs, RegsToTrackInstsFor.getNumTrackedRegisters ());
214262 }
215263
216264 State getStartingStateAtPoint (const MCInst &Point) {
217- return State (NumRegs, RegsToTrackInstsFor.size ());
265+ return State (NumRegs, RegsToTrackInstsFor.getNumTrackedRegisters ());
218266 }
219267
220268 void doConfluence (State &StateOut, const State &StateIn) {
@@ -275,7 +323,7 @@ class PacRetAnalysis
275323 Next.NonAutClobRegs |= Written;
276324 // Keep track of this instruction if it writes to any of the registers we
277325 // need to track that for:
278- for (MCPhysReg Reg : RegsToTrackInstsFor)
326+ for (MCPhysReg Reg : RegsToTrackInstsFor. getRegisters () )
279327 if (Written[Reg])
280328 lastWritingInsts (Next, Reg) = {&Point};
281329
@@ -287,7 +335,7 @@ class PacRetAnalysis
287335 // https://github.com/llvm/llvm-project/pull/122304#discussion_r1939515516
288336 Next.NonAutClobRegs .reset (
289337 BC.MIB ->getAliases (*AutReg, /* OnlySmaller=*/ true ));
290- if (TrackingLastInsts && isTrackingReg (*AutReg))
338+ if (RegsToTrackInstsFor. isTracked (*AutReg))
291339 lastWritingInsts (Next, *AutReg).clear ();
292340 }
293341
@@ -306,7 +354,7 @@ class PacRetAnalysis
306354 std::vector<MCInstReference>
307355 getLastClobberingInsts (const MCInst Ret, BinaryFunction &BF,
308356 const BitVector &UsedDirtyRegs) const {
309- if (!TrackingLastInsts )
357+ if (RegsToTrackInstsFor. empty () )
310358 return {};
311359 auto MaybeState = getStateAt (Ret);
312360 if (!MaybeState)
@@ -355,28 +403,18 @@ Analysis::computeDfState(PacRetAnalysis &PRA, BinaryFunction &BF,
355403 }
356404 MCPhysReg RetReg = *MaybeRetReg;
357405 LLVM_DEBUG ({
358- dbgs () << " Found RET inst: " ;
359- BC.printInstruction (dbgs (), Inst);
360- dbgs () << " RetReg: " << BC.MRI ->getName (RetReg)
361- << " ; authenticatesReg: "
362- << BC.MIB ->isAuthenticationOfReg (Inst, RetReg) << " \n " ;
406+ traceInst (BC, " Found RET inst" , Inst);
407+ traceReg (BC, " RetReg" , RetReg);
408+ traceReg (BC, " Authenticated reg" , BC.MIB ->getAuthenticatedReg (Inst));
363409 });
364410 if (BC.MIB ->isAuthenticationOfReg (Inst, RetReg))
365411 break ;
366412 BitVector UsedDirtyRegs = PRA.getStateAt (Inst)->NonAutClobRegs ;
367- LLVM_DEBUG ({
368- dbgs () << " NonAutClobRegs at Ret: " ;
369- RegStatePrinter RSP (BC);
370- RSP.print (dbgs (), UsedDirtyRegs);
371- dbgs () << " \n " ;
372- });
413+ LLVM_DEBUG (
414+ { traceRegMask (BC, " NonAutClobRegs at Ret" , UsedDirtyRegs); });
373415 UsedDirtyRegs &= BC.MIB ->getAliases (RetReg, /* OnlySmaller=*/ true );
374- LLVM_DEBUG ({
375- dbgs () << " Intersection with RetReg: " ;
376- RegStatePrinter RSP (BC);
377- RSP.print (dbgs (), UsedDirtyRegs);
378- dbgs () << " \n " ;
379- });
416+ LLVM_DEBUG (
417+ { traceRegMask (BC, " Intersection with RetReg" , UsedDirtyRegs); });
380418 if (UsedDirtyRegs.any ()) {
381419 // This return instruction needs to be reported
382420 Result.Diagnostics .push_back (std::make_shared<Gadget>(
@@ -472,12 +510,6 @@ void Gadget::generateReport(raw_ostream &OS, const BinaryContext &BC) const {
472510 OS << " " << (I + 1 ) << " . " ;
473511 BC.printInstruction (OS, InstRef, InstRef.getAddress (), BF);
474512 };
475- LLVM_DEBUG ({
476- dbgs () << " .. OverWritingRetRegInst:\n " ;
477- for (MCInstReference Ref : OverwritingRetRegInst) {
478- dbgs () << " " << Ref << " \n " ;
479- }
480- });
481513 if (OverwritingRetRegInst.size () == 1 ) {
482514 const MCInstReference OverwInst = OverwritingRetRegInst[0 ];
483515 assert (OverwInst.ParentKind == MCInstReference::BasicBlockParent);
0 commit comments