Skip to content
Merged
151 changes: 120 additions & 31 deletions llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,9 +153,13 @@ class AArch64AsmPrinter : public AsmPrinter {
void emitPtrauthCheckAuthenticatedValue(Register TestedReg,
Register ScratchReg,
AArch64PACKey::ID Key,
AArch64PAuth::AuthCheckMethod Method,
bool ShouldTrap,
const MCSymbol *OnFailure);

// Check authenticated LR before tail calling.
void emitPtrauthTailCallHardening(const MachineInstr *TC);

// Emit the sequence for AUT or AUTPAC.
void emitPtrauthAuthResign(const MachineInstr *MI);

Expand Down Expand Up @@ -1752,7 +1756,8 @@ unsigned AArch64AsmPrinter::emitPtrauthDiscriminator(uint16_t Disc,
/// of proceeding to the next instruction (only if ShouldTrap is false).
void AArch64AsmPrinter::emitPtrauthCheckAuthenticatedValue(
Register TestedReg, Register ScratchReg, AArch64PACKey::ID Key,
bool ShouldTrap, const MCSymbol *OnFailure) {
AArch64PAuth::AuthCheckMethod Method, bool ShouldTrap,
const MCSymbol *OnFailure) {
// Insert a sequence to check if authentication of TestedReg succeeded,
// such as:
//
Expand All @@ -1778,38 +1783,70 @@ void AArch64AsmPrinter::emitPtrauthCheckAuthenticatedValue(
// Lsuccess:
// ...
//
// This sequence is expensive, but we need more information to be able to
// do better.
//
// We can't TBZ the poison bit because EnhancedPAC2 XORs the PAC bits
// on failure.
// We can't TST the PAC bits because we don't always know how the address
// space is setup for the target environment (and the bottom PAC bit is
// based on that).
// Either way, we also don't always know whether TBI is enabled or not for
// the specific target environment.
// See the documentation on AuthCheckMethod enumeration constants for
// the specific code sequences that can be used to perform the check.
using AArch64PAuth::AuthCheckMethod;

unsigned XPACOpc = getXPACOpcodeForKey(Key);
if (Method == AuthCheckMethod::None)
return;
if (Method == AuthCheckMethod::DummyLoad) {
EmitToStreamer(MCInstBuilder(AArch64::LDRWui)
.addReg(getWRegFromXReg(ScratchReg))
.addReg(TestedReg)
.addImm(0));
assert(ShouldTrap && !OnFailure && "DummyLoad always traps on error");
return;
}

MCSymbol *SuccessSym = createTempSymbol("auth_success_");
if (Method == AuthCheckMethod::XPAC || Method == AuthCheckMethod::XPACHint) {
// mov Xscratch, Xtested
emitMovXReg(ScratchReg, TestedReg);

// mov Xscratch, Xtested
emitMovXReg(ScratchReg, TestedReg);

// xpac(i|d) Xscratch
EmitToStreamer(MCInstBuilder(XPACOpc).addReg(ScratchReg).addReg(ScratchReg));
if (Method == AuthCheckMethod::XPAC) {
// xpac(i|d) Xscratch
unsigned XPACOpc = getXPACOpcodeForKey(Key);
EmitToStreamer(
MCInstBuilder(XPACOpc).addReg(ScratchReg).addReg(ScratchReg));
} else {
// xpaclri

// Note that this method applies XPAC to TestedReg instead of ScratchReg.
assert(TestedReg == AArch64::LR &&
"XPACHint mode is only compatible with checking the LR register");
assert((Key == AArch64PACKey::IA || Key == AArch64PACKey::IB) &&
"XPACHint mode is only compatible with I-keys");
EmitToStreamer(MCInstBuilder(AArch64::XPACLRI));
}

// cmp Xtested, Xscratch
EmitToStreamer(MCInstBuilder(AArch64::SUBSXrs)
.addReg(AArch64::XZR)
.addReg(TestedReg)
.addReg(ScratchReg)
.addImm(0));
// cmp Xtested, Xscratch
EmitToStreamer(MCInstBuilder(AArch64::SUBSXrs)
.addReg(AArch64::XZR)
.addReg(TestedReg)
.addReg(ScratchReg)
.addImm(0));

// b.eq Lsuccess
EmitToStreamer(MCInstBuilder(AArch64::Bcc)
.addImm(AArch64CC::EQ)
.addExpr(MCSymbolRefExpr::create(SuccessSym, OutContext)));
// b.eq Lsuccess
EmitToStreamer(
MCInstBuilder(AArch64::Bcc)
.addImm(AArch64CC::EQ)
.addExpr(MCSymbolRefExpr::create(SuccessSym, OutContext)));
} else if (Method == AuthCheckMethod::HighBitsNoTBI) {
// eor Xscratch, Xtested, Xtested, lsl #1
EmitToStreamer(MCInstBuilder(AArch64::EORXrs)
.addReg(ScratchReg)
.addReg(TestedReg)
.addReg(TestedReg)
.addImm(1));
// tbz Xscratch, #62, Lsuccess
EmitToStreamer(
MCInstBuilder(AArch64::TBZX)
.addReg(ScratchReg)
.addImm(62)
.addExpr(MCSymbolRefExpr::create(SuccessSym, OutContext)));
} else {
llvm_unreachable("Unsupported check method");
}

if (ShouldTrap) {
assert(!OnFailure && "Cannot specify OnFailure with ShouldTrap");
Expand All @@ -1823,9 +1860,26 @@ void AArch64AsmPrinter::emitPtrauthCheckAuthenticatedValue(
// Note that this can introduce an authentication oracle (such as based on
// the high bits of the re-signed value).

// FIXME: Can we simply return the AUT result, already in TestedReg?
// mov Xtested, Xscratch
emitMovXReg(TestedReg, ScratchReg);
// FIXME: The XPAC method can be optimized by applying XPAC to TestedReg
// instead of ScratchReg, thus eliminating one `mov` instruction.
// Both XPAC and XPACHint can be further optimized by not using a
// conditional branch jumping over an unconditional one.

switch (Method) {
case AuthCheckMethod::XPACHint:
// LR is already XPAC-ed at this point.
break;
case AuthCheckMethod::XPAC:
// mov Xtested, Xscratch
emitMovXReg(TestedReg, ScratchReg);
break;
default:
// If Xtested was not XPAC-ed so far, emit XPAC here.
// xpac(i|d) Xtested
unsigned XPACOpc = getXPACOpcodeForKey(Key);
EmitToStreamer(
MCInstBuilder(XPACOpc).addReg(TestedReg).addReg(TestedReg));
}

if (OnFailure) {
// b Lend
Expand All @@ -1840,6 +1894,30 @@ void AArch64AsmPrinter::emitPtrauthCheckAuthenticatedValue(
OutStreamer->emitLabel(SuccessSym);
}

// With Pointer Authentication, it may be needed to explicitly check the
// authenticated value in LR before performing a tail call.
// Otherwise, the callee may re-sign the invalid return address,
// introducing a signing oracle.
void AArch64AsmPrinter::emitPtrauthTailCallHardening(const MachineInstr *TC) {
if (!AArch64FI->shouldSignReturnAddress(*MF))
return;

auto LRCheckMethod = STI->getAuthenticatedLRCheckMethod(*MF);
if (LRCheckMethod == AArch64PAuth::AuthCheckMethod::None)
return;

const AArch64RegisterInfo *TRI = STI->getRegisterInfo();
Register ScratchReg =
TC->readsRegister(AArch64::X16, TRI) ? AArch64::X17 : AArch64::X16;
assert(!TC->readsRegister(ScratchReg, TRI) &&
"Neither x16 nor x17 is available as a scratch register");
AArch64PACKey::ID Key =
AArch64FI->shouldSignWithBKey() ? AArch64PACKey::IB : AArch64PACKey::IA;
emitPtrauthCheckAuthenticatedValue(
AArch64::LR, ScratchReg, Key, LRCheckMethod,
/*ShouldTrap=*/true, /*OnFailure=*/nullptr);
}

void AArch64AsmPrinter::emitPtrauthAuthResign(const MachineInstr *MI) {
const bool IsAUTPAC = MI->getOpcode() == AArch64::AUTPAC;

Expand All @@ -1851,7 +1929,7 @@ void AArch64AsmPrinter::emitPtrauthAuthResign(const MachineInstr *MI) {
// ; sign x16 (if AUTPAC)
// Lend: ; if not trapping on failure
//
// with the checking sequence chosen depending on whether we should check
// with the checking sequence chosen depending on whether/how we should check
// the pointer and whether we should trap on failure.

// By default, auth/resign sequences check for auth failures.
Expand Down Expand Up @@ -1911,6 +1989,7 @@ void AArch64AsmPrinter::emitPtrauthAuthResign(const MachineInstr *MI) {
EndSym = createTempSymbol("resign_end_");

emitPtrauthCheckAuthenticatedValue(AArch64::X16, AArch64::X17, AUTKey,
AArch64PAuth::AuthCheckMethod::XPAC,
ShouldTrap, EndSym);
}

Expand Down Expand Up @@ -2195,6 +2274,7 @@ void AArch64AsmPrinter::LowerMOVaddrPAC(const MachineInstr &MI) {
: AArch64PACKey::DA);

emitPtrauthCheckAuthenticatedValue(AArch64::X16, AArch64::X17, AuthKey,
AArch64PAuth::AuthCheckMethod::XPAC,
/*ShouldTrap=*/true,
/*OnFailure=*/nullptr);
}
Expand Down Expand Up @@ -2327,6 +2407,7 @@ void AArch64AsmPrinter::LowerLOADgotAUTH(const MachineInstr &MI) {
(AuthOpcode == AArch64::AUTIA ? AArch64PACKey::IA : AArch64PACKey::DA);

emitPtrauthCheckAuthenticatedValue(AuthResultReg, AArch64::X17, AuthKey,
AArch64PAuth::AuthCheckMethod::XPAC,
/*ShouldTrap=*/true,
/*OnFailure=*/nullptr);

Expand Down Expand Up @@ -2396,6 +2477,8 @@ void AArch64AsmPrinter::emitInstruction(const MachineInstr *MI) {
// Do any manual lowerings.
switch (MI->getOpcode()) {
default:
assert(!AArch64InstrInfo::isTailCallReturnInst(*MI) &&
"Unhandled tail call instruction");
break;
case AArch64::HINT: {
// CurrentPatchableFunctionEntrySym can be CurrentFnBegin only for
Expand Down Expand Up @@ -2539,6 +2622,8 @@ void AArch64AsmPrinter::emitInstruction(const MachineInstr *MI) {
? AArch64::X17
: AArch64::X16;

emitPtrauthTailCallHardening(MI);

unsigned DiscReg = AddrDisc;
if (Disc) {
if (AddrDisc != AArch64::NoRegister) {
Expand Down Expand Up @@ -2569,13 +2654,17 @@ void AArch64AsmPrinter::emitInstruction(const MachineInstr *MI) {
case AArch64::TCRETURNrix17:
case AArch64::TCRETURNrinotx16:
case AArch64::TCRETURNriALL: {
emitPtrauthTailCallHardening(MI);

MCInst TmpInst;
TmpInst.setOpcode(AArch64::BR);
TmpInst.addOperand(MCOperand::createReg(MI->getOperand(0).getReg()));
EmitToStreamer(*OutStreamer, TmpInst);
return;
}
case AArch64::TCRETURNdi: {
emitPtrauthTailCallHardening(MI);

MCOperand Dest;
MCInstLowering.lowerOperand(MI->getOperand(0), Dest);
MCInst TmpInst;
Expand Down
13 changes: 13 additions & 0 deletions llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,19 @@ unsigned AArch64InstrInfo::getInstSizeInBytes(const MachineInstr &MI) const {
unsigned NumBytes = 0;
const MCInstrDesc &Desc = MI.getDesc();

if (!MI.isBundle() && isTailCallReturnInst(MI)) {
NumBytes = Desc.getSize() ? Desc.getSize() : 4;

const auto *MFI = MF->getInfo<AArch64FunctionInfo>();
if (!MFI->shouldSignReturnAddress(MF))
return NumBytes;

const auto &STI = MF->getSubtarget<AArch64Subtarget>();
auto Method = STI.getAuthenticatedLRCheckMethod(*MF);
NumBytes += AArch64PAuth::getCheckerSizeInBytes(Method);
return NumBytes;
}

// Size should be preferably set in
// llvm/lib/Target/AArch64/AArch64InstrInfo.td (default case).
// Specific cases handle instructions of variable sizes
Expand Down
18 changes: 12 additions & 6 deletions llvm/lib/Target/AArch64/AArch64InstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -1964,30 +1964,36 @@ let Predicates = [HasPAuth] in {
}

// Size 16: 4 fixed + 8 variable, to compute discriminator.
// The size returned by getInstSizeInBytes() is incremented according
// to the variant of LR check.
// As the check requires either x16 or x17 as a scratch register and
// authenticated tail call instructions have two register operands,
// make sure at least one register is usable as a scratch one - for that
// purpose, use tcGPRnotx16x17 register class for one of the operands.
let isCall = 1, isTerminator = 1, isReturn = 1, isBarrier = 1, Size = 16,
Uses = [SP] in {
def AUTH_TCRETURN
: Pseudo<(outs), (ins tcGPR64:$dst, i32imm:$FPDiff, i32imm:$Key,
: Pseudo<(outs), (ins tcGPRnotx16x17:$dst, i32imm:$FPDiff, i32imm:$Key,
i64imm:$Disc, tcGPR64:$AddrDisc),
[]>, Sched<[WriteBrReg]>;
def AUTH_TCRETURN_BTI
: Pseudo<(outs), (ins tcGPRx16x17:$dst, i32imm:$FPDiff, i32imm:$Key,
i64imm:$Disc, tcGPR64:$AddrDisc),
i64imm:$Disc, tcGPRnotx16x17:$AddrDisc),
[]>, Sched<[WriteBrReg]>;
}

let Predicates = [TailCallAny] in
def : Pat<(AArch64authtcret tcGPR64:$dst, (i32 timm:$FPDiff), (i32 timm:$Key),
def : Pat<(AArch64authtcret tcGPRnotx16x17:$dst, (i32 timm:$FPDiff), (i32 timm:$Key),
(i64 timm:$Disc), tcGPR64:$AddrDisc),
(AUTH_TCRETURN tcGPR64:$dst, imm:$FPDiff, imm:$Key, imm:$Disc,
(AUTH_TCRETURN tcGPRnotx16x17:$dst, imm:$FPDiff, imm:$Key, imm:$Disc,
tcGPR64:$AddrDisc)>;

let Predicates = [TailCallX16X17] in
def : Pat<(AArch64authtcret tcGPRx16x17:$dst, (i32 timm:$FPDiff),
(i32 timm:$Key), (i64 timm:$Disc),
tcGPR64:$AddrDisc),
tcGPRnotx16x17:$AddrDisc),
(AUTH_TCRETURN_BTI tcGPRx16x17:$dst, imm:$FPDiff, imm:$Key,
imm:$Disc, tcGPR64:$AddrDisc)>;
imm:$Disc, tcGPRnotx16x17:$AddrDisc)>;
}

// v9.5-A pointer authentication extensions
Expand Down
Loading
Loading