Skip to content

Commit b1e246b

Browse files
committed
Address the review comments
1 parent 8b1df05 commit b1e246b

File tree

3 files changed

+174
-124
lines changed

3 files changed

+174
-124
lines changed

bolt/include/bolt/Passes/PAuthGadgetScanner.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ raw_ostream &operator<<(raw_ostream &OS, const MCInstReference &);
199199

200200
namespace PAuthGadgetScanner {
201201

202-
class PacRetAnalysis;
202+
class RegisterSafetyAnalysis;
203203
struct State;
204204

205205
/// Description of a gadget kind that can be detected. Intended to be

bolt/lib/Passes/PAuthGadgetScanner.cpp

Lines changed: 109 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -124,27 +124,6 @@ class TrackedRegisters {
124124
}
125125
};
126126

127-
// Without CFG, we reset gadget scanning state when encountering an
128-
// unconditional branch. Note that BC.MIB->isUnconditionalBranch neither
129-
// considers indirect branches nor annotated tail calls as unconditional.
130-
static bool isStateTrackingBoundary(const BinaryContext &BC,
131-
const MCInst &Inst) {
132-
const MCInstrDesc &Desc = BC.MII->get(Inst.getOpcode());
133-
// Adapted from llvm::MCInstrDesc::isUnconditionalBranch().
134-
return Desc.isBranch() && Desc.isBarrier();
135-
}
136-
137-
template <typename T> static void iterateOverInstrs(BinaryFunction &BF, T Fn) {
138-
if (BF.hasCFG()) {
139-
for (BinaryBasicBlock &BB : BF)
140-
for (int64_t I = 0, E = BB.size(); I < E; ++I)
141-
Fn(MCInstInBBReference(&BB, I));
142-
} else {
143-
for (auto I : BF.instrs())
144-
Fn(MCInstInBFReference(&BF, I.first));
145-
}
146-
}
147-
148127
// The security property that is checked is:
149128
// When a register is used as the address to jump to in a return instruction,
150129
// that register must be safe-to-dereference. It must either
@@ -286,16 +265,21 @@ void PacStatePrinter::print(raw_ostream &OS, const State &S) const {
286265
OS << ">";
287266
}
288267

289-
class PacRetAnalysis {
268+
/// Computes which registers are safe to be used by control flow instructions.
269+
///
270+
/// This is the base class for two implementations: a dataflow-based analysis
271+
/// which is intended to be used for most functions and a simplified CFG-unaware
272+
/// version for functions without reconstructed CFG.
273+
class RegisterSafetyAnalysis {
290274
public:
291-
PacRetAnalysis(BinaryFunction &BF,
292-
const std::vector<MCPhysReg> &RegsToTrackInstsFor)
275+
RegisterSafetyAnalysis(BinaryFunction &BF,
276+
const std::vector<MCPhysReg> &RegsToTrackInstsFor)
293277
: BC(BF.getBinaryContext()), NumRegs(BC.MRI->getNumRegs()),
294278
RegsToTrackInstsFor(RegsToTrackInstsFor) {}
295279

296-
virtual ~PacRetAnalysis() {}
280+
virtual ~RegisterSafetyAnalysis() {}
297281

298-
static std::shared_ptr<PacRetAnalysis>
282+
static std::shared_ptr<RegisterSafetyAnalysis>
299283
create(BinaryFunction &BF, MCPlusBuilder::AllocatorIdTy AllocId,
300284
const std::vector<MCPhysReg> &RegsToTrackInstsFor);
301285

@@ -373,7 +357,7 @@ class PacRetAnalysis {
373357
State computeNext(const MCInst &Point, const State &Cur) {
374358
PacStatePrinter P(BC);
375359
LLVM_DEBUG({
376-
dbgs() << " PacRetAnalysis::ComputeNext(";
360+
dbgs() << " RegisterSafetyAnalysis::ComputeNext(";
377361
BC.InstPrinter->printInst(&const_cast<MCInst &>(Point), 0, "", *BC.STI,
378362
dbgs());
379363
dbgs() << ", ";
@@ -422,7 +406,7 @@ class PacRetAnalysis {
422406
}
423407

424408
LLVM_DEBUG({
425-
dbgs() << " .. result: (";
409+
dbgs() << " .. result: (";
426410
P.print(dbgs(), Next);
427411
dbgs() << ")\n";
428412
});
@@ -456,21 +440,23 @@ class PacRetAnalysis {
456440
}
457441
};
458442

459-
class PacRetDFAnalysis
460-
: public PacRetAnalysis,
461-
public DataflowAnalysis<PacRetDFAnalysis, State, /*Backward=*/false,
462-
PacStatePrinter> {
463-
using DFParent =
464-
DataflowAnalysis<PacRetDFAnalysis, State, false, PacStatePrinter>;
443+
class DataflowRegisterSafetyAnalysis
444+
: public RegisterSafetyAnalysis,
445+
public DataflowAnalysis<DataflowRegisterSafetyAnalysis, State,
446+
/*Backward=*/false, PacStatePrinter> {
447+
using DFParent = DataflowAnalysis<DataflowRegisterSafetyAnalysis, State,
448+
false, PacStatePrinter>;
465449
friend DFParent;
466450

467-
using PacRetAnalysis::BC;
468-
using PacRetAnalysis::computeNext;
451+
using RegisterSafetyAnalysis::BC;
452+
using RegisterSafetyAnalysis::computeNext;
469453

470454
public:
471-
PacRetDFAnalysis(BinaryFunction &BF, MCPlusBuilder::AllocatorIdTy AllocId,
472-
const std::vector<MCPhysReg> &RegsToTrackInstsFor)
473-
: PacRetAnalysis(BF, RegsToTrackInstsFor), DFParent(BF, AllocId) {}
455+
DataflowRegisterSafetyAnalysis(
456+
BinaryFunction &BF, MCPlusBuilder::AllocatorIdTy AllocId,
457+
const std::vector<MCPhysReg> &RegsToTrackInstsFor)
458+
: RegisterSafetyAnalysis(BF, RegsToTrackInstsFor), DFParent(BF, AllocId) {
459+
}
474460

475461
ErrorOr<const State &> getStateBefore(const MCInst &Inst) const override {
476462
return DFParent::getStateBefore(Inst);
@@ -493,28 +479,64 @@ class PacRetDFAnalysis
493479
void doConfluence(State &StateOut, const State &StateIn) {
494480
PacStatePrinter P(BC);
495481
LLVM_DEBUG({
496-
dbgs() << " PacRetAnalysis::Confluence(\n";
497-
dbgs() << " State 1: ";
482+
dbgs() << " DataflowRegisterSafetyAnalysis::Confluence(\n";
483+
dbgs() << " State 1: ";
498484
P.print(dbgs(), StateOut);
499485
dbgs() << "\n";
500-
dbgs() << " State 2: ";
486+
dbgs() << " State 2: ";
501487
P.print(dbgs(), StateIn);
502488
dbgs() << ")\n";
503489
});
504490

505491
StateOut.merge(StateIn);
506492

507493
LLVM_DEBUG({
508-
dbgs() << " merged state: ";
494+
dbgs() << " merged state: ";
509495
P.print(dbgs(), StateOut);
510496
dbgs() << "\n";
511497
});
512498
}
513499

514-
StringRef getAnnotationName() const { return "PacRetAnalysis"; }
500+
StringRef getAnnotationName() const {
501+
return "DataflowRegisterSafetyAnalysis";
502+
}
515503
};
516504

517-
class NoCFGPacRetAnalysis : public PacRetAnalysis {
505+
// A simplified implementation of DataflowRegisterSafetyAnalysis for functions
506+
// lacking CFG information.
507+
//
508+
// Let assume the instructions can only be executed linearly unless there is
509+
// a label to jump to - this should handle both directly jumping to a location
510+
// encoded as an immediate operand of a branch instruction, as well as saving a
511+
// branch destination somewhere and passing it to an indirect branch instruction
512+
// later, provided no arithmetic is performed on the destination address:
513+
//
514+
// ; good: the destination is directly encoded into the branch instruction
515+
// cbz x0, some_label
516+
//
517+
// ; good: the branch destination is first stored and then used as-is
518+
// adr x1, some_label
519+
// br x1
520+
//
521+
// ; bad: some clever arithmetic is performed manually
522+
// adr x1, some_label
523+
// add x1, x1, #4
524+
// br x1
525+
// ...
526+
// some_label:
527+
// ; pessimistically reset the state as we are unsure where we came from
528+
// ...
529+
// ret
530+
// JTI0:
531+
// .byte some_label - Ltmp0 ; computing offsets using labels may probably
532+
// work too, provided enough information is
533+
// retained by the assembler and linker
534+
//
535+
// Then, a function can be split into a number of disjoint contiguous sequences
536+
// of instructions without labels in between. These sequences can be processed
537+
// the same way basic blocks are processed by data-flow analysis, assuming
538+
// pessimistically that all registers are unsafe at the start of each sequence.
539+
class CFGUnawareRegisterSafetyAnalysis : public RegisterSafetyAnalysis {
518540
BinaryFunction &BF;
519541
MCPlusBuilder::AllocatorIdTy AllocId;
520542
unsigned StateAnnotationIndex;
@@ -531,11 +553,13 @@ class NoCFGPacRetAnalysis : public PacRetAnalysis {
531553
}
532554

533555
public:
534-
NoCFGPacRetAnalysis(BinaryFunction &BF, MCPlusBuilder::AllocatorIdTy AllocId,
535-
const std::vector<MCPhysReg> &RegsToTrackInstsFor)
536-
: PacRetAnalysis(BF, RegsToTrackInstsFor), BF(BF), AllocId(AllocId) {
556+
CFGUnawareRegisterSafetyAnalysis(
557+
BinaryFunction &BF, MCPlusBuilder::AllocatorIdTy AllocId,
558+
const std::vector<MCPhysReg> &RegsToTrackInstsFor)
559+
: RegisterSafetyAnalysis(BF, RegsToTrackInstsFor), BF(BF),
560+
AllocId(AllocId) {
537561
StateAnnotationIndex =
538-
BC.MIB->getOrCreateAnnotationIndex("NoCFGPacRetAnalysis");
562+
BC.MIB->getOrCreateAnnotationIndex("CFGUnawareRegisterSafetyAnalysis");
539563
}
540564

541565
void run() override {
@@ -547,40 +571,40 @@ class NoCFGPacRetAnalysis : public PacRetAnalysis {
547571
// can be jumped-to, thus conservatively resetting S. As an exception,
548572
// let's ignore any labels at the beginning of the function, as at least
549573
// one label is expected there.
550-
if (BF.hasLabelAt(I.first) && &Inst != &BF.instrs().begin()->second)
574+
if (BF.hasLabelAt(I.first) && &Inst != &BF.instrs().begin()->second) {
575+
LLVM_DEBUG({
576+
traceInst(BC, "Due to label, resetting the state before", Inst);
577+
});
551578
S = createUnsafeState();
579+
}
552580

553581
// Check if we need to remove an old annotation (this is the case if
554582
// this is the second, detailed, run of the analysis).
555583
if (BC.MIB->hasAnnotation(Inst, StateAnnotationIndex))
556584
BC.MIB->removeAnnotation(Inst, StateAnnotationIndex);
557-
// Attach the state *before* this instruction.
585+
// Attach the state *before* this instruction executes.
558586
BC.MIB->addAnnotation(Inst, StateAnnotationIndex, S, AllocId);
559587

560588
// Compute the state after this instruction.
561-
// If this instruction is an unconditional branch (incl. indirect ones),
562-
// reset the state.
563-
if (isStateTrackingBoundary(BC, Inst))
564-
S = createUnsafeState();
565-
else
566-
S = computeNext(Inst, S);
589+
S = computeNext(Inst, S);
567590
}
568591
}
569592

570593
ErrorOr<const State &> getStateBefore(const MCInst &Inst) const override {
571594
return BC.MIB->getAnnotationAs<State>(Inst, StateAnnotationIndex);
572595
}
573596

574-
~NoCFGPacRetAnalysis() { cleanStateAnnotations(); }
597+
~CFGUnawareRegisterSafetyAnalysis() { cleanStateAnnotations(); }
575598
};
576599

577-
std::shared_ptr<PacRetAnalysis>
578-
PacRetAnalysis::create(BinaryFunction &BF, MCPlusBuilder::AllocatorIdTy AllocId,
579-
const std::vector<MCPhysReg> &RegsToTrackInstsFor) {
600+
std::shared_ptr<RegisterSafetyAnalysis> RegisterSafetyAnalysis::create(
601+
BinaryFunction &BF, MCPlusBuilder::AllocatorIdTy AllocId,
602+
const std::vector<MCPhysReg> &RegsToTrackInstsFor) {
580603
if (BF.hasCFG())
581-
return std::make_shared<PacRetDFAnalysis>(BF, AllocId, RegsToTrackInstsFor);
582-
return std::make_shared<NoCFGPacRetAnalysis>(BF, AllocId,
583-
RegsToTrackInstsFor);
604+
return std::make_shared<DataflowRegisterSafetyAnalysis>(
605+
BF, AllocId, RegsToTrackInstsFor);
606+
return std::make_shared<CFGUnawareRegisterSafetyAnalysis>(
607+
BF, AllocId, RegsToTrackInstsFor);
584608
}
585609

586610
static std::shared_ptr<Report>
@@ -634,21 +658,33 @@ shouldReportCallGadget(const BinaryContext &BC, const MCInstReference &Inst,
634658
return std::make_shared<GadgetReport>(CallKind, Inst, DestReg);
635659
}
636660

661+
template <typename T> static void iterateOverInstrs(BinaryFunction &BF, T Fn) {
662+
if (BF.hasCFG()) {
663+
for (BinaryBasicBlock &BB : BF)
664+
for (int64_t I = 0, E = BB.size(); I < E; ++I)
665+
Fn(MCInstInBBReference(&BB, I));
666+
} else {
667+
for (auto I : BF.instrs())
668+
Fn(MCInstInBFReference(&BF, I.first));
669+
}
670+
}
671+
637672
FunctionAnalysisResult
638673
Analysis::findGadgets(BinaryFunction &BF,
639674
MCPlusBuilder::AllocatorIdTy AllocatorId) {
640675
FunctionAnalysisResult Result;
641676

642-
auto PRA = PacRetAnalysis::create(BF, AllocatorId, {});
643-
PRA->run();
677+
auto RSA = RegisterSafetyAnalysis::create(BF, AllocatorId, {});
678+
LLVM_DEBUG({ dbgs() << "Running register safety analysis...\n"; });
679+
RSA->run();
644680
LLVM_DEBUG({
645-
dbgs() << " After PacRetAnalysis:\n";
681+
dbgs() << "After register safety analysis:\n";
646682
BF.dump();
647683
});
648684

649685
BinaryContext &BC = BF.getBinaryContext();
650686
iterateOverInstrs(BF, [&](MCInstReference Inst) {
651-
const State &S = *PRA->getStateBefore(Inst);
687+
const State &S = *RSA->getStateBefore(Inst);
652688

653689
// If non-empty state was never propagated from the entry basic block
654690
// to Inst, assume it to be unreachable and report a warning.
@@ -682,10 +718,11 @@ void Analysis::computeDetailedInfo(BinaryFunction &BF,
682718
std::vector<MCPhysReg> RegsToTrackVec(RegsToTrack.begin(), RegsToTrack.end());
683719

684720
// Re-compute the analysis with register tracking.
685-
auto PRWIA = PacRetAnalysis::create(BF, AllocatorId, RegsToTrackVec);
686-
PRWIA->run();
721+
auto RSWIA = RegisterSafetyAnalysis::create(BF, AllocatorId, RegsToTrackVec);
722+
LLVM_DEBUG({ dbgs() << "\nRunning detailed register safety analysis...\n"; });
723+
RSWIA->run();
687724
LLVM_DEBUG({
688-
dbgs() << " After detailed PacRetAnalysis:\n";
725+
dbgs() << "After detailed register safety analysis:\n";
689726
BF.dump();
690727
});
691728

@@ -694,7 +731,7 @@ void Analysis::computeDetailedInfo(BinaryFunction &BF,
694731
LLVM_DEBUG(
695732
{ traceInst(BC, "Attaching clobbering info to", Report->Location); });
696733
(void)BC;
697-
Report->setOverwritingInstrs(PRWIA->getLastClobberingInsts(
734+
Report->setOverwritingInstrs(RSWIA->getLastClobberingInsts(
698735
Report->Location, BF, Report->getAffectedRegisters()));
699736
}
700737
}

0 commit comments

Comments
 (0)