Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
//===- bolt/Passes/NonPacProtectedRetAnalysis.h -----------------*- C++ -*-===//
//===- bolt/Passes/PAuthGadgetScanner.h -------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef BOLT_PASSES_NONPACPROTECTEDRETANALYSIS_H
#define BOLT_PASSES_NONPACPROTECTEDRETANALYSIS_H
#ifndef BOLT_PASSES_PAUTHGADGETSCANNER_H
#define BOLT_PASSES_PAUTHGADGETSCANNER_H

#include "bolt/Core/BinaryContext.h"
#include "bolt/Core/BinaryFunction.h"
Expand Down Expand Up @@ -173,63 +173,59 @@ struct MCInstReference {

raw_ostream &operator<<(raw_ostream &OS, const MCInstReference &);

struct GeneralDiagnostic {
std::string Text;
GeneralDiagnostic(const std::string &Text) : Text(Text) {}
bool operator==(const GeneralDiagnostic &RHS) const {
return Text == RHS.Text;
}
namespace PAuthGadgetScanner {

class PacRetAnalysis;
struct State;

/// Description of a gadget kind that can be detected. Intended to be
/// statically allocated to be attached to reports by reference.
class GadgetKind {
const char *Description;

public:
GadgetKind(const char *Description) : Description(Description) {}

const StringRef getDescription() const { return Description; }
};

raw_ostream &operator<<(raw_ostream &OS, const GeneralDiagnostic &Diag);
/// Base report located at some instruction, without any additional information.
struct Report {
MCInstReference Location;

Report(MCInstReference Location) : Location(Location) {}
virtual ~Report() {}

namespace NonPacProtectedRetAnalysis {
struct Annotation {
MCInstReference RetInst;
Annotation(MCInstReference RetInst) : RetInst(RetInst) {}
virtual bool operator==(const Annotation &RHS) const {
return RetInst == RHS.RetInst;
}
Annotation &operator=(const Annotation &Other) {
if (this == &Other)
return *this;
RetInst = Other.RetInst;
return *this;
}
virtual ~Annotation() {}
virtual void generateReport(raw_ostream &OS,
const BinaryContext &BC) const = 0;

void printBasicInfo(raw_ostream &OS, const BinaryContext &BC,
StringRef IssueKind) const;
};

struct Gadget : public Annotation {
std::vector<MCInstReference> OverwritingRetRegInst;
virtual bool operator==(const Gadget &RHS) const {
return Annotation::operator==(RHS) &&
OverwritingRetRegInst == RHS.OverwritingRetRegInst;
}
Gadget(MCInstReference RetInst,
const std::vector<MCInstReference> &OverwritingRetRegInst)
: Annotation(RetInst), OverwritingRetRegInst(OverwritingRetRegInst) {}
virtual void generateReport(raw_ostream &OS,
const BinaryContext &BC) const override;
struct GadgetReport : public Report {
const GadgetKind &Kind;
std::vector<MCInstReference> OverwritingInstrs;

GadgetReport(const GadgetKind &Kind, MCInstReference Location,
std::vector<MCInstReference> OverwritingInstrs)
: Report(Location), Kind(Kind), OverwritingInstrs(OverwritingInstrs) {}

void generateReport(raw_ostream &OS, const BinaryContext &BC) const override;
};

struct GenDiag : public Annotation {
GeneralDiagnostic Diag;
virtual bool operator==(const GenDiag &RHS) const {
return Annotation::operator==(RHS) && Diag == RHS.Diag;
}
GenDiag(MCInstReference RetInst, const std::string &Text)
: Annotation(RetInst), Diag(Text) {}
/// Report with a free-form message attached.
struct GenericReport : public Report {
std::string Text;
GenericReport(MCInstReference Location, const std::string &Text)
: Report(Location), Text(Text) {}
virtual void generateReport(raw_ostream &OS,
const BinaryContext &BC) const override;
};

class PacRetAnalysis;

struct FunctionAnalysisResult {
SmallSet<MCPhysReg, 1> RegistersAffected;
std::vector<std::shared_ptr<Annotation>> Diagnostics;
std::vector<std::shared_ptr<Report>> Diagnostics;
};

class Analysis : public BinaryFunctionPass {
Expand All @@ -245,13 +241,13 @@ class Analysis : public BinaryFunctionPass {
public:
explicit Analysis() : BinaryFunctionPass(false) {}

const char *getName() const override { return "non-pac-protected-rets"; }
const char *getName() const override { return "pauth-gadget-scanner"; }

/// Pass entry point
Error runOnFunctions(BinaryContext &BC) override;
};

} // namespace NonPacProtectedRetAnalysis
} // namespace PAuthGadgetScanner
} // namespace bolt
} // namespace llvm

Expand Down
2 changes: 1 addition & 1 deletion bolt/lib/Passes/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ add_llvm_library(LLVMBOLTPasses
LoopInversionPass.cpp
LivenessAnalysis.cpp
MCF.cpp
NonPacProtectedRetAnalysis.cpp
PatchEntries.cpp
PAuthGadgetScanner.cpp
PettisAndHansen.cpp
PLTCall.cpp
ProfileQualityStats.cpp
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
//===- bolt/Passes/NonPacProtectedRetAnalysis.cpp -------------------------===//
//===- bolt/Passes/PAuthGadgetScanner.cpp ---------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
Expand All @@ -11,7 +11,7 @@
//
//===----------------------------------------------------------------------===//

#include "bolt/Passes/NonPacProtectedRetAnalysis.h"
#include "bolt/Passes/PAuthGadgetScanner.h"
#include "bolt/Core/ParallelUtilities.h"
#include "bolt/Passes/DataflowAnalysis.h"
#include "llvm/ADT/STLExtras.h"
Expand All @@ -20,7 +20,7 @@
#include "llvm/Support/Format.h"
#include <memory>

#define DEBUG_TYPE "bolt-nonpacprotectedret"
#define DEBUG_TYPE "bolt-pauth-scanner"

namespace llvm {
namespace bolt {
Expand Down Expand Up @@ -57,7 +57,7 @@ raw_ostream &operator<<(raw_ostream &OS, const MCInstReference &Ref) {
llvm_unreachable("");
}

namespace NonPacProtectedRetAnalysis {
namespace PAuthGadgetScanner {

[[maybe_unused]] static void traceInst(const BinaryContext &BC, StringRef Label,
const MCInst &MI) {
Expand Down Expand Up @@ -395,7 +395,7 @@ Analysis::computeDfState(PacRetAnalysis &PRA, BinaryFunction &BF,
if (BC.MIB->isReturn(Inst)) {
ErrorOr<MCPhysReg> MaybeRetReg = BC.MIB->getRegUsedAsRetDest(Inst);
if (MaybeRetReg.getError()) {
Result.Diagnostics.push_back(std::make_shared<GenDiag>(
Result.Diagnostics.push_back(std::make_shared<GenericReport>(
MCInstInBBReference(&BB, I),
"Warning: pac-ret analysis could not analyze this return "
"instruction"));
Expand All @@ -416,9 +416,10 @@ Analysis::computeDfState(PacRetAnalysis &PRA, BinaryFunction &BF,
LLVM_DEBUG(
{ traceRegMask(BC, "Intersection with RetReg", UsedDirtyRegs); });
if (UsedDirtyRegs.any()) {
static const GadgetKind RetKind("non-protected ret found");
// This return instruction needs to be reported
Result.Diagnostics.push_back(std::make_shared<Gadget>(
MCInstInBBReference(&BB, I),
Result.Diagnostics.push_back(std::make_shared<GadgetReport>(
RetKind, MCInstInBBReference(&BB, I),
PRA.getLastClobberingInsts(Inst, BF, UsedDirtyRegs)));
for (MCPhysReg RetRegWithGadget : UsedDirtyRegs.set_bits())
Result.RegistersAffected.insert(RetRegWithGadget);
Expand Down Expand Up @@ -480,54 +481,61 @@ static void printBB(const BinaryContext &BC, const BinaryBasicBlock *BB,

static void reportFoundGadgetInSingleBBSingleOverwInst(
raw_ostream &OS, const BinaryContext &BC, const MCInstReference OverwInst,
const MCInstReference RetInst) {
BinaryBasicBlock *BB = RetInst.getBasicBlock();
const MCInstReference Location) {
BinaryBasicBlock *BB = Location.getBasicBlock();
assert(OverwInst.ParentKind == MCInstReference::BasicBlockParent);
assert(RetInst.ParentKind == MCInstReference::BasicBlockParent);
assert(Location.ParentKind == MCInstReference::BasicBlockParent);
MCInstInBBReference OverwInstBB = OverwInst.U.BBRef;
if (BB == OverwInstBB.BB) {
// overwriting inst and ret instruction are in the same basic block.
assert(OverwInstBB.BBIndex < RetInst.U.BBRef.BBIndex);
assert(OverwInstBB.BBIndex < Location.U.BBRef.BBIndex);
OS << " This happens in the following basic block:\n";
printBB(BC, BB);
}
}

void Gadget::generateReport(raw_ostream &OS, const BinaryContext &BC) const {
GenDiag(RetInst, "non-protected ret found").generateReport(OS, BC);
void Report::printBasicInfo(raw_ostream &OS, const BinaryContext &BC,
StringRef IssueKind) const {
BinaryFunction *BF = Location.getFunction();
BinaryBasicBlock *BB = Location.getBasicBlock();

OS << "\nGS-PAUTH: " << IssueKind;
OS << " in function " << BF->getPrintName();
if (BB)
OS << ", basic block " << BB->getName();
OS << ", at address " << llvm::format("%x", Location.getAddress()) << "\n";
OS << " The instruction is ";
BC.printInstruction(OS, Location, Location.getAddress(), BF);
}

BinaryFunction *BF = RetInst.getFunction();
OS << " The " << OverwritingRetRegInst.size()
<< " instructions that write to the return register after any "
void GadgetReport::generateReport(raw_ostream &OS,
const BinaryContext &BC) const {
printBasicInfo(OS, BC, Kind.getDescription());

BinaryFunction *BF = Location.getFunction();
OS << " The " << OverwritingInstrs.size()
<< " instructions that write to the affected registers after any "
"authentication are:\n";
// Sort by address to ensure output is deterministic.
std::vector<MCInstReference> ORRI = OverwritingRetRegInst;
llvm::sort(ORRI, [](const MCInstReference &A, const MCInstReference &B) {
std::vector<MCInstReference> OI = OverwritingInstrs;
llvm::sort(OI, [](const MCInstReference &A, const MCInstReference &B) {
return A.getAddress() < B.getAddress();
});
for (unsigned I = 0; I < ORRI.size(); ++I) {
MCInstReference InstRef = ORRI[I];
for (unsigned I = 0; I < OI.size(); ++I) {
MCInstReference InstRef = OI[I];
OS << " " << (I + 1) << ". ";
BC.printInstruction(OS, InstRef, InstRef.getAddress(), BF);
};
if (OverwritingRetRegInst.size() == 1) {
const MCInstReference OverwInst = OverwritingRetRegInst[0];
if (OverwritingInstrs.size() == 1) {
const MCInstReference OverwInst = OverwritingInstrs[0];
assert(OverwInst.ParentKind == MCInstReference::BasicBlockParent);
reportFoundGadgetInSingleBBSingleOverwInst(OS, BC, OverwInst, RetInst);
reportFoundGadgetInSingleBBSingleOverwInst(OS, BC, OverwInst, Location);
}
}

void GenDiag::generateReport(raw_ostream &OS, const BinaryContext &BC) const {
BinaryFunction *BF = RetInst.getFunction();
BinaryBasicBlock *BB = RetInst.getBasicBlock();

OS << "\nGS-PACRET: " << Diag.Text;
OS << " in function " << BF->getPrintName();
if (BB)
OS << ", basic block " << BB->getName();
OS << ", at address " << llvm::format("%x", RetInst.getAddress()) << "\n";
OS << " The return instruction is ";
BC.printInstruction(OS, RetInst, RetInst.getAddress(), BF);
void GenericReport::generateReport(raw_ostream &OS,
const BinaryContext &BC) const {
printBasicInfo(OS, BC, Text);
}

Error Analysis::runOnFunctions(BinaryContext &BC) {
Expand All @@ -542,17 +550,16 @@ Error Analysis::runOnFunctions(BinaryContext &BC) {

ParallelUtilities::runOnEachFunctionWithUniqueAllocId(
BC, ParallelUtilities::SchedulingPolicy::SP_INST_LINEAR, WorkFun,
SkipFunc, "NonPacProtectedRetAnalysis");
SkipFunc, "PAuthGadgetScanner");

for (BinaryFunction *BF : BC.getAllBinaryFunctions())
if (AnalysisResults.count(BF) > 0) {
for (const std::shared_ptr<Annotation> &A :
AnalysisResults[BF].Diagnostics)
A->generateReport(outs(), BC);
for (const std::shared_ptr<Report> &R : AnalysisResults[BF].Diagnostics)
R->generateReport(outs(), BC);
}
return Error::success();
}

} // namespace NonPacProtectedRetAnalysis
} // namespace PAuthGadgetScanner
} // namespace bolt
} // namespace llvm
5 changes: 2 additions & 3 deletions bolt/lib/Rewrite/RewriteInstance.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
#include "bolt/Passes/BinaryPasses.h"
#include "bolt/Passes/CacheMetrics.h"
#include "bolt/Passes/IdenticalCodeFolding.h"
#include "bolt/Passes/NonPacProtectedRetAnalysis.h"
#include "bolt/Passes/PAuthGadgetScanner.h"
#include "bolt/Passes/ReorderFunctions.h"
#include "bolt/Profile/BoltAddressTranslation.h"
#include "bolt/Profile/DataAggregator.h"
Expand Down Expand Up @@ -3544,8 +3544,7 @@ void RewriteInstance::runBinaryAnalyses() {
opts::GadgetScannersToRun.addValue(GSK::GS_ALL);
for (GSK ScannerToRun : opts::GadgetScannersToRun) {
if (ScannerToRun == GSK::GS_PACRET || ScannerToRun == GSK::GS_ALL)
Manager.registerPass(
std::make_unique<NonPacProtectedRetAnalysis::Analysis>());
Manager.registerPass(std::make_unique<PAuthGadgetScanner::Analysis>());
}

BC->logBOLTErrorsAndQuitOnFatal(Manager.runPasses());
Expand Down
Loading