Skip to content

Commit 42d2ae1

Browse files
[InstCombine] Combine ptrauth constant callee into bundle. (#94706)
Try to optimize a call to a ptrauth constant, into its ptrauth bundle: call(ptrauth(f)), ["ptrauth"()] -> call f as long as the key/discriminator are the same in constant and bundle.
1 parent 1a940bf commit 42d2ae1

File tree

3 files changed

+134
-0
lines changed

3 files changed

+134
-0
lines changed

llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4050,6 +4050,34 @@ static IntrinsicInst *findInitTrampoline(Value *Callee) {
40504050
return nullptr;
40514051
}
40524052

4053+
Instruction *InstCombinerImpl::foldPtrAuthConstantCallee(CallBase &Call) {
4054+
auto *CPA = dyn_cast<ConstantPtrAuth>(Call.getCalledOperand());
4055+
if (!CPA)
4056+
return nullptr;
4057+
4058+
auto *CalleeF = dyn_cast<Function>(CPA->getPointer());
4059+
// If the ptrauth constant isn't based on a function pointer, bail out.
4060+
if (!CalleeF)
4061+
return nullptr;
4062+
4063+
// Inspect the call ptrauth bundle to check it matches the ptrauth constant.
4064+
auto PAB = Call.getOperandBundle(LLVMContext::OB_ptrauth);
4065+
if (!PAB)
4066+
return nullptr;
4067+
4068+
auto *Key = cast<ConstantInt>(PAB->Inputs[0]);
4069+
Value *Discriminator = PAB->Inputs[1];
4070+
4071+
// If the bundle doesn't match, this is probably going to fail to auth.
4072+
if (!CPA->isKnownCompatibleWith(Key, Discriminator, DL))
4073+
return nullptr;
4074+
4075+
// If the bundle matches the constant, proceed in making this a direct call.
4076+
auto *NewCall = CallBase::removeOperandBundle(&Call, LLVMContext::OB_ptrauth);
4077+
NewCall->setCalledOperand(CalleeF);
4078+
return NewCall;
4079+
}
4080+
40534081
bool InstCombinerImpl::annotateAnyAllocSite(CallBase &Call,
40544082
const TargetLibraryInfo *TLI) {
40554083
// Note: We only handle cases which can't be driven from generic attributes
@@ -4210,6 +4238,10 @@ Instruction *InstCombinerImpl::visitCallBase(CallBase &Call) {
42104238
if (IntrinsicInst *II = findInitTrampoline(Callee))
42114239
return transformCallThroughTrampoline(Call, *II);
42124240

4241+
// Combine calls to ptrauth constants.
4242+
if (Instruction *NewCall = foldPtrAuthConstantCallee(Call))
4243+
return NewCall;
4244+
42134245
if (isa<InlineAsm>(Callee) && !Call.doesNotThrow()) {
42144246
InlineAsm *IA = cast<InlineAsm>(Callee);
42154247
if (!IA->canThrow()) {

llvm/lib/Transforms/InstCombine/InstCombineInternal.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,11 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
283283
Instruction *transformCallThroughTrampoline(CallBase &Call,
284284
IntrinsicInst &Tramp);
285285

286+
/// Try to optimize a call to a ptrauth constant, into its ptrauth bundle:
287+
/// call(ptrauth(f)), ["ptrauth"()] -> call f
288+
/// as long as the key/discriminator are the same in constant and bundle.
289+
Instruction *foldPtrAuthConstantCallee(CallBase &Call);
290+
286291
// Return (a, b) if (LHS, RHS) is known to be (a, b) or (b, a).
287292
// Otherwise, return std::nullopt
288293
// Currently it matches:
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
2+
; RUN: opt < %s -passes=instcombine -S | FileCheck %s
3+
4+
declare i64 @f(i32)
5+
declare ptr @f2(i32)
6+
7+
define i32 @test_ptrauth_call(i32 %a0) {
8+
; CHECK-LABEL: @test_ptrauth_call(
9+
; CHECK-NEXT: [[V0:%.*]] = call i32 @f(i32 [[A0:%.*]])
10+
; CHECK-NEXT: ret i32 [[V0]]
11+
;
12+
%v0 = call i32 ptrauth(ptr @f, i32 0)(i32 %a0) [ "ptrauth"(i32 0, i64 0) ]
13+
ret i32 %v0
14+
}
15+
16+
define i32 @test_ptrauth_call_disc(i32 %a0) {
17+
; CHECK-LABEL: @test_ptrauth_call_disc(
18+
; CHECK-NEXT: [[V0:%.*]] = call i32 @f(i32 [[A0:%.*]])
19+
; CHECK-NEXT: ret i32 [[V0]]
20+
;
21+
%v0 = call i32 ptrauth(ptr @f, i32 1, i64 5678)(i32 %a0) [ "ptrauth"(i32 1, i64 5678) ]
22+
ret i32 %v0
23+
}
24+
25+
@f_addr_disc.ref = constant ptr ptrauth(ptr @f, i32 1, i64 0, ptr @f_addr_disc.ref)
26+
27+
define i32 @test_ptrauth_call_addr_disc(i32 %a0) {
28+
; CHECK-LABEL: @test_ptrauth_call_addr_disc(
29+
; CHECK-NEXT: [[V0:%.*]] = call i32 @f(i32 [[A0:%.*]])
30+
; CHECK-NEXT: ret i32 [[V0]]
31+
;
32+
%v0 = call i32 ptrauth(ptr @f, i32 1, i64 0, ptr @f_addr_disc.ref)(i32 %a0) [ "ptrauth"(i32 1, i64 ptrtoint (ptr @f_addr_disc.ref to i64)) ]
33+
ret i32 %v0
34+
}
35+
36+
@f_both_disc.ref = constant ptr ptrauth(ptr @f, i32 1, i64 1234, ptr @f_both_disc.ref)
37+
38+
define i32 @test_ptrauth_call_blend(i32 %a0) {
39+
; CHECK-LABEL: @test_ptrauth_call_blend(
40+
; CHECK-NEXT: [[V0:%.*]] = call i32 @f(i32 [[A0:%.*]])
41+
; CHECK-NEXT: ret i32 [[V0]]
42+
;
43+
%v = call i64 @llvm.ptrauth.blend(i64 ptrtoint (ptr @f_both_disc.ref to i64), i64 1234)
44+
%v0 = call i32 ptrauth(ptr @f, i32 1, i64 1234, ptr @f_both_disc.ref)(i32 %a0) [ "ptrauth"(i32 1, i64 %v) ]
45+
ret i32 %v0
46+
}
47+
48+
define i64 @test_ptrauth_call_cast(i32 %a0) {
49+
; CHECK-LABEL: @test_ptrauth_call_cast(
50+
; CHECK-NEXT: [[V0:%.*]] = call i64 @f2(i32 [[A0:%.*]])
51+
; CHECK-NEXT: ret i64 [[V0]]
52+
;
53+
%v0 = call i64 ptrauth(ptr @f2, i32 0)(i32 %a0) [ "ptrauth"(i32 0, i64 0) ]
54+
ret i64 %v0
55+
}
56+
57+
define i32 @test_ptrauth_call_mismatch_key(i32 %a0) {
58+
; CHECK-LABEL: @test_ptrauth_call_mismatch_key(
59+
; CHECK-NEXT: [[V0:%.*]] = call i32 ptrauth (ptr @f, i32 1, i64 5678)(i32 [[A0:%.*]]) [ "ptrauth"(i32 0, i64 5678) ]
60+
; CHECK-NEXT: ret i32 [[V0]]
61+
;
62+
%v0 = call i32 ptrauth(ptr @f, i32 1, i64 5678)(i32 %a0) [ "ptrauth"(i32 0, i64 5678) ]
63+
ret i32 %v0
64+
}
65+
66+
define i32 @test_ptrauth_call_mismatch_disc(i32 %a0) {
67+
; CHECK-LABEL: @test_ptrauth_call_mismatch_disc(
68+
; CHECK-NEXT: [[V0:%.*]] = call i32 ptrauth (ptr @f, i32 1, i64 5678)(i32 [[A0:%.*]]) [ "ptrauth"(i32 1, i64 0) ]
69+
; CHECK-NEXT: ret i32 [[V0]]
70+
;
71+
%v0 = call i32 ptrauth(ptr @f, i32 1, i64 5678)(i32 %a0) [ "ptrauth"(i32 1, i64 0) ]
72+
ret i32 %v0
73+
}
74+
75+
define i32 @test_ptrauth_call_mismatch_blend(i32 %a0) {
76+
; CHECK-LABEL: @test_ptrauth_call_mismatch_blend(
77+
; CHECK-NEXT: [[V:%.*]] = call i64 @llvm.ptrauth.blend(i64 ptrtoint (ptr @f_both_disc.ref to i64), i64 0)
78+
; CHECK-NEXT: [[V0:%.*]] = call i32 ptrauth (ptr @f, i32 1, i64 1234, ptr @f_both_disc.ref)(i32 [[A0:%.*]]) [ "ptrauth"(i32 1, i64 [[V]]) ]
79+
; CHECK-NEXT: ret i32 [[V0]]
80+
;
81+
%v = call i64 @llvm.ptrauth.blend(i64 ptrtoint (ptr @f_both_disc.ref to i64), i64 0)
82+
%v0 = call i32 ptrauth(ptr @f, i32 1, i64 1234, ptr @f_both_disc.ref)(i32 %a0) [ "ptrauth"(i32 1, i64 %v) ]
83+
ret i32 %v0
84+
}
85+
86+
define i32 @test_ptrauth_call_mismatch_blend_addr(i32 %a0) {
87+
; CHECK-LABEL: @test_ptrauth_call_mismatch_blend_addr(
88+
; CHECK-NEXT: [[V:%.*]] = call i64 @llvm.ptrauth.blend(i64 ptrtoint (ptr @f_addr_disc.ref to i64), i64 1234)
89+
; CHECK-NEXT: [[V0:%.*]] = call i32 ptrauth (ptr @f, i32 1, i64 1234, ptr @f_both_disc.ref)(i32 [[A0:%.*]]) [ "ptrauth"(i32 1, i64 [[V]]) ]
90+
; CHECK-NEXT: ret i32 [[V0]]
91+
;
92+
%v = call i64 @llvm.ptrauth.blend(i64 ptrtoint (ptr @f_addr_disc.ref to i64), i64 1234)
93+
%v0 = call i32 ptrauth(ptr @f, i32 1, i64 1234, ptr @f_both_disc.ref)(i32 %a0) [ "ptrauth"(i32 1, i64 %v) ]
94+
ret i32 %v0
95+
}
96+
97+
declare i64 @llvm.ptrauth.blend(i64, i64)

0 commit comments

Comments
 (0)