Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,14 @@ struct UncheckedOptionalAccessModelOptions {
/// can't identify when their results are used safely (across calls),
/// resulting in false positives in all such cases. Note: this option does not
/// cover access through `operator[]`.
/// FIXME: we currently cache and equate the result of const accessors
/// returning pointers, so cover the case of operator-> followed by
/// operator->, which covers the common case of smart pointers. We also cover
/// some limited cases of returning references (if return type is an optional
/// type), so cover some cases of operator* followed by operator*. We don't
/// cover mixing operator-> and operator*. Once we are confident in this const
/// accessor caching, we shouldn't need the IgnoreSmartPointerDereference
/// option anymore.
bool IgnoreSmartPointerDereference = false;
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,11 @@ auto isZeroParamConstMemberCall() {
callee(cxxMethodDecl(parameterCountIs(0), isConst())));
}

auto isZeroParamConstMemberOperatorCall() {
return cxxOperatorCallExpr(
callee(cxxMethodDecl(parameterCountIs(0), isConst())));
}

auto isNonConstMemberCall() {
return cxxMemberCallExpr(callee(cxxMethodDecl(unless(isConst()))));
}
Expand Down Expand Up @@ -572,9 +577,10 @@ void handleConstMemberCall(const CallExpr *CE,
return;
}

// Cache if the const method returns a boolean type.
// Cache if the const method returns a boolean or pointer type.
// We may decide to cache other return types in the future.
if (RecordLoc != nullptr && CE->getType()->isBooleanType()) {
if (RecordLoc != nullptr &&
(CE->getType()->isBooleanType() || CE->getType()->isPointerType())) {
Value *Val = State.Lattice.getOrCreateConstMethodReturnValue(*RecordLoc, CE,
State.Env);
if (Val == nullptr)
Expand All @@ -597,6 +603,14 @@ void transferValue_ConstMemberCall(const CXXMemberCallExpr *MCE,
MCE, dataflow::getImplicitObjectLocation(*MCE, State.Env), Result, State);
}

void transferValue_ConstMemberOperatorCall(
const CXXOperatorCallExpr *OCE, const MatchFinder::MatchResult &Result,
LatticeTransferState &State) {
auto *RecordLoc = cast_or_null<dataflow::RecordStorageLocation>(
State.Env.getStorageLocation(*OCE->getArg(0)));
handleConstMemberCall(OCE, RecordLoc, Result, State);
}

void handleNonConstMemberCall(const CallExpr *CE,
dataflow::RecordStorageLocation *RecordLoc,
const MatchFinder::MatchResult &Result,
Expand Down Expand Up @@ -1020,6 +1034,8 @@ auto buildTransferMatchSwitch() {
// const accessor calls
.CaseOfCFGStmt<CXXMemberCallExpr>(isZeroParamConstMemberCall(),
transferValue_ConstMemberCall)
.CaseOfCFGStmt<CXXOperatorCallExpr>(isZeroParamConstMemberOperatorCall(),
transferValue_ConstMemberOperatorCall)
// non-const member calls that may modify the state of an object.
.CaseOfCFGStmt<CXXMemberCallExpr>(isNonConstMemberCall(),
transferValue_NonConstMemberCall)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1282,28 +1282,35 @@ static raw_ostream &operator<<(raw_ostream &OS,
class UncheckedOptionalAccessTest
: public ::testing::TestWithParam<OptionalTypeIdentifier> {
protected:
void ExpectDiagnosticsFor(std::string SourceCode) {
ExpectDiagnosticsFor(SourceCode, ast_matchers::hasName("target"));
void ExpectDiagnosticsFor(std::string SourceCode,
bool IgnoreSmartPointerDereference = true) {
ExpectDiagnosticsFor(SourceCode, ast_matchers::hasName("target"),
IgnoreSmartPointerDereference);
}

void ExpectDiagnosticsForLambda(std::string SourceCode) {
void ExpectDiagnosticsForLambda(std::string SourceCode,
bool IgnoreSmartPointerDereference = true) {
ExpectDiagnosticsFor(
SourceCode, ast_matchers::hasDeclContext(
ast_matchers::cxxRecordDecl(ast_matchers::isLambda())));
SourceCode,
ast_matchers::hasDeclContext(
ast_matchers::cxxRecordDecl(ast_matchers::isLambda())),
IgnoreSmartPointerDereference);
}

template <typename FuncDeclMatcher>
void ExpectDiagnosticsFor(std::string SourceCode,
FuncDeclMatcher FuncMatcher) {
void ExpectDiagnosticsFor(std::string SourceCode, FuncDeclMatcher FuncMatcher,
bool IgnoreSmartPointerDereference = true) {
// Run in C++17 and C++20 mode to cover differences in the AST between modes
// (e.g. C++20 can contain `CXXRewrittenBinaryOperator`).
for (const char *CxxMode : {"-std=c++17", "-std=c++20"})
ExpectDiagnosticsFor(SourceCode, FuncMatcher, CxxMode);
ExpectDiagnosticsFor(SourceCode, FuncMatcher, CxxMode,
IgnoreSmartPointerDereference);
}

template <typename FuncDeclMatcher>
void ExpectDiagnosticsFor(std::string SourceCode, FuncDeclMatcher FuncMatcher,
const char *CxxMode) {
const char *CxxMode,
bool IgnoreSmartPointerDereference) {
ReplaceAllOccurrences(SourceCode, "$ns", GetParam().NamespaceName);
ReplaceAllOccurrences(SourceCode, "$optional", GetParam().TypeName);

Expand All @@ -1328,8 +1335,7 @@ class UncheckedOptionalAccessTest
template <typename T>
T Make();
)");
UncheckedOptionalAccessModelOptions Options{
/*IgnoreSmartPointerDereference=*/true};
UncheckedOptionalAccessModelOptions Options{IgnoreSmartPointerDereference};
std::vector<SourceLocation> Diagnostics;
llvm::Error Error = checkDataflow<UncheckedOptionalAccessModel>(
AnalysisInputs<UncheckedOptionalAccessModel>(
Expand Down Expand Up @@ -3721,6 +3727,50 @@ TEST_P(UncheckedOptionalAccessTest, ConstByValueAccessorWithModInBetween) {
)cc");
}

TEST_P(UncheckedOptionalAccessTest, ConstPointerAccessor) {
ExpectDiagnosticsFor(R"cc(
#include "unchecked_optional_access_test.h"

struct B {
$ns::$optional<int> x;
};

struct MyUniquePtr {
B* operator->() const;
};

void target(MyUniquePtr a) {
if (a->x) {
*a->x;
}
}
)cc",
/*IgnoreSmartPointerDereference=*/false);
}

TEST_P(UncheckedOptionalAccessTest, ConstPointerAccessorWithModInBetween) {
ExpectDiagnosticsFor(R"cc(
#include "unchecked_optional_access_test.h"

struct B {
$ns::$optional<int> x;
};

struct MyUniquePtr {
B* operator->() const;
void reset(B*);
};

void target(MyUniquePtr a) {
if (a->x) {
a.reset(nullptr);
*a->x; // [[unsafe]]
}
}
)cc",
/*IgnoreSmartPointerDereference=*/false);
}

TEST_P(UncheckedOptionalAccessTest, ConstBoolAccessor) {
ExpectDiagnosticsFor(R"cc(
#include "unchecked_optional_access_test.h"
Expand Down
Loading