diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp index 75da24e5cf18f..73293bb5f4a0e 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp @@ -4050,6 +4050,83 @@ static IntrinsicInst *findInitTrampoline(Value *Callee) { return nullptr; } +Instruction *InstCombinerImpl::foldPtrAuthIntrinsicCallee(CallBase &Call) { + const Value *Callee = Call.getCalledOperand(); + const auto *IPC = dyn_cast(Callee); + if (!IPC || !IPC->isNoopCast(DL)) + return nullptr; + + const auto *II = dyn_cast(IPC->getOperand(0)); + if (!II) + return nullptr; + + Intrinsic::ID IIID = II->getIntrinsicID(); + if (IIID != Intrinsic::ptrauth_resign && IIID != Intrinsic::ptrauth_sign) + return nullptr; + + // Isolate the ptrauth bundle from the others. + std::optional PtrAuthBundleOrNone; + SmallVector NewBundles; + for (unsigned BI = 0, BE = Call.getNumOperandBundles(); BI != BE; ++BI) { + OperandBundleUse Bundle = Call.getOperandBundleAt(BI); + if (Bundle.getTagID() == LLVMContext::OB_ptrauth) + PtrAuthBundleOrNone = Bundle; + else + NewBundles.emplace_back(Bundle); + } + + if (!PtrAuthBundleOrNone) + return nullptr; + + Value *NewCallee = nullptr; + switch (IIID) { + // call(ptrauth.resign(p)), ["ptrauth"()] -> call p, ["ptrauth"()] + // assuming the call bundle and the sign operands match. + case Intrinsic::ptrauth_resign: { + // Resign result key should match bundle. + if (II->getOperand(3) != PtrAuthBundleOrNone->Inputs[0]) + return nullptr; + // Resign result discriminator should match bundle. + if (II->getOperand(4) != PtrAuthBundleOrNone->Inputs[1]) + return nullptr; + + // Resign input (auth) key should also match: we can't change the key on + // the new call we're generating, because we don't know what keys are valid. + if (II->getOperand(1) != PtrAuthBundleOrNone->Inputs[0]) + return nullptr; + + Value *NewBundleOps[] = {II->getOperand(1), II->getOperand(2)}; + NewBundles.emplace_back("ptrauth", NewBundleOps); + NewCallee = II->getOperand(0); + break; + } + + // call(ptrauth.sign(p)), ["ptrauth"()] -> call p + // assuming the call bundle and the sign operands match. + // Non-ptrauth indirect calls are undesirable, but so is ptrauth.sign. + case Intrinsic::ptrauth_sign: { + // Sign key should match bundle. + if (II->getOperand(1) != PtrAuthBundleOrNone->Inputs[0]) + return nullptr; + // Sign discriminator should match bundle. + if (II->getOperand(2) != PtrAuthBundleOrNone->Inputs[1]) + return nullptr; + NewCallee = II->getOperand(0); + break; + } + default: + llvm_unreachable("unexpected intrinsic ID"); + } + + if (!NewCallee) + return nullptr; + + NewCallee = Builder.CreateBitOrPointerCast(NewCallee, Callee->getType()); + CallBase *NewCall = CallBase::Create(&Call, NewBundles); + NewCall->setCalledOperand(NewCallee); + return NewCall; +} + Instruction *InstCombinerImpl::foldPtrAuthConstantCallee(CallBase &Call) { auto *CPA = dyn_cast(Call.getCalledOperand()); if (!CPA) @@ -4238,6 +4315,10 @@ Instruction *InstCombinerImpl::visitCallBase(CallBase &Call) { if (IntrinsicInst *II = findInitTrampoline(Callee)) return transformCallThroughTrampoline(Call, *II); + // Combine calls involving pointer authentication intrinsics. + if (Instruction *NewCall = foldPtrAuthIntrinsicCallee(Call)) + return NewCall; + // Combine calls to ptrauth constants. if (Instruction *NewCall = foldPtrAuthConstantCallee(Call)) return NewCall; diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h index 0c2ef3ebf88dc..9d7c025ccff86 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h +++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h @@ -283,6 +283,14 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final Instruction *transformCallThroughTrampoline(CallBase &Call, IntrinsicInst &Tramp); + /// Try to optimize a call to the result of a ptrauth intrinsic, potentially + /// into the ptrauth call bundle: + /// - call(ptrauth.resign(p)), ["ptrauth"()] -> call p, ["ptrauth"()] + /// - call(ptrauth.sign(p)), ["ptrauth"()] -> call p + /// as long as the key/discriminator are the same in sign and auth-bundle, + /// and we don't change the key in the bundle (to a potentially-invalid key.) + Instruction *foldPtrAuthIntrinsicCallee(CallBase &Call); + /// Try to optimize a call to a ptrauth constant, into its ptrauth bundle: /// call(ptrauth(f)), ["ptrauth"()] -> call f /// as long as the key/discriminator are the same in constant and bundle. diff --git a/llvm/test/Transforms/InstCombine/ptrauth-intrinsics-call.ll b/llvm/test/Transforms/InstCombine/ptrauth-intrinsics-call.ll new file mode 100644 index 0000000000000..5e597c9155f40 --- /dev/null +++ b/llvm/test/Transforms/InstCombine/ptrauth-intrinsics-call.ll @@ -0,0 +1,142 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt < %s -passes=instcombine -S | FileCheck %s + +define i32 @test_ptrauth_call_sign(ptr %p) { +; CHECK-LABEL: @test_ptrauth_call_sign( +; CHECK-NEXT: [[V3:%.*]] = call i32 [[P:%.*]]() +; CHECK-NEXT: ret i32 [[V3]] +; + %v0 = ptrtoint ptr %p to i64 + %v1 = call i64 @llvm.ptrauth.sign(i64 %v0, i32 2, i64 5678) + %v2 = inttoptr i64 %v1 to ptr + %v3 = call i32 %v2() [ "ptrauth"(i32 2, i64 5678) ] + ret i32 %v3 +} + +define i32 @test_ptrauth_call_sign_otherbundle(ptr %p) { +; CHECK-LABEL: @test_ptrauth_call_sign_otherbundle( +; CHECK-NEXT: [[V3:%.*]] = call i32 [[P:%.*]]() [ "somebundle"(ptr null), "otherbundle"(i64 0) ] +; CHECK-NEXT: ret i32 [[V3]] +; + %v0 = ptrtoint ptr %p to i64 + %v1 = call i64 @llvm.ptrauth.sign(i64 %v0, i32 2, i64 5678) + %v2 = inttoptr i64 %v1 to ptr + %v3 = call i32 %v2() [ "somebundle"(ptr null), "ptrauth"(i32 2, i64 5678), "otherbundle"(i64 0) ] + ret i32 %v3 +} + +define i32 @test_ptrauth_call_resign(ptr %p) { +; CHECK-LABEL: @test_ptrauth_call_resign( +; CHECK-NEXT: [[V3:%.*]] = call i32 [[P:%.*]]() [ "ptrauth"(i32 1, i64 1234) ] +; CHECK-NEXT: ret i32 [[V3]] +; + %v0 = ptrtoint ptr %p to i64 + %v1 = call i64 @llvm.ptrauth.resign(i64 %v0, i32 1, i64 1234, i32 1, i64 5678) + %v2 = inttoptr i64 %v1 to ptr + %v3 = call i32 %v2() [ "ptrauth"(i32 1, i64 5678) ] + ret i32 %v3 +} + +define i32 @test_ptrauth_call_resign_blend(ptr %pp) { +; CHECK-LABEL: @test_ptrauth_call_resign_blend( +; CHECK-NEXT: [[V01:%.*]] = load ptr, ptr [[PP:%.*]], align 8 +; CHECK-NEXT: [[V6:%.*]] = call i32 [[V01]]() [ "ptrauth"(i32 1, i64 1234) ] +; CHECK-NEXT: ret i32 [[V6]] +; + %v0 = load ptr, ptr %pp, align 8 + %v1 = ptrtoint ptr %pp to i64 + %v2 = ptrtoint ptr %v0 to i64 + %v3 = call i64 @llvm.ptrauth.blend(i64 %v1, i64 5678) + %v4 = call i64 @llvm.ptrauth.resign(i64 %v2, i32 1, i64 1234, i32 1, i64 %v3) + %v5 = inttoptr i64 %v4 to ptr + %v6 = call i32 %v5() [ "ptrauth"(i32 1, i64 %v3) ] + ret i32 %v6 +} + +define i32 @test_ptrauth_call_resign_blend_2(ptr %pp) { +; CHECK-LABEL: @test_ptrauth_call_resign_blend_2( +; CHECK-NEXT: [[V01:%.*]] = load ptr, ptr [[PP:%.*]], align 8 +; CHECK-NEXT: [[V1:%.*]] = ptrtoint ptr [[PP]] to i64 +; CHECK-NEXT: [[V3:%.*]] = call i64 @llvm.ptrauth.blend(i64 [[V1]], i64 5678) +; CHECK-NEXT: [[V6:%.*]] = call i32 [[V01]]() [ "ptrauth"(i32 0, i64 [[V3]]) ] +; CHECK-NEXT: ret i32 [[V6]] +; + %v0 = load ptr, ptr %pp, align 8 + %v1 = ptrtoint ptr %pp to i64 + %v2 = ptrtoint ptr %v0 to i64 + %v3 = call i64 @llvm.ptrauth.blend(i64 %v1, i64 5678) + %v4 = call i64 @llvm.ptrauth.resign(i64 %v2, i32 0, i64 %v3, i32 0, i64 1234) + %v5 = inttoptr i64 %v4 to ptr + %v6 = call i32 %v5() [ "ptrauth"(i32 0, i64 1234) ] + ret i32 %v6 +} + +define i32 @test_ptrauth_call_resign_mismatch_key(ptr %p) { +; CHECK-LABEL: @test_ptrauth_call_resign_mismatch_key( +; CHECK-NEXT: [[V0:%.*]] = ptrtoint ptr [[P:%.*]] to i64 +; CHECK-NEXT: [[V1:%.*]] = call i64 @llvm.ptrauth.resign(i64 [[V0]], i32 1, i64 1234, i32 0, i64 5678) +; CHECK-NEXT: [[V2:%.*]] = inttoptr i64 [[V1]] to ptr +; CHECK-NEXT: [[V3:%.*]] = call i32 [[V2]]() [ "ptrauth"(i32 1, i64 5678) ] +; CHECK-NEXT: ret i32 [[V3]] +; + %v0 = ptrtoint ptr %p to i64 + %v1 = call i64 @llvm.ptrauth.resign(i64 %v0, i32 1, i64 1234, i32 0, i64 5678) + %v2 = inttoptr i64 %v1 to ptr + %v3 = call i32 %v2() [ "ptrauth"(i32 1, i64 5678) ] + ret i32 %v3 +} + +define i32 @test_ptrauth_call_resign_mismatch_disc(ptr %p) { +; CHECK-LABEL: @test_ptrauth_call_resign_mismatch_disc( +; CHECK-NEXT: [[V0:%.*]] = ptrtoint ptr [[P:%.*]] to i64 +; CHECK-NEXT: [[V1:%.*]] = call i64 @llvm.ptrauth.resign(i64 [[V0]], i32 1, i64 1234, i32 0, i64 9900) +; CHECK-NEXT: [[V2:%.*]] = inttoptr i64 [[V1]] to ptr +; CHECK-NEXT: [[V3:%.*]] = call i32 [[V2]]() [ "ptrauth"(i32 1, i64 5678) ] +; CHECK-NEXT: ret i32 [[V3]] +; + %v0 = ptrtoint ptr %p to i64 + %v1 = call i64 @llvm.ptrauth.resign(i64 %v0, i32 1, i64 1234, i32 0, i64 9900) + %v2 = inttoptr i64 %v1 to ptr + %v3 = call i32 %v2() [ "ptrauth"(i32 1, i64 5678) ] + ret i32 %v3 +} + +define i32 @test_ptrauth_call_resign_mismatch_blend(ptr %pp) { +; CHECK-LABEL: @test_ptrauth_call_resign_mismatch_blend( +; CHECK-NEXT: [[V0:%.*]] = load ptr, ptr [[PP:%.*]], align 8 +; CHECK-NEXT: [[V1:%.*]] = ptrtoint ptr [[PP]] to i64 +; CHECK-NEXT: [[V2:%.*]] = ptrtoint ptr [[V0]] to i64 +; CHECK-NEXT: [[V6:%.*]] = call i64 @llvm.ptrauth.blend(i64 [[V1]], i64 5678) +; CHECK-NEXT: [[V4:%.*]] = call i64 @llvm.ptrauth.resign(i64 [[V2]], i32 1, i64 1234, i32 1, i64 [[V6]]) +; CHECK-NEXT: [[V5:%.*]] = inttoptr i64 [[V4]] to ptr +; CHECK-NEXT: [[V3:%.*]] = call i32 [[V5]]() [ "ptrauth"(i32 1, i64 [[V1]]) ] +; CHECK-NEXT: ret i32 [[V3]] +; + %v0 = load ptr, ptr %pp, align 8 + %v1 = ptrtoint ptr %pp to i64 + %v2 = ptrtoint ptr %v0 to i64 + %v3 = call i64 @llvm.ptrauth.blend(i64 %v1, i64 5678) + %v4 = call i64 @llvm.ptrauth.resign(i64 %v2, i32 1, i64 1234, i32 1, i64 %v3) + %v5 = inttoptr i64 %v4 to ptr + %v6 = call i32 %v5() [ "ptrauth"(i32 1, i64 %v1) ] + ret i32 %v6 +} + +define i32 @test_ptrauth_call_resign_changing_call_key(ptr %p) { +; CHECK-LABEL: @test_ptrauth_call_resign_changing_call_key( +; CHECK-NEXT: [[V0:%.*]] = ptrtoint ptr [[P:%.*]] to i64 +; CHECK-NEXT: [[V1:%.*]] = call i64 @llvm.ptrauth.resign(i64 [[V0]], i32 2, i64 1234, i32 1, i64 5678) +; CHECK-NEXT: [[V2:%.*]] = inttoptr i64 [[V1]] to ptr +; CHECK-NEXT: [[V3:%.*]] = call i32 [[V2]]() [ "ptrauth"(i32 1, i64 5678) ] +; CHECK-NEXT: ret i32 [[V3]] +; + %v0 = ptrtoint ptr %p to i64 + %v1 = call i64 @llvm.ptrauth.resign(i64 %v0, i32 2, i64 1234, i32 1, i64 5678) + %v2 = inttoptr i64 %v1 to ptr + %v3 = call i32 %v2() [ "ptrauth"(i32 1, i64 5678) ] + ret i32 %v3 +} + +declare i64 @llvm.ptrauth.sign(i64, i32, i64) +declare i64 @llvm.ptrauth.resign(i64, i32, i64, i32, i64) +declare i64 @llvm.ptrauth.blend(i64, i64)