Skip to content

Commit 302135e

Browse files
committed
[clang] Implement pointer authentication for C++ member function pointers.
1 parent 0a5b1f1 commit 302135e

File tree

10 files changed

+812
-17
lines changed

10 files changed

+812
-17
lines changed

clang/include/clang/Basic/Features.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ FEATURE(ptrauth_objc_method_list_pointer, LangOpts.PointerAuthCalls)
109109
FEATURE(ptrauth_vtable_pointer_address_discrimination, LangOpts.PointerAuthVTPtrAddressDiscrimination)
110110
FEATURE(ptrauth_vtable_pointer_type_discrimination, LangOpts.PointerAuthVTPtrTypeDiscrimination)
111111
FEATURE(ptrauth_returns, LangOpts.PointerAuthReturns)
112+
FEATURE(ptrauth_member_function_pointer_type_discrimination, LangOpts.PointerAuthCalls)
112113
FEATURE(ptrauth_function_pointer_type_discrimination, LangOpts.FunctionPointerTypeDiscrimination)
113114
FEATURE(ptrauth_signed_block_descriptors, LangOpts.PointerAuthBlockDescriptorPointers)
114115
FEATURE(swiftasynccc,

clang/include/clang/Basic/PointerAuthOptions.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,9 @@ struct PointerAuthOptions {
200200

201201
/// The ABI for variadic C++ virtual function pointers.
202202
PointerAuthSchema CXXVirtualVariadicFunctionPointers;
203+
204+
/// The ABI for C++ member function pointers.
205+
PointerAuthSchema CXXMemberFunctionPointers;
203206
};
204207

205208
} // end namespace clang

clang/lib/CodeGen/CGCall.cpp

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4980,7 +4980,8 @@ RValue CodeGenFunction::EmitCall(const CGFunctionInfo &CallInfo,
49804980
ReturnValueSlot ReturnValue,
49814981
const CallArgList &CallArgs,
49824982
llvm::CallBase **callOrInvoke, bool IsMustTail,
4983-
SourceLocation Loc) {
4983+
SourceLocation Loc,
4984+
bool IsVirtualFunctionPointerThunk) {
49844985
// FIXME: We no longer need the types from CallArgs; lift up and simplify.
49854986

49864987
assert(Callee.isOrdinary() || Callee.isVirtual());
@@ -5043,8 +5044,12 @@ RValue CodeGenFunction::EmitCall(const CGFunctionInfo &CallInfo,
50435044
RawAddress SRetAlloca = RawAddress::invalid();
50445045
llvm::Value *UnusedReturnSizePtr = nullptr;
50455046
if (RetAI.isIndirect() || RetAI.isInAlloca() || RetAI.isCoerceAndExpand()) {
5046-
if (!ReturnValue.isNull()) {
5047-
SRetPtr = ReturnValue.getValue();
5047+
if (IsVirtualFunctionPointerThunk && RetAI.isIndirect()) {
5048+
SRetPtr = makeNaturalAddressForPointer(CurFn->arg_begin() +
5049+
IRFunctionArgs.getSRetArgNo(),
5050+
RetTy, CharUnits::fromQuantity(1));
5051+
} else if (!ReturnValue.isNull()) {
5052+
SRetPtr = ReturnValue.getAddress();
50485053
} else {
50495054
SRetPtr = CreateMemTemp(RetTy, "tmp", &SRetAlloca);
50505055
if (HaveInsertPoint() && ReturnValue.isUnused()) {
@@ -5776,7 +5781,14 @@ RValue CodeGenFunction::EmitCall(const CGFunctionInfo &CallInfo,
57765781
CallArgs.freeArgumentMemory(*this);
57775782

57785783
// Extract the return value.
5779-
RValue Ret = [&] {
5784+
RValue Ret;
5785+
5786+
// If the current function is a virtual function pointer thunk, avoid copying
5787+
// the return value of the musttail call to a temporary.
5788+
if (IsVirtualFunctionPointerThunk)
5789+
Ret = RValue::get(CI);
5790+
else
5791+
Ret = [&] {
57805792
switch (RetAI.getKind()) {
57815793
case ABIArgInfo::CoerceAndExpand: {
57825794
auto coercionType = RetAI.getCoerceAndExpandType();

clang/lib/CodeGen/CGPointerAuth.cpp

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,23 @@ CodeGenModule::getFunctionPointerAuthInfo(QualType T) {
106106
/* authenticatesNullValues */ false, Discriminator);
107107
}
108108

109+
CGPointerAuthInfo
110+
CodeGenModule::getMemberFunctionPointerAuthInfo(QualType functionType) {
111+
assert(functionType->getAs<MemberPointerType>() &&
112+
"MemberPointerType expected");
113+
auto &schema = getCodeGenOpts().PointerAuth.CXXMemberFunctionPointers;
114+
if (!schema)
115+
return CGPointerAuthInfo();
116+
117+
assert(!schema.isAddressDiscriminated() &&
118+
"function pointers cannot use address-specific discrimination");
119+
120+
auto discriminator =
121+
getPointerAuthOtherDiscriminator(schema, GlobalDecl(), functionType);
122+
return CGPointerAuthInfo(schema.getKey(), schema.getAuthenticationMode(),
123+
/* authenticatesNullValues */ false, discriminator);
124+
}
125+
109126
/// Return the natural pointer authentication for values of the given
110127
/// pointee type.
111128
static CGPointerAuthInfo
@@ -812,6 +829,7 @@ void ConstantAggregateBuilderBase::addSignedPointer(
812829
void CodeGenModule::destroyConstantSignedPointerCaches() {
813830
destroyCache<ByConstantCacheTy>(ConstantSignedPointersByConstant);
814831
destroyCache<FunctionPointerCacheTy>(SignedFunctionPointersByDeclAndType);
832+
destroyCache<ByDeclCacheTy>(SignedThunkPointers);
815833
}
816834

817835
/// If applicable, sign a given constant function pointer with the ABI rules for
@@ -869,6 +887,40 @@ llvm::Constant *CodeGenModule::getFunctionPointer(GlobalDecl GD,
869887
return getFunctionPointer(getRawFunctionPointer(GD, Ty), FuncType, GD);
870888
}
871889

890+
llvm::Constant *
891+
CodeGenModule::getMemberFunctionPointer(llvm::Constant *pointer,
892+
QualType functionType,
893+
const FunctionDecl *FD) {
894+
if (auto pointerAuth = getMemberFunctionPointerAuthInfo(functionType)) {
895+
llvm::Constant **entry = nullptr;
896+
if (FD) {
897+
auto &cache =
898+
getOrCreateCache<ByDeclCacheTy>(SignedThunkPointers);
899+
entry = &cache[FD->getCanonicalDecl()];
900+
if (*entry)
901+
return llvm::ConstantExpr::getBitCast(*entry, pointer->getType());
902+
}
903+
904+
pointer = getConstantSignedPointer(
905+
pointer, pointerAuth.getKey(), nullptr,
906+
cast_or_null<llvm::Constant>(pointerAuth.getDiscriminator()));
907+
908+
if (entry)
909+
*entry = pointer;
910+
}
911+
912+
return pointer;
913+
}
914+
915+
llvm::Constant *
916+
CodeGenModule::getMemberFunctionPointer(const FunctionDecl *FD, llvm::Type *Ty) {
917+
QualType functionType = FD->getType();
918+
functionType = getContext().getMemberPointerType(
919+
functionType, cast<CXXMethodDecl>(FD)->getParent()->getTypeForDecl());
920+
return getMemberFunctionPointer(getRawFunctionPointer(FD, Ty), functionType,
921+
FD);
922+
}
923+
872924
llvm::Value *CodeGenFunction::AuthPointerToPointerCast(llvm::Value *ResultPtr,
873925
QualType SourceType,
874926
QualType DestType) {

clang/lib/CodeGen/CodeGenFunction.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4173,7 +4173,8 @@ class CodeGenFunction : public CodeGenTypeCache {
41734173
RValue EmitCall(const CGFunctionInfo &CallInfo, const CGCallee &Callee,
41744174
ReturnValueSlot ReturnValue, const CallArgList &Args,
41754175
llvm::CallBase **callOrInvoke, bool IsMustTail,
4176-
SourceLocation Loc);
4176+
SourceLocation Loc,
4177+
bool IsVirtualFunctionPointerThunk = false);
41774178
RValue EmitCall(const CGFunctionInfo &CallInfo, const CGCallee &Callee,
41784179
ReturnValueSlot ReturnValue, const CallArgList &Args,
41794180
llvm::CallBase **callOrInvoke = nullptr,

clang/lib/CodeGen/CodeGenModule.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -441,6 +441,7 @@ class CodeGenModule : public CodeGenTypeCache {
441441

442442
/// Signed constant pointers.
443443
void *SignedFunctionPointersByDeclAndType = nullptr;
444+
void *SignedThunkPointers = nullptr;
444445
void *ConstantSignedPointersByConstant = nullptr;
445446

446447
llvm::StringMap<llvm::GlobalVariable *> CFConstantStringMap;
@@ -964,6 +965,13 @@ class CodeGenModule : public CodeGenTypeCache {
964965
QualType functionType,
965966
GlobalDecl GD = GlobalDecl());
966967

968+
llvm::Constant *getMemberFunctionPointer(const FunctionDecl *FD,
969+
llvm::Type *Ty = nullptr);
970+
971+
llvm::Constant *getMemberFunctionPointer(llvm::Constant *pointer,
972+
QualType functionType,
973+
const FunctionDecl *FD = nullptr);
974+
967975
CGPointerAuthInfo getFunctionPointerAuthInfo(QualType functionType);
968976

969977
CGPointerAuthInfo getMemberFunctionPointerAuthInfo(QualType functionType);

0 commit comments

Comments
 (0)