Skip to content
6 changes: 3 additions & 3 deletions clang/lib/CodeGen/CGPointerAuth.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -440,9 +440,9 @@ CodeGenModule::getConstantSignedPointer(llvm::Constant *Pointer, unsigned Key,
IntegerDiscriminator = llvm::ConstantInt::get(Int64Ty, 0);
}

return llvm::ConstantPtrAuth::get(Pointer,
llvm::ConstantInt::get(Int32Ty, Key),
IntegerDiscriminator, AddressDiscriminator);
return llvm::ConstantPtrAuth::get(
Pointer, llvm::ConstantInt::get(Int32Ty, Key), IntegerDiscriminator,
AddressDiscriminator, llvm::Constant::getNullValue(UnqualPtrTy));
}

/// Does a given PointerAuthScheme require us to sign a value
Expand Down
1 change: 1 addition & 0 deletions llvm/include/llvm/Bitcode/LLVMBitCodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,7 @@ enum ConstantsCodes {
CST_CODE_CE_GEP_WITH_INRANGE = 31, // [opty, flags, range, n x operands]
CST_CODE_CE_GEP = 32, // [opty, flags, n x operands]
CST_CODE_PTRAUTH = 33, // [ptr, key, disc, addrdisc]
CST_CODE_PTRAUTH2 = 34, // [ptr, key, disc, addrdisc, DeactivationSymbol]
};

/// CastOpcodes - These are values used in the bitcode files to encode which
Expand Down
13 changes: 9 additions & 4 deletions llvm/include/llvm/IR/Constants.h
Original file line number Diff line number Diff line change
Expand Up @@ -1031,10 +1031,10 @@ class ConstantPtrAuth final : public Constant {
friend struct ConstantPtrAuthKeyType;
friend class Constant;

constexpr static IntrusiveOperandsAllocMarker AllocMarker{4};
constexpr static IntrusiveOperandsAllocMarker AllocMarker{5};

ConstantPtrAuth(Constant *Ptr, ConstantInt *Key, ConstantInt *Disc,
Constant *AddrDisc);
Constant *AddrDisc, Constant *DeactivationSymbol);

void *operator new(size_t s) { return User::operator new(s, AllocMarker); }

Expand All @@ -1044,7 +1044,8 @@ class ConstantPtrAuth final : public Constant {
public:
/// Return a pointer signed with the specified parameters.
static ConstantPtrAuth *get(Constant *Ptr, ConstantInt *Key,
ConstantInt *Disc, Constant *AddrDisc);
ConstantInt *Disc, Constant *AddrDisc,
Constant *DeactivationSymbol);

/// Produce a new ptrauth expression signing the given value using
/// the same schema as is stored in one.
Expand Down Expand Up @@ -1076,6 +1077,10 @@ class ConstantPtrAuth final : public Constant {
return !getAddrDiscriminator()->isNullValue();
}

Constant *getDeactivationSymbol() const {
return cast<Constant>(Op<4>().get());
}

/// A constant value for the address discriminator which has special
/// significance to ctors/dtors lowering. Regular address discrimination can't
/// be applied for them since uses of llvm.global_{c|d}tors are disallowed
Expand Down Expand Up @@ -1103,7 +1108,7 @@ class ConstantPtrAuth final : public Constant {

template <>
struct OperandTraits<ConstantPtrAuth>
: public FixedNumOperandTraits<ConstantPtrAuth, 4> {};
: public FixedNumOperandTraits<ConstantPtrAuth, 5> {};

DEFINE_TRANSPARENT_OPERAND_ACCESSORS(ConstantPtrAuth, Constant)

Expand Down
5 changes: 4 additions & 1 deletion llvm/include/llvm/SandboxIR/Constant.h
Original file line number Diff line number Diff line change
Expand Up @@ -1384,7 +1384,8 @@ class ConstantPtrAuth final : public Constant {
public:
/// Return a pointer signed with the specified parameters.
static ConstantPtrAuth *get(Constant *Ptr, ConstantInt *Key,
ConstantInt *Disc, Constant *AddrDisc);
ConstantInt *Disc, Constant *AddrDisc,
Constant *DeactivationSymbol);
/// The pointer that is signed in this ptrauth signed pointer.
Constant *getPointer() const;

Expand All @@ -1399,6 +1400,8 @@ class ConstantPtrAuth final : public Constant {
/// the only global-initializer user of the ptrauth signed pointer.
Constant *getAddrDiscriminator() const;

Constant *getDeactivationSymbol() const;

/// Whether there is any non-null address discriminator.
bool hasAddressDiscriminator() const {
return cast<llvm::ConstantPtrAuth>(Val)->hasAddressDiscriminator();
Expand Down
29 changes: 21 additions & 8 deletions llvm/lib/AsmParser/LLParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4218,11 +4218,12 @@ bool LLParser::parseValID(ValID &ID, PerFunctionState *PFS, Type *ExpectedTy) {
}
case lltok::kw_ptrauth: {
// ValID ::= 'ptrauth' '(' ptr @foo ',' i32 <key>
// (',' i64 <disc> (',' ptr addrdisc)? )? ')'
// (',' i64 <disc> (',' ptr addrdisc (',' ptr ds)? )? )? ')'
Lex.Lex();

Constant *Ptr, *Key;
Constant *Disc = nullptr, *AddrDisc = nullptr;
Constant *Disc = nullptr, *AddrDisc = nullptr,
*DeactivationSymbol = nullptr;

if (parseToken(lltok::lparen,
"expected '(' in constant ptrauth expression") ||
Expand All @@ -4231,11 +4232,14 @@ bool LLParser::parseValID(ValID &ID, PerFunctionState *PFS, Type *ExpectedTy) {
"expected comma in constant ptrauth expression") ||
parseGlobalTypeAndValue(Key))
return true;
// If present, parse the optional disc/addrdisc.
if (EatIfPresent(lltok::comma))
if (parseGlobalTypeAndValue(Disc) ||
(EatIfPresent(lltok::comma) && parseGlobalTypeAndValue(AddrDisc)))
return true;
// If present, parse the optional disc/addrdisc/ds.
if (EatIfPresent(lltok::comma) && parseGlobalTypeAndValue(Disc))
return true;
if (EatIfPresent(lltok::comma) && parseGlobalTypeAndValue(AddrDisc))
return true;
if (EatIfPresent(lltok::comma) &&
parseGlobalTypeAndValue(DeactivationSymbol))
return true;
if (parseToken(lltok::rparen,
"expected ')' in constant ptrauth expression"))
return true;
Expand Down Expand Up @@ -4266,7 +4270,16 @@ bool LLParser::parseValID(ValID &ID, PerFunctionState *PFS, Type *ExpectedTy) {
AddrDisc = ConstantPointerNull::get(PointerType::get(Context, 0));
}

ID.ConstantVal = ConstantPtrAuth::get(Ptr, KeyC, DiscC, AddrDisc);
if (DeactivationSymbol) {
if (!DeactivationSymbol->getType()->isPointerTy())
return error(
ID.Loc, "constant ptrauth deactivation symbol must be a pointer");
} else {
DeactivationSymbol = ConstantPointerNull::get(PointerType::get(Context, 0));
}

ID.ConstantVal =
ConstantPtrAuth::get(Ptr, KeyC, DiscC, AddrDisc, DeactivationSymbol);
ID.Kind = ValID::t_Constant;
return false;
}
Expand Down
18 changes: 17 additions & 1 deletion llvm/lib/Bitcode/Reader/BitcodeReader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1603,7 +1603,13 @@ Expected<Value *> BitcodeReader::materializeValue(unsigned StartValID,
if (!Disc)
return error("ptrauth disc operand must be ConstantInt");

C = ConstantPtrAuth::get(ConstOps[0], Key, Disc, ConstOps[3]);
auto *DeactivationSymbol =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] while there is a type named on the rhs, I think based on the ternary rather than single option, this should be Constant * instead of auto *

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

ConstOps.size() > 4 ? ConstOps[4]
: ConstantPointerNull::get(cast<PointerType>(
ConstOps[3]->getType()));

C = ConstantPtrAuth::get(ConstOps[0], Key, Disc, ConstOps[3],
DeactivationSymbol);
break;
}
case BitcodeConstant::NoCFIOpcode: {
Expand Down Expand Up @@ -3801,6 +3807,16 @@ Error BitcodeReader::parseConstants() {
(unsigned)Record[2], (unsigned)Record[3]});
break;
}
case bitc::CST_CODE_PTRAUTH2: {
if (Record.size() < 4)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be 5?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, fixed

return error("Invalid ptrauth record");
// Ptr, Key, Disc, AddrDisc, DeactivationSymbol
V = BitcodeConstant::create(
Alloc, CurTy, BitcodeConstant::ConstantPtrAuthOpcode,
{(unsigned)Record[0], (unsigned)Record[1], (unsigned)Record[2],
(unsigned)Record[3], (unsigned)Record[4]});
break;
}
}

assert(V->getType() == getTypeByID(CurTyID) && "Incorrect result type ID");
Expand Down
4 changes: 3 additions & 1 deletion llvm/lib/IR/AsmWriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1658,12 +1658,14 @@ static void WriteConstantInternal(raw_ostream &Out, const Constant *CV,
if (const ConstantPtrAuth *CPA = dyn_cast<ConstantPtrAuth>(CV)) {
Out << "ptrauth (";

// ptrauth (ptr CST, i32 KEY[, i64 DISC[, ptr ADDRDISC]?]?)
// ptrauth (ptr CST, i32 KEY[, i64 DISC[, ptr ADDRDISC[, ptr DS]?]?]?)
unsigned NumOpsToWrite = 2;
if (!CPA->getOperand(2)->isNullValue())
NumOpsToWrite = 3;
if (!CPA->getOperand(3)->isNullValue())
NumOpsToWrite = 4;
if (!CPA->getOperand(4)->isNullValue())
NumOpsToWrite = 5;

ListSeparator LS;
for (unsigned i = 0, e = NumOpsToWrite; i != e; ++i) {
Expand Down
12 changes: 8 additions & 4 deletions llvm/lib/IR/Constants.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2056,19 +2056,22 @@ Value *NoCFIValue::handleOperandChangeImpl(Value *From, Value *To) {
//

ConstantPtrAuth *ConstantPtrAuth::get(Constant *Ptr, ConstantInt *Key,
ConstantInt *Disc, Constant *AddrDisc) {
Constant *ArgVec[] = {Ptr, Key, Disc, AddrDisc};
ConstantInt *Disc, Constant *AddrDisc,
Constant *DeactivationSymbol) {
Constant *ArgVec[] = {Ptr, Key, Disc, AddrDisc, DeactivationSymbol};
ConstantPtrAuthKeyType MapKey(ArgVec);
LLVMContextImpl *pImpl = Ptr->getContext().pImpl;
return pImpl->ConstantPtrAuths.getOrCreate(Ptr->getType(), MapKey);
}

ConstantPtrAuth *ConstantPtrAuth::getWithSameSchema(Constant *Pointer) const {
return get(Pointer, getKey(), getDiscriminator(), getAddrDiscriminator());
return get(Pointer, getKey(), getDiscriminator(), getAddrDiscriminator(),
getDeactivationSymbol());
}

ConstantPtrAuth::ConstantPtrAuth(Constant *Ptr, ConstantInt *Key,
ConstantInt *Disc, Constant *AddrDisc)
ConstantInt *Disc, Constant *AddrDisc,
Constant *DeactivationSymbol)
: Constant(Ptr->getType(), Value::ConstantPtrAuthVal, AllocMarker) {
assert(Ptr->getType()->isPointerTy());
assert(Key->getBitWidth() == 32);
Expand All @@ -2078,6 +2081,7 @@ ConstantPtrAuth::ConstantPtrAuth(Constant *Ptr, ConstantInt *Key,
setOperand(1, Key);
setOperand(2, Disc);
setOperand(3, AddrDisc);
setOperand(4, DeactivationSymbol);
}

/// Remove the constant from the constant table.
Expand Down
3 changes: 2 additions & 1 deletion llvm/lib/IR/ConstantsContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -539,7 +539,8 @@ struct ConstantPtrAuthKeyType {

ConstantPtrAuth *create(TypeClass *Ty) const {
return new ConstantPtrAuth(Operands[0], cast<ConstantInt>(Operands[1]),
cast<ConstantInt>(Operands[2]), Operands[3]);
cast<ConstantInt>(Operands[2]), Operands[3],
Operands[4]);
}
};

Expand Down
4 changes: 3 additions & 1 deletion llvm/lib/IR/Core.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1699,7 +1699,9 @@ LLVMValueRef LLVMConstantPtrAuth(LLVMValueRef Ptr, LLVMValueRef Key,
LLVMValueRef Disc, LLVMValueRef AddrDisc) {
return wrap(ConstantPtrAuth::get(
unwrap<Constant>(Ptr), unwrap<ConstantInt>(Key),
unwrap<ConstantInt>(Disc), unwrap<Constant>(AddrDisc)));
unwrap<ConstantInt>(Disc), unwrap<Constant>(AddrDisc),
ConstantPointerNull::get(
cast<PointerType>(unwrap<Constant>(AddrDisc)->getType()))));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to extend the C API to give access to this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I reckon that could be done in a followup if anyone needs it.

}

/*-- Opcode mapping */
Expand Down
11 changes: 9 additions & 2 deletions llvm/lib/SandboxIR/Constant.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -443,10 +443,12 @@ PointerType *NoCFIValue::getType() const {
}

ConstantPtrAuth *ConstantPtrAuth::get(Constant *Ptr, ConstantInt *Key,
ConstantInt *Disc, Constant *AddrDisc) {
ConstantInt *Disc, Constant *AddrDisc,
Constant *DeactivationSymbol) {
auto *LLVMC = llvm::ConstantPtrAuth::get(
cast<llvm::Constant>(Ptr->Val), cast<llvm::ConstantInt>(Key->Val),
cast<llvm::ConstantInt>(Disc->Val), cast<llvm::Constant>(AddrDisc->Val));
cast<llvm::ConstantInt>(Disc->Val), cast<llvm::Constant>(AddrDisc->Val),
cast<llvm::Constant>(DeactivationSymbol->Val));
return cast<ConstantPtrAuth>(Ptr->getContext().getOrCreateConstant(LLVMC));
}

Expand All @@ -470,6 +472,11 @@ Constant *ConstantPtrAuth::getAddrDiscriminator() const {
cast<llvm::ConstantPtrAuth>(Val)->getAddrDiscriminator());
}

Constant *ConstantPtrAuth::getDeactivationSymbol() const {
return Ctx.getOrCreateConstant(
cast<llvm::ConstantPtrAuth>(Val)->getDeactivationSymbol());
}

ConstantPtrAuth *ConstantPtrAuth::getWithSameSchema(Constant *Pointer) const {
auto *LLVMC = cast<llvm::ConstantPtrAuth>(Val)->getWithSameSchema(
cast<llvm::Constant>(Pointer->Val));
Expand Down
37 changes: 31 additions & 6 deletions llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ class AArch64AsmPrinter : public AsmPrinter {

const MCExpr *emitPAuthRelocationAsIRelative(
const MCExpr *Target, uint16_t Disc, AArch64PACKey::ID KeyID,
bool HasAddressDiversity, bool IsDSOLocal);
bool HasAddressDiversity, bool IsDSOLocal, const MCExpr *DSExpr);

/// tblgen'erated driver function for lowering simple MI->MC
/// pseudo instructions.
Expand Down Expand Up @@ -2301,15 +2301,17 @@ static void emitAddress(MCStreamer &Streamer, MCRegister Reg,
}

static bool targetSupportsPAuthRelocation(const Triple &TT,
const MCExpr *Target) {
const MCExpr *Target,
const MCExpr *DSExpr) {
// No released version of glibc supports PAuth relocations.
if (TT.isOSGlibc())
return false;

// We emit PAuth constants as IRELATIVE relocations in cases where the
// constant cannot be represented as a PAuth relocation:
// 1) The signed value is not a symbol.
return !isa<MCConstantExpr>(Target);
// 1) There is a deactivation symbol.
// 2) The signed value is not a symbol.
return !DSExpr && !isa<MCConstantExpr>(Target);
}

static bool targetSupportsIRelativeRelocation(const Triple &TT) {
Expand All @@ -2326,7 +2328,7 @@ static bool targetSupportsIRelativeRelocation(const Triple &TT) {

const MCExpr *AArch64AsmPrinter::emitPAuthRelocationAsIRelative(
const MCExpr *Target, uint16_t Disc, AArch64PACKey::ID KeyID,
bool HasAddressDiversity, bool IsDSOLocal) {
bool HasAddressDiversity, bool IsDSOLocal, const MCExpr *DSExpr) {
const Triple &TT = TM.getTargetTriple();

// We only emit an IRELATIVE relocation if the target supports IRELATIVE and
Expand Down Expand Up @@ -2388,6 +2390,18 @@ const MCExpr *AArch64AsmPrinter::emitPAuthRelocationAsIRelative(
MCSymbolRefExpr::create(EmuPAC, OutStreamer->getContext());
OutStreamer->emitInstruction(MCInstBuilder(AArch64::B).addExpr(EmuPACRef),
*STI);

if (DSExpr) {
auto *PrePACInstExpr =
MCSymbolRefExpr::create(PrePACInst, OutStreamer->getContext());
OutStreamer->emitRelocDirective(*PrePACInstExpr, "R_AARCH64_INST32", DSExpr,
SMLoc(), *STI);
}

// We need a RET despite the above tail call because the deactivation symbol
// may replace it with a NOP.
OutStreamer->emitInstruction(MCInstBuilder(AArch64::RET).addReg(AArch64::LR),
*STI);
OutStreamer->popSection();

return MCSymbolRefExpr::create(IRelativeSym, AArch64MCExpr::VK_FUNCINIT,
Expand Down Expand Up @@ -2419,6 +2433,13 @@ AArch64AsmPrinter::lowerConstantPtrAuth(const ConstantPtrAuth &CPA) {
Sym = MCConstantExpr::create(Offset.getSExtValue(), Ctx);
}

const MCExpr *DSExpr = nullptr;
if (auto *DS = dyn_cast<GlobalValue>(CPA.getDeactivationSymbol())) {
if (isa<GlobalAlias>(DS))
return Sym;
DSExpr = MCSymbolRefExpr::create(getSymbol(DS), Ctx);
}

uint64_t KeyID = CPA.getKey()->getZExtValue();
// We later rely on valid KeyID value in AArch64PACKeyIDToString call from
// AArch64AuthMCExpr::printImpl, so fail fast.
Expand All @@ -2435,9 +2456,13 @@ AArch64AsmPrinter::lowerConstantPtrAuth(const ConstantPtrAuth &CPA) {
// Check if we need to represent this with an IRELATIVE and emit it if so.
if (auto *IFuncSym = emitPAuthRelocationAsIRelative(
Sym, Disc, AArch64PACKey::ID(KeyID), CPA.hasAddressDiscriminator(),
BaseGVB && BaseGVB->isDSOLocal()))
BaseGVB && BaseGVB->isDSOLocal(), DSExpr))
return IFuncSym;

if (DSExpr)
report_fatal_error("deactivation symbols unsupported in constant "
"expressions on this target");

// Finally build the complete @AUTH expr.
return AArch64AuthMCExpr::create(Sym, Disc, AArch64PACKey::ID(KeyID),
CPA.hasAddressDiscriminator(), Ctx);
Expand Down
4 changes: 2 additions & 2 deletions llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2979,9 +2979,9 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
if (NeedSign && isa<ConstantInt>(II->getArgOperand(4))) {
auto *SignKey = cast<ConstantInt>(II->getArgOperand(3));
auto *SignDisc = cast<ConstantInt>(II->getArgOperand(4));
auto *SignAddrDisc = ConstantPointerNull::get(Builder.getPtrTy());
auto *Null = ConstantPointerNull::get(Builder.getPtrTy());
auto *NewCPA = ConstantPtrAuth::get(CPA->getPointer(), SignKey,
SignDisc, SignAddrDisc);
SignDisc, Null, Null);
replaceInstUsesWith(
*II, ConstantExpr::getPointerCast(NewCPA, II->getType()));
return eraseInstFromFunction(*II);
Expand Down
5 changes: 3 additions & 2 deletions llvm/lib/Transforms/Utils/ValueMapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -527,8 +527,9 @@ Value *Mapper::mapValue(const Value *V) {
if (isa<ConstantVector>(C))
return getVM()[V] = ConstantVector::get(Ops);
if (isa<ConstantPtrAuth>(C))
return getVM()[V] = ConstantPtrAuth::get(Ops[0], cast<ConstantInt>(Ops[1]),
cast<ConstantInt>(Ops[2]), Ops[3]);
return getVM()[V] =
ConstantPtrAuth::get(Ops[0], cast<ConstantInt>(Ops[1]),
cast<ConstantInt>(Ops[2]), Ops[3], Ops[4]);
// If this is a no-operand constant, it must be because the type was remapped.
if (isa<PoisonValue>(C))
return getVM()[V] = PoisonValue::get(NewTy);
Expand Down
2 changes: 1 addition & 1 deletion llvm/unittests/SandboxIR/SandboxIRTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1398,7 +1398,7 @@ define ptr @foo() {
// Check get(), getKey(), getDiscriminator(), getAddrDiscriminator().
auto *NewPtrAuth = sandboxir::ConstantPtrAuth::get(
&F, PtrAuth->getKey(), PtrAuth->getDiscriminator(),
PtrAuth->getAddrDiscriminator());
PtrAuth->getAddrDiscriminator(), PtrAuth->getDeactivationSymbol());
EXPECT_EQ(NewPtrAuth, PtrAuth);
// Check hasAddressDiscriminator().
EXPECT_EQ(PtrAuth->hasAddressDiscriminator(),
Expand Down
Loading
Loading