Skip to content

Commit 1f5c0bf

Browse files
committed
[AArch64][PAC] Move emission of LR checks in tail calls to AsmPrinter
Move the emission of the checks performed on the authenticated LR value during tail calls to AArch64AsmPrinter class, so that different checker sequences can be reused by pseudo instructions expanded there. This adds one more option to AuthCheckMethod enumeration, the generic XPAC variant which is not restricted to checking the LR register.
1 parent 0a68171 commit 1f5c0bf

File tree

9 files changed

+192
-303
lines changed

9 files changed

+192
-303
lines changed

llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp

Lines changed: 112 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@ class AArch64AsmPrinter : public AsmPrinter {
153153
void emitPtrauthCheckAuthenticatedValue(Register TestedReg,
154154
Register ScratchReg,
155155
AArch64PACKey::ID Key,
156+
AArch64PAuth::AuthCheckMethod Method,
156157
bool ShouldTrap,
157158
const MCSymbol *OnFailure);
158159

@@ -1752,7 +1753,8 @@ unsigned AArch64AsmPrinter::emitPtrauthDiscriminator(uint16_t Disc,
17521753
/// of proceeding to the next instruction (only if ShouldTrap is false).
17531754
void AArch64AsmPrinter::emitPtrauthCheckAuthenticatedValue(
17541755
Register TestedReg, Register ScratchReg, AArch64PACKey::ID Key,
1755-
bool ShouldTrap, const MCSymbol *OnFailure) {
1756+
AArch64PAuth::AuthCheckMethod Method, bool ShouldTrap,
1757+
const MCSymbol *OnFailure) {
17561758
// Insert a sequence to check if authentication of TestedReg succeeded,
17571759
// such as:
17581760
//
@@ -1778,38 +1780,70 @@ void AArch64AsmPrinter::emitPtrauthCheckAuthenticatedValue(
17781780
// Lsuccess:
17791781
// ...
17801782
//
1781-
// This sequence is expensive, but we need more information to be able to
1782-
// do better.
1783-
//
1784-
// We can't TBZ the poison bit because EnhancedPAC2 XORs the PAC bits
1785-
// on failure.
1786-
// We can't TST the PAC bits because we don't always know how the address
1787-
// space is setup for the target environment (and the bottom PAC bit is
1788-
// based on that).
1789-
// Either way, we also don't always know whether TBI is enabled or not for
1790-
// the specific target environment.
1783+
// See the documentation on AuthCheckMethod enumeration constants for
1784+
// the specific code sequences that can be used to perform the check.
1785+
using AArch64PAuth::AuthCheckMethod;
17911786

1792-
unsigned XPACOpc = getXPACOpcodeForKey(Key);
1787+
if (Method == AuthCheckMethod::None)
1788+
return;
1789+
if (Method == AuthCheckMethod::DummyLoad) {
1790+
EmitToStreamer(MCInstBuilder(AArch64::LDRWui)
1791+
.addReg(getWRegFromXReg(ScratchReg))
1792+
.addReg(TestedReg)
1793+
.addImm(0));
1794+
assert(ShouldTrap && !OnFailure && "DummyLoad always traps on error");
1795+
return;
1796+
}
17931797

17941798
MCSymbol *SuccessSym = createTempSymbol("auth_success_");
1799+
if (Method == AuthCheckMethod::XPAC || Method == AuthCheckMethod::XPACHint) {
1800+
// mov Xscratch, Xtested
1801+
emitMovXReg(ScratchReg, TestedReg);
17951802

1796-
// mov Xscratch, Xtested
1797-
emitMovXReg(ScratchReg, TestedReg);
1798-
1799-
// xpac(i|d) Xscratch
1800-
EmitToStreamer(MCInstBuilder(XPACOpc).addReg(ScratchReg).addReg(ScratchReg));
1803+
if (Method == AuthCheckMethod::XPAC) {
1804+
// xpac(i|d) Xscratch
1805+
unsigned XPACOpc = getXPACOpcodeForKey(Key);
1806+
EmitToStreamer(
1807+
MCInstBuilder(XPACOpc).addReg(ScratchReg).addReg(ScratchReg));
1808+
} else {
1809+
// xpaclri
1810+
1811+
// Note that this method applies XPAC to TestedReg instead of ScratchReg.
1812+
assert(TestedReg == AArch64::LR &&
1813+
"XPACHint mode is only compatible with checking the LR register");
1814+
assert((Key == AArch64PACKey::IA || Key == AArch64PACKey::IB) &&
1815+
"XPACHint mode is only compatible with I-keys");
1816+
EmitToStreamer(MCInstBuilder(AArch64::XPACLRI));
1817+
}
18011818

1802-
// cmp Xtested, Xscratch
1803-
EmitToStreamer(MCInstBuilder(AArch64::SUBSXrs)
1804-
.addReg(AArch64::XZR)
1805-
.addReg(TestedReg)
1806-
.addReg(ScratchReg)
1807-
.addImm(0));
1819+
// cmp Xtested, Xscratch
1820+
EmitToStreamer(MCInstBuilder(AArch64::SUBSXrs)
1821+
.addReg(AArch64::XZR)
1822+
.addReg(TestedReg)
1823+
.addReg(ScratchReg)
1824+
.addImm(0));
18081825

1809-
// b.eq Lsuccess
1810-
EmitToStreamer(MCInstBuilder(AArch64::Bcc)
1811-
.addImm(AArch64CC::EQ)
1812-
.addExpr(MCSymbolRefExpr::create(SuccessSym, OutContext)));
1826+
// b.eq Lsuccess
1827+
EmitToStreamer(
1828+
MCInstBuilder(AArch64::Bcc)
1829+
.addImm(AArch64CC::EQ)
1830+
.addExpr(MCSymbolRefExpr::create(SuccessSym, OutContext)));
1831+
} else if (Method == AuthCheckMethod::HighBitsNoTBI) {
1832+
// eor Xscratch, Xtested, Xtested, lsl #1
1833+
EmitToStreamer(MCInstBuilder(AArch64::EORXrs)
1834+
.addReg(ScratchReg)
1835+
.addReg(TestedReg)
1836+
.addReg(TestedReg)
1837+
.addImm(1));
1838+
// tbz Xscratch, #62, Lsuccess
1839+
EmitToStreamer(
1840+
MCInstBuilder(AArch64::TBZX)
1841+
.addReg(ScratchReg)
1842+
.addImm(62)
1843+
.addExpr(MCSymbolRefExpr::create(SuccessSym, OutContext)));
1844+
} else {
1845+
llvm_unreachable("Unsupported check method");
1846+
}
18131847

18141848
if (ShouldTrap) {
18151849
assert(!OnFailure && "Cannot specify OnFailure with ShouldTrap");
@@ -1823,9 +1857,26 @@ void AArch64AsmPrinter::emitPtrauthCheckAuthenticatedValue(
18231857
// Note that this can introduce an authentication oracle (such as based on
18241858
// the high bits of the re-signed value).
18251859

1826-
// FIXME: Can we simply return the AUT result, already in TestedReg?
1827-
// mov Xtested, Xscratch
1828-
emitMovXReg(TestedReg, ScratchReg);
1860+
// FIXME: The XPAC method can be optimized by applying XPAC to TestedReg
1861+
// instead of ScratchReg, thus eliminating one `mov` instruction.
1862+
// Both XPAC and XPACHint can be further optimized by not using a
1863+
// conditional branch jumping over an unconditional one.
1864+
1865+
switch (Method) {
1866+
case AuthCheckMethod::XPACHint:
1867+
// LR is already XPAC-ed at this point.
1868+
break;
1869+
case AuthCheckMethod::XPAC:
1870+
// mov Xtested, Xscratch
1871+
emitMovXReg(TestedReg, ScratchReg);
1872+
break;
1873+
default:
1874+
// If Xtested was not XPAC-ed so far, emit XPAC here.
1875+
// xpac(i|d) Xtested
1876+
unsigned XPACOpc = getXPACOpcodeForKey(Key);
1877+
EmitToStreamer(
1878+
MCInstBuilder(XPACOpc).addReg(TestedReg).addReg(TestedReg));
1879+
}
18291880

18301881
if (OnFailure) {
18311882
// b Lend
@@ -1851,7 +1902,7 @@ void AArch64AsmPrinter::emitPtrauthAuthResign(const MachineInstr *MI) {
18511902
// ; sign x16 (if AUTPAC)
18521903
// Lend: ; if not trapping on failure
18531904
//
1854-
// with the checking sequence chosen depending on whether we should check
1905+
// with the checking sequence chosen depending on whether/how we should check
18551906
// the pointer and whether we should trap on failure.
18561907

18571908
// By default, auth/resign sequences check for auth failures.
@@ -1911,6 +1962,7 @@ void AArch64AsmPrinter::emitPtrauthAuthResign(const MachineInstr *MI) {
19111962
EndSym = createTempSymbol("resign_end_");
19121963

19131964
emitPtrauthCheckAuthenticatedValue(AArch64::X16, AArch64::X17, AUTKey,
1965+
AArch64PAuth::AuthCheckMethod::XPAC,
19141966
ShouldTrap, EndSym);
19151967
}
19161968

@@ -2391,11 +2443,34 @@ void AArch64AsmPrinter::emitInstruction(const MachineInstr *MI) {
23912443
OutStreamer->emitLabel(LOHLabel);
23922444
}
23932445

2446+
// With Pointer Authentication, it may be needed to explicitly check the
2447+
// authenticated value in LR when performing a tail call.
2448+
// Otherwise, the callee may re-sign the invalid return address,
2449+
// introducing a signing oracle.
2450+
auto CheckLRInTailCall = [this](Register CallDestinationReg) {
2451+
if (!AArch64FI->shouldSignReturnAddress(*MF))
2452+
return;
2453+
2454+
auto LRCheckMethod = STI->getAuthenticatedLRCheckMethod(*MF);
2455+
if (LRCheckMethod == AArch64PAuth::AuthCheckMethod::None)
2456+
return;
2457+
2458+
Register ScratchReg =
2459+
CallDestinationReg == AArch64::X16 ? AArch64::X17 : AArch64::X16;
2460+
AArch64PACKey::ID Key =
2461+
AArch64FI->shouldSignWithBKey() ? AArch64PACKey::IB : AArch64PACKey::IA;
2462+
emitPtrauthCheckAuthenticatedValue(
2463+
AArch64::LR, ScratchReg, Key, LRCheckMethod,
2464+
/*ShouldTrap=*/true, /*OnFailure=*/nullptr);
2465+
};
2466+
23942467
AArch64TargetStreamer *TS =
23952468
static_cast<AArch64TargetStreamer *>(OutStreamer->getTargetStreamer());
23962469
// Do any manual lowerings.
23972470
switch (MI->getOpcode()) {
23982471
default:
2472+
assert(!AArch64InstrInfo::isTailCallReturnInst(*MI) &&
2473+
"Unhandled tail call instruction");
23992474
break;
24002475
case AArch64::HINT: {
24012476
// CurrentPatchableFunctionEntrySym can be CurrentFnBegin only for
@@ -2539,6 +2614,8 @@ void AArch64AsmPrinter::emitInstruction(const MachineInstr *MI) {
25392614
? AArch64::X17
25402615
: AArch64::X16;
25412616

2617+
CheckLRInTailCall(MI->getOperand(0).getReg());
2618+
25422619
unsigned DiscReg = AddrDisc;
25432620
if (Disc) {
25442621
if (AddrDisc != AArch64::NoRegister) {
@@ -2569,13 +2646,17 @@ void AArch64AsmPrinter::emitInstruction(const MachineInstr *MI) {
25692646
case AArch64::TCRETURNrix17:
25702647
case AArch64::TCRETURNrinotx16:
25712648
case AArch64::TCRETURNriALL: {
2649+
CheckLRInTailCall(MI->getOperand(0).getReg());
2650+
25722651
MCInst TmpInst;
25732652
TmpInst.setOpcode(AArch64::BR);
25742653
TmpInst.addOperand(MCOperand::createReg(MI->getOperand(0).getReg()));
25752654
EmitToStreamer(*OutStreamer, TmpInst);
25762655
return;
25772656
}
25782657
case AArch64::TCRETURNdi: {
2658+
CheckLRInTailCall(AArch64::NoRegister);
2659+
25792660
MCOperand Dest;
25802661
MCInstLowering.lowerOperand(MI->getOperand(0), Dest);
25812662
MCInst TmpInst;

llvm/lib/Target/AArch64/AArch64InstrInfo.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,19 @@ unsigned AArch64InstrInfo::getInstSizeInBytes(const MachineInstr &MI) const {
107107
unsigned NumBytes = 0;
108108
const MCInstrDesc &Desc = MI.getDesc();
109109

110+
if (!MI.isBundle() && isTailCallReturnInst(MI)) {
111+
NumBytes = Desc.getSize() ? Desc.getSize() : 4;
112+
113+
const auto *MFI = MF->getInfo<AArch64FunctionInfo>();
114+
if (!MFI->shouldSignReturnAddress(MF))
115+
return NumBytes;
116+
117+
auto &STI = MF->getSubtarget<AArch64Subtarget>();
118+
auto Method = STI.getAuthenticatedLRCheckMethod(*MF);
119+
NumBytes += AArch64PAuth::getCheckerSizeInBytes(Method);
120+
return NumBytes;
121+
}
122+
110123
// Size should be preferably set in
111124
// llvm/lib/Target/AArch64/AArch64InstrInfo.td (default case).
112125
// Specific cases handle instructions of variable sizes

llvm/lib/Target/AArch64/AArch64InstrInfo.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1964,6 +1964,8 @@ let Predicates = [HasPAuth] in {
19641964
}
19651965

19661966
// Size 16: 4 fixed + 8 variable, to compute discriminator.
1967+
// The size returned by getInstSizeInBytes() is incremented according
1968+
// to the variant of LR check.
19671969
let isCall = 1, isTerminator = 1, isReturn = 1, isBarrier = 1, Size = 16,
19681970
Uses = [SP] in {
19691971
def AUTH_TCRETURN

0 commit comments

Comments
 (0)