Skip to content

Commit 2c4e9ff

Browse files
committed
Check both register operands of AUTH_TCRETURN*
1 parent 1f5c0bf commit 2c4e9ff

File tree

4 files changed

+42
-28
lines changed

4 files changed

+42
-28
lines changed

llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp

Lines changed: 30 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,9 @@ class AArch64AsmPrinter : public AsmPrinter {
157157
bool ShouldTrap,
158158
const MCSymbol *OnFailure);
159159

160+
// Check authenticated LR before tail calling.
161+
void emitPtrauthTailCallHardening(const MachineInstr *TC);
162+
160163
// Emit the sequence for AUT or AUTPAC.
161164
void emitPtrauthAuthResign(const MachineInstr *MI);
162165

@@ -1891,6 +1894,30 @@ void AArch64AsmPrinter::emitPtrauthCheckAuthenticatedValue(
18911894
OutStreamer->emitLabel(SuccessSym);
18921895
}
18931896

1897+
// With Pointer Authentication, it may be needed to explicitly check the
1898+
// authenticated value in LR before performing a tail call.
1899+
// Otherwise, the callee may re-sign the invalid return address,
1900+
// introducing a signing oracle.
1901+
void AArch64AsmPrinter::emitPtrauthTailCallHardening(const MachineInstr *TC) {
1902+
if (!AArch64FI->shouldSignReturnAddress(*MF))
1903+
return;
1904+
1905+
auto LRCheckMethod = STI->getAuthenticatedLRCheckMethod(*MF);
1906+
if (LRCheckMethod == AArch64PAuth::AuthCheckMethod::None)
1907+
return;
1908+
1909+
const AArch64RegisterInfo *TRI = STI->getRegisterInfo();
1910+
Register ScratchReg =
1911+
TC->readsRegister(AArch64::X16, TRI) ? AArch64::X17 : AArch64::X16;
1912+
assert(!TC->readsRegister(ScratchReg, TRI) &&
1913+
"Neither x16 nor x17 is available as a scratch register");
1914+
AArch64PACKey::ID Key =
1915+
AArch64FI->shouldSignWithBKey() ? AArch64PACKey::IB : AArch64PACKey::IA;
1916+
emitPtrauthCheckAuthenticatedValue(
1917+
AArch64::LR, ScratchReg, Key, LRCheckMethod,
1918+
/*ShouldTrap=*/true, /*OnFailure=*/nullptr);
1919+
}
1920+
18941921
void AArch64AsmPrinter::emitPtrauthAuthResign(const MachineInstr *MI) {
18951922
const bool IsAUTPAC = MI->getOpcode() == AArch64::AUTPAC;
18961923

@@ -2443,27 +2470,6 @@ void AArch64AsmPrinter::emitInstruction(const MachineInstr *MI) {
24432470
OutStreamer->emitLabel(LOHLabel);
24442471
}
24452472

2446-
// With Pointer Authentication, it may be needed to explicitly check the
2447-
// authenticated value in LR when performing a tail call.
2448-
// Otherwise, the callee may re-sign the invalid return address,
2449-
// introducing a signing oracle.
2450-
auto CheckLRInTailCall = [this](Register CallDestinationReg) {
2451-
if (!AArch64FI->shouldSignReturnAddress(*MF))
2452-
return;
2453-
2454-
auto LRCheckMethod = STI->getAuthenticatedLRCheckMethod(*MF);
2455-
if (LRCheckMethod == AArch64PAuth::AuthCheckMethod::None)
2456-
return;
2457-
2458-
Register ScratchReg =
2459-
CallDestinationReg == AArch64::X16 ? AArch64::X17 : AArch64::X16;
2460-
AArch64PACKey::ID Key =
2461-
AArch64FI->shouldSignWithBKey() ? AArch64PACKey::IB : AArch64PACKey::IA;
2462-
emitPtrauthCheckAuthenticatedValue(
2463-
AArch64::LR, ScratchReg, Key, LRCheckMethod,
2464-
/*ShouldTrap=*/true, /*OnFailure=*/nullptr);
2465-
};
2466-
24672473
AArch64TargetStreamer *TS =
24682474
static_cast<AArch64TargetStreamer *>(OutStreamer->getTargetStreamer());
24692475
// Do any manual lowerings.
@@ -2614,7 +2620,7 @@ void AArch64AsmPrinter::emitInstruction(const MachineInstr *MI) {
26142620
? AArch64::X17
26152621
: AArch64::X16;
26162622

2617-
CheckLRInTailCall(MI->getOperand(0).getReg());
2623+
emitPtrauthTailCallHardening(MI);
26182624

26192625
unsigned DiscReg = AddrDisc;
26202626
if (Disc) {
@@ -2646,7 +2652,7 @@ void AArch64AsmPrinter::emitInstruction(const MachineInstr *MI) {
26462652
case AArch64::TCRETURNrix17:
26472653
case AArch64::TCRETURNrinotx16:
26482654
case AArch64::TCRETURNriALL: {
2649-
CheckLRInTailCall(MI->getOperand(0).getReg());
2655+
emitPtrauthTailCallHardening(MI);
26502656

26512657
MCInst TmpInst;
26522658
TmpInst.setOpcode(AArch64::BR);
@@ -2655,7 +2661,7 @@ void AArch64AsmPrinter::emitInstruction(const MachineInstr *MI) {
26552661
return;
26562662
}
26572663
case AArch64::TCRETURNdi: {
2658-
CheckLRInTailCall(AArch64::NoRegister);
2664+
emitPtrauthTailCallHardening(MI);
26592665

26602666
MCOperand Dest;
26612667
MCInstLowering.lowerOperand(MI->getOperand(0), Dest);

llvm/lib/Target/AArch64/AArch64InstrInfo.td

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1966,15 +1966,19 @@ let Predicates = [HasPAuth] in {
19661966
// Size 16: 4 fixed + 8 variable, to compute discriminator.
19671967
// The size returned by getInstSizeInBytes() is incremented according
19681968
// to the variant of LR check.
1969+
// As the check requires either x16 or x17 as a scratch register and
1970+
// authenticated tail call instructions have two register operands,
1971+
// make sure at least one register is usable as a scratch one - for that
1972+
// purpose, use tcGPRnotx16x17 register class for the second operand.
19691973
let isCall = 1, isTerminator = 1, isReturn = 1, isBarrier = 1, Size = 16,
19701974
Uses = [SP] in {
19711975
def AUTH_TCRETURN
19721976
: Pseudo<(outs), (ins tcGPR64:$dst, i32imm:$FPDiff, i32imm:$Key,
1973-
i64imm:$Disc, tcGPR64:$AddrDisc),
1977+
i64imm:$Disc, tcGPRnotx16x17:$AddrDisc),
19741978
[]>, Sched<[WriteBrReg]>;
19751979
def AUTH_TCRETURN_BTI
19761980
: Pseudo<(outs), (ins tcGPRx16x17:$dst, i32imm:$FPDiff, i32imm:$Key,
1977-
i64imm:$Disc, tcGPR64:$AddrDisc),
1981+
i64imm:$Disc, tcGPRnotx16x17:$AddrDisc),
19781982
[]>, Sched<[WriteBrReg]>;
19791983
}
19801984

llvm/lib/Target/AArch64/AArch64RegisterInfo.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,10 @@ def tcGPR64 : RegisterClass<"AArch64", [i64], 64, (sub GPR64common, X19, X20, X2
249249
def tcGPRx17 : RegisterClass<"AArch64", [i64], 64, (add X17)>;
250250
def tcGPRx16x17 : RegisterClass<"AArch64", [i64], 64, (add X16, X17)>;
251251
def tcGPRnotx16 : RegisterClass<"AArch64", [i64], 64, (sub tcGPR64, X16)>;
252+
// LR checking code expects either x16 or x17 to be available as a scratch
253+
// register - for that reason restrict one of two register operands of
254+
// AUTH_TCRETURN* pseudos.
255+
def tcGPRnotx16x17 : RegisterClass<"AArch64", [i64], 64, (sub tcGPR64, X16, X17)>;
252256

253257
// Register set that excludes registers that are reserved for procedure calls.
254258
// This is used for pseudo-instructions that are actually implemented using a

llvm/test/CodeGen/AArch64/ptrauth-call.ll

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -173,9 +173,9 @@ define void @test_tailcall_omit_mov_x16_x16(ptr %objptr) #0 {
173173
; CHECK: mov x17, x0
174174
; CHECK: movk x17, #6503, lsl #48
175175
; CHECK: autda x16, x17
176-
; CHECK: ldr x1, [x16]
176+
; CHECK: ldr x2, [x16]
177177
; CHECK: movk x16, #54167, lsl #48
178-
; CHECK: braa x1, x16
178+
; CHECK: braa x2, x16
179179
%vtable.signed = load ptr, ptr %objptr, align 8
180180
%objptr.int = ptrtoint ptr %objptr to i64
181181
%vtable.discr = tail call i64 @llvm.ptrauth.blend(i64 %objptr.int, i64 6503)

0 commit comments

Comments
 (0)