Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 16 additions & 19 deletions clang/lib/AST/InferAlloc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,13 @@
#include "clang/Basic/IdentifierTable.h"
#include "llvm/ADT/SmallPtrSet.h"

namespace clang {
namespace {
bool typeContainsPointer(QualType T,
llvm::SmallPtrSet<const RecordDecl *, 4> &VisitedRD,
bool &IncompleteType) {
using namespace clang;
using namespace infer_alloc;

static bool
typeContainsPointer(QualType T,
llvm::SmallPtrSet<const RecordDecl *, 4> &VisitedRD,
bool &IncompleteType) {
QualType CanonicalType = T.getCanonicalType();
if (CanonicalType->isPointerType())
return true; // base case
Expand Down Expand Up @@ -70,7 +72,7 @@ bool typeContainsPointer(QualType T,
}

/// Infer type from a simple sizeof expression.
QualType inferTypeFromSizeofExpr(const Expr *E) {
static QualType inferTypeFromSizeofExpr(const Expr *E) {
const Expr *Arg = E->IgnoreParenImpCasts();
if (const auto *UET = dyn_cast<UnaryExprOrTypeTraitExpr>(Arg)) {
if (UET->getKind() == UETT_SizeOf) {
Expand All @@ -96,7 +98,7 @@ QualType inferTypeFromSizeofExpr(const Expr *E) {
///
/// malloc(sizeof(HasFlexArray) + sizeof(int) * 32); // infers 'HasFlexArray'
///
QualType inferPossibleTypeFromArithSizeofExpr(const Expr *E) {
static QualType inferPossibleTypeFromArithSizeofExpr(const Expr *E) {
const Expr *Arg = E->IgnoreParenImpCasts();
// The argument is a lone sizeof expression.
if (QualType T = inferTypeFromSizeofExpr(Arg); !T.isNull())
Expand Down Expand Up @@ -132,7 +134,7 @@ QualType inferPossibleTypeFromArithSizeofExpr(const Expr *E) {
/// size_t my_size = sizeof(MyType);
/// void *x = malloc(my_size); // infers 'MyType'
///
QualType inferPossibleTypeFromVarInitSizeofExpr(const Expr *E) {
static QualType inferPossibleTypeFromVarInitSizeofExpr(const Expr *E) {
const Expr *Arg = E->IgnoreParenImpCasts();
if (const auto *DRE = dyn_cast<DeclRefExpr>(Arg)) {
if (const auto *VD = dyn_cast<VarDecl>(DRE->getDecl())) {
Expand All @@ -148,21 +150,19 @@ QualType inferPossibleTypeFromVarInitSizeofExpr(const Expr *E) {
///
/// MyType *x = (MyType *)malloc(4096); // infers 'MyType'
///
QualType inferPossibleTypeFromCastExpr(const CallExpr *CallE,
const CastExpr *CastE) {
static QualType inferPossibleTypeFromCastExpr(const CallExpr *CallE,
const CastExpr *CastE) {
if (!CastE)
return QualType();
QualType PtrType = CastE->getType();
if (PtrType->isPointerType())
return PtrType->getPointeeType();
return QualType();
}
} // anonymous namespace

namespace infer_alloc {

QualType inferPossibleType(const CallExpr *E, const ASTContext &Ctx,
const CastExpr *CastE) {
QualType clang::infer_alloc::inferPossibleType(const CallExpr *E,
const ASTContext &Ctx,
const CastExpr *CastE) {
QualType AllocType;
// First check arguments.
for (const Expr *Arg : E->arguments()) {
Expand All @@ -179,7 +179,7 @@ QualType inferPossibleType(const CallExpr *E, const ASTContext &Ctx,
}

std::optional<llvm::AllocTokenMetadata>
getAllocTokenMetadata(QualType T, const ASTContext &Ctx) {
clang::infer_alloc::getAllocTokenMetadata(QualType T, const ASTContext &Ctx) {
llvm::AllocTokenMetadata ATMD;

// Get unique type name.
Expand All @@ -199,6 +199,3 @@ getAllocTokenMetadata(QualType T, const ASTContext &Ctx) {

return ATMD;
}

} // namespace infer_alloc
} // namespace clang
9 changes: 1 addition & 8 deletions clang/lib/Frontend/CompilerInvocation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4565,14 +4565,7 @@ bool CompilerInvocation::ParseLangArgs(LangOptions &Opts, ArgList &Args,

if (const auto *Arg = Args.getLastArg(options::OPT_falloc_token_mode_EQ)) {
StringRef S = Arg->getValue();
auto Mode = llvm::StringSwitch<std::optional<llvm::AllocTokenMode>>(S)
.Case("increment", llvm::AllocTokenMode::Increment)
.Case("random", llvm::AllocTokenMode::Random)
.Case("typehash", llvm::AllocTokenMode::TypeHash)
.Case("typehashpointersplit",
llvm::AllocTokenMode::TypeHashPointerSplit)
.Default(std::nullopt);
if (Mode)
if (auto Mode = getAllocTokenModeFromString(S))
Opts.AllocTokenMode = Mode;
else
Diags.Report(diag::err_drv_invalid_value) << Arg->getAsString(Args) << S;
Expand Down
12 changes: 9 additions & 3 deletions llvm/include/llvm/Support/AllocToken.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#define LLVM_SUPPORT_ALLOCTOKEN_H

#include "llvm/ADT/SmallString.h"
#include "llvm/ADT/StringRef.h"
#include <cstdint>
#include <optional>

Expand All @@ -40,6 +41,11 @@ enum class AllocTokenMode {
inline constexpr AllocTokenMode DefaultAllocTokenMode =
AllocTokenMode::TypeHashPointerSplit;

/// Returns the AllocTokenMode from its canonical string name; if an invalid
/// name was provided returns nullopt.
LLVM_ABI std::optional<AllocTokenMode>
getAllocTokenModeFromString(StringRef Name);

/// Metadata about an allocation used to generate a token ID.
struct AllocTokenMetadata {
SmallString<64> TypeName;
Expand All @@ -53,9 +59,9 @@ struct AllocTokenMetadata {
/// \param Metadata The metadata about the allocation.
/// \param MaxTokens The maximum number of tokens (must not be 0)
/// \return The calculated allocation token ID, or std::nullopt.
std::optional<uint64_t> getAllocTokenHash(AllocTokenMode Mode,
const AllocTokenMetadata &Metadata,
uint64_t MaxTokens);
LLVM_ABI std::optional<uint64_t>
getAllocToken(AllocTokenMode Mode, const AllocTokenMetadata &Metadata,
uint64_t MaxTokens);

} // end namespace llvm

Expand Down
9 changes: 1 addition & 8 deletions llvm/lib/Passes/PassBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1102,14 +1102,7 @@ Expected<AllocTokenOptions> parseAllocTokenPassOptions(StringRef Params) {
std::tie(ParamName, Params) = Params.split(';');

if (ParamName.consume_front("mode=")) {
auto Mode = StringSwitch<std::optional<AllocTokenMode>>(ParamName)
.Case("increment", AllocTokenMode::Increment)
.Case("random", AllocTokenMode::Random)
.Case("typehash", AllocTokenMode::TypeHash)
.Case("typehashpointersplit",
AllocTokenMode::TypeHashPointerSplit)
.Default(std::nullopt);
if (Mode)
if (auto Mode = getAllocTokenModeFromString(ParamName))
Result.Mode = *Mode;
else
return make_error<StringError>(
Expand Down
35 changes: 25 additions & 10 deletions llvm/lib/Support/AllocToken.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,30 +11,46 @@
//===----------------------------------------------------------------------===//

#include "llvm/Support/AllocToken.h"
#include "llvm/ADT/StringSwitch.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/SipHash.h"

namespace llvm {
std::optional<uint64_t> getAllocTokenHash(AllocTokenMode Mode,
const AllocTokenMetadata &Metadata,
uint64_t MaxTokens) {
assert(MaxTokens && "Must provide concrete max tokens");
using namespace llvm;

std::optional<AllocTokenMode>
llvm::getAllocTokenModeFromString(StringRef Name) {
return StringSwitch<std::optional<AllocTokenMode>>(Name)
.Case("increment", AllocTokenMode::Increment)
.Case("random", AllocTokenMode::Random)
.Case("typehash", AllocTokenMode::TypeHash)
.Case("typehashpointersplit", AllocTokenMode::TypeHashPointerSplit)
.Default(std::nullopt);
}

static uint64_t getStableHash(const AllocTokenMetadata &Metadata,
uint64_t MaxTokens) {
return getStableSipHash(Metadata.TypeName) % MaxTokens;
}

std::optional<uint64_t> llvm::getAllocToken(AllocTokenMode Mode,
const AllocTokenMetadata &Metadata,
uint64_t MaxTokens) {
assert(MaxTokens && "Must provide non-zero max tokens");

switch (Mode) {
case AllocTokenMode::Increment:
case AllocTokenMode::Random:
// Stateful modes cannot be implemented as a pure function.
return std::nullopt;

case AllocTokenMode::TypeHash: {
return getStableSipHash(Metadata.TypeName) % MaxTokens;
}
case AllocTokenMode::TypeHash:
return getStableHash(Metadata, MaxTokens);

case AllocTokenMode::TypeHashPointerSplit: {
if (MaxTokens == 1)
return 0;
const uint64_t HalfTokens = MaxTokens / 2;
uint64_t Hash = getStableSipHash(Metadata.TypeName) % HalfTokens;
uint64_t Hash = getStableHash(Metadata, HalfTokens);
if (Metadata.ContainsPointer)
Hash += HalfTokens;
return Hash;
Expand All @@ -43,4 +59,3 @@ std::optional<uint64_t> getAllocTokenHash(AllocTokenMode Mode,

llvm_unreachable("");
}
} // namespace llvm
12 changes: 5 additions & 7 deletions llvm/lib/Transforms/Instrumentation/AllocToken.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -189,8 +189,7 @@ class TypeHashMode : public ModeBase {
if (MDNode *N = getAllocTokenMetadata(CB)) {
MDString *S = cast<MDString>(N->getOperand(0));
AllocTokenMetadata Metadata{S->getString(), containsPointer(N)};
if (auto Token =
getAllocTokenHash(TokenMode::TypeHash, Metadata, MaxTokens))
if (auto Token = getAllocToken(TokenMode::TypeHash, Metadata, MaxTokens))
return *Token;
}
// Fallback.
Expand Down Expand Up @@ -222,8 +221,8 @@ class TypeHashPointerSplitMode : public TypeHashMode {
if (MDNode *N = getAllocTokenMetadata(CB)) {
MDString *S = cast<MDString>(N->getOperand(0));
AllocTokenMetadata Metadata{S->getString(), containsPointer(N)};
if (auto Token = getAllocTokenHash(TokenMode::TypeHashPointerSplit,
Metadata, MaxTokens))
if (auto Token = getAllocToken(TokenMode::TypeHashPointerSplit, Metadata,
MaxTokens))
return *Token;
}
// Pick the fallback token (ClFallbackToken), which by default is 0, meaning
Expand Down Expand Up @@ -357,9 +356,8 @@ bool AllocToken::instrumentFunction(Function &F) {
}

if (!IntrinsicInsts.empty()) {
for (auto *II : IntrinsicInsts) {
for (auto *II : IntrinsicInsts)
replaceIntrinsicInst(II, ORE);
}
Modified = true;
NumFunctionsModified++;
}
Expand All @@ -381,7 +379,7 @@ AllocToken::shouldInstrumentCall(const CallBase &CB,
if (TLI.getLibFunc(*Callee, Func)) {
if (isInstrumentableLibFunc(Func, CB, TLI))
return Func;
} else if (Options.Extended && getAllocTokenMetadata(CB)) {
} else if (Options.Extended && CB.getMetadata(LLVMContext::MD_alloc_token)) {
return NotLibFunc;
}

Expand Down
Loading