Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,19 @@ static auto valueOperatorCall() {
isStatusOrOperatorCallWithName("->")));
}

static clang::ast_matchers::TypeMatcher statusType() {
using namespace ::clang::ast_matchers; // NOLINT: Too many names
return hasCanonicalType(qualType(hasDeclaration(statusClass())));
}

static auto isComparisonOperatorCall(llvm::StringRef operator_name) {
using namespace ::clang::ast_matchers; // NOLINT: Too many names
return cxxOperatorCallExpr(
hasOverloadedOperatorName(operator_name), argumentCountIs(2),
hasArgument(0, anyOf(hasType(statusType()), hasType(statusOrType()))),
hasArgument(1, anyOf(hasType(statusType()), hasType(statusOrType()))));
}

static auto
buildDiagnoseMatchSwitch(const UncheckedStatusOrAccessModelOptions &Options) {
return CFGMatchSwitchBuilder<const Environment,
Expand Down Expand Up @@ -312,6 +325,101 @@ static void transferStatusUpdateCall(const CXXMemberCallExpr *Expr,
State.Env.setValue(locForOk(*ThisLoc), NewVal);
}

static BoolValue *evaluateStatusEquality(RecordStorageLocation &LhsStatusLoc,
RecordStorageLocation &RhsStatusLoc,
Environment &Env) {
auto &A = Env.arena();
// Logically, a Status object is composed of an error code that could take one
// of multiple possible values, including the "ok" value. We track whether a
// Status object has an "ok" value and represent this as an `ok` bit. Equality
// of Status objects compares their error codes. Therefore, merely comparing
// the `ok` bits isn't sufficient: when two Status objects are assigned non-ok
// error codes the equality of their respective error codes matters. Since we
// only track the `ok` bits, we can't make any conclusions about equality when
// we know that two Status objects have non-ok values.

auto &LhsOkVal = valForOk(LhsStatusLoc, Env);
auto &RhsOkVal = valForOk(RhsStatusLoc, Env);

auto &Res = Env.makeAtomicBoolValue();

// lhs && rhs => res (a.k.a. !res => !lhs || !rhs)
Env.assume(A.makeImplies(A.makeAnd(LhsOkVal.formula(), RhsOkVal.formula()),
Res.formula()));
// res => (lhs == rhs)
Env.assume(A.makeImplies(
Res.formula(), A.makeEquals(LhsOkVal.formula(), RhsOkVal.formula())));

return &Res;
}

static BoolValue *
evaluateStatusOrEquality(RecordStorageLocation &LhsStatusOrLoc,
RecordStorageLocation &RhsStatusOrLoc,
Environment &Env) {
auto &A = Env.arena();
// Logically, a StatusOr<T> object is composed of two values - a Status and a
// value of type T. Equality of StatusOr objects compares both values.
// Therefore, merely comparing the `ok` bits of the Status values isn't
// sufficient. When two StatusOr objects are engaged, the equality of their
// respective values of type T matters. Similarly, when two StatusOr objects
// have Status values that have non-ok error codes, the equality of the error
// codes matters. Since we only track the `ok` bits of the Status values, we
// can't make any conclusions about equality when we know that two StatusOr
// objects are engaged or when their Status values contain non-ok error codes.
auto &LhsOkVal = valForOk(locForStatus(LhsStatusOrLoc), Env);
auto &RhsOkVal = valForOk(locForStatus(RhsStatusOrLoc), Env);
auto &res = Env.makeAtomicBoolValue();

// res => (lhs == rhs)
Env.assume(A.makeImplies(
res.formula(), A.makeEquals(LhsOkVal.formula(), RhsOkVal.formula())));
return &res;
}

static BoolValue *evaluateEquality(const Expr *LhsExpr, const Expr *RhsExpr,
Environment &Env) {
// Check the type of both sides in case an operator== is added that admits
// different types.
if (isStatusOrType(LhsExpr->getType()) &&
isStatusOrType(RhsExpr->getType())) {
auto *LhsStatusOrLoc = Env.get<RecordStorageLocation>(*LhsExpr);
if (LhsStatusOrLoc == nullptr)
return nullptr;
auto *RhsStatusOrLoc = Env.get<RecordStorageLocation>(*RhsExpr);
if (RhsStatusOrLoc == nullptr)
return nullptr;

return evaluateStatusOrEquality(*LhsStatusOrLoc, *RhsStatusOrLoc, Env);
}
if (isStatusType(LhsExpr->getType()) && isStatusType(RhsExpr->getType())) {
auto *LhsStatusLoc = Env.get<RecordStorageLocation>(*LhsExpr);
if (LhsStatusLoc == nullptr)
return nullptr;

auto *RhsStatusLoc = Env.get<RecordStorageLocation>(*RhsExpr);
if (RhsStatusLoc == nullptr)
return nullptr;

return evaluateStatusEquality(*LhsStatusLoc, *RhsStatusLoc, Env);
}
return nullptr;
}

static void transferComparisonOperator(const CXXOperatorCallExpr *Expr,
LatticeTransferState &State,
bool IsNegative) {
auto *LhsAndRhsVal =
evaluateEquality(Expr->getArg(0), Expr->getArg(1), State.Env);
if (LhsAndRhsVal == nullptr)
return;

if (IsNegative)
State.Env.setValue(*Expr, State.Env.makeNot(*LhsAndRhsVal));
else
State.Env.setValue(*Expr, *LhsAndRhsVal);
}

CFGMatchSwitch<LatticeTransferState>
buildTransferMatchSwitch(ASTContext &Ctx,
CFGMatchSwitchBuilder<LatticeTransferState> Builder) {
Expand All @@ -325,6 +433,20 @@ buildTransferMatchSwitch(ASTContext &Ctx,
transferStatusOkCall)
.CaseOfCFGStmt<CXXMemberCallExpr>(isStatusMemberCallWithName("Update"),
transferStatusUpdateCall)
.CaseOfCFGStmt<CXXOperatorCallExpr>(
isComparisonOperatorCall("=="),
[](const CXXOperatorCallExpr *Expr, const MatchFinder::MatchResult &,
LatticeTransferState &State) {
transferComparisonOperator(Expr, State,
/*IsNegative=*/false);
})
.CaseOfCFGStmt<CXXOperatorCallExpr>(
isComparisonOperatorCall("!="),
[](const CXXOperatorCallExpr *Expr, const MatchFinder::MatchResult &,
LatticeTransferState &State) {
transferComparisonOperator(Expr, State,
/*IsNegative=*/true);
})
.Build();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2614,6 +2614,233 @@ TEST_P(UncheckedStatusOrAccessModelTest, StatusUpdate) {
)cc");
}

TEST_P(UncheckedStatusOrAccessModelTest, EqualityCheck) {
ExpectDiagnosticsFor(
R"cc(
#include "unchecked_statusor_access_test_defs.h"

void target(STATUSOR_INT x, STATUSOR_INT y) {
if (x.ok()) {
if (x == y)
y.value();
else
y.value(); // [[unsafe]]
}
}
)cc");
ExpectDiagnosticsFor(
R"cc(
#include "unchecked_statusor_access_test_defs.h"

void target(STATUSOR_INT x, STATUSOR_INT y) {
if (x.ok()) {
if (y == x)
y.value();
else
y.value(); // [[unsafe]]
}
}
)cc");
ExpectDiagnosticsFor(R"cc(
#include "unchecked_statusor_access_test_defs.h"

void target(STATUSOR_INT x, STATUSOR_INT y) {
if (x.ok()) {
if (x != y)
y.value(); // [[unsafe]]
else
y.value();
}
}
)cc");
ExpectDiagnosticsFor(R"cc(
#include "unchecked_statusor_access_test_defs.h"

void target(STATUSOR_INT x, STATUSOR_INT y) {
if (x.ok()) {
if (y != x)
y.value(); // [[unsafe]]
else
y.value();
}
}
)cc");
ExpectDiagnosticsFor(R"cc(
#include "unchecked_statusor_access_test_defs.h"

void target(STATUSOR_INT x, STATUSOR_INT y) {
if (x.ok()) {
if (!(x == y))
y.value(); // [[unsafe]]
else
y.value();
}
}
)cc");
ExpectDiagnosticsFor(
R"cc(
#include "unchecked_statusor_access_test_defs.h"

void target(STATUSOR_INT x, STATUSOR_INT y) {
if (x.ok()) {
if (!(x != y))
y.value();
else
y.value(); // [[unsafe]]
}
}
)cc");
ExpectDiagnosticsFor(R"cc(
#include "unchecked_statusor_access_test_defs.h"

void target(STATUSOR_INT x, STATUSOR_INT y) {
if (x == y)
if (x.ok()) y.value();
}
)cc");
ExpectDiagnosticsFor(
R"cc(
#include "unchecked_statusor_access_test_defs.h"

void target(STATUSOR_INT x, STATUSOR_INT y) {
if (x.ok()) {
if (x.status() == y.status())
y.value();
else
y.value(); // [[unsafe]]
}
}
)cc");
ExpectDiagnosticsFor(R"cc(
#include "unchecked_statusor_access_test_defs.h"

void target(STATUSOR_INT x, STATUSOR_INT y) {
if (x.ok()) {
if (x.status() != y.status())
y.value(); // [[unsafe]]
else
y.value();
}
}
)cc");
ExpectDiagnosticsFor(
R"cc(
#include "unchecked_statusor_access_test_defs.h"

void target(STATUSOR_INT x, STATUSOR_INT y) {
if (x.ok()) {
if (x.ok() == y.ok())
y.value();
else
y.value(); // [[unsafe]]
}
}
)cc");
ExpectDiagnosticsFor(R"cc(
#include "unchecked_statusor_access_test_defs.h"

void target(STATUSOR_INT x, STATUSOR_INT y) {
if (x.ok()) {
if (x.ok() != y.ok())
y.value(); // [[unsafe]]
else
y.value();
}
}
)cc");
ExpectDiagnosticsFor(
R"cc(
#include "unchecked_statusor_access_test_defs.h"

void target(STATUSOR_INT x, STATUSOR_INT y) {
if (x.ok()) {
if (x.status().ok() == y.status().ok())
y.value();
else
y.value(); // [[unsafe]]
}
}
)cc");
ExpectDiagnosticsFor(R"cc(
#include "unchecked_statusor_access_test_defs.h"

void target(STATUSOR_INT x, STATUSOR_INT y) {
if (x.ok()) {
if (x.status().ok() != y.status().ok())
y.value(); // [[unsafe]]
else
y.value();
}
}
)cc");
ExpectDiagnosticsFor(
R"cc(
#include "unchecked_statusor_access_test_defs.h"

void target(STATUSOR_INT x, STATUSOR_INT y) {
if (x.ok()) {
if (x.status().ok() == y.ok())
y.value();
else
y.value(); // [[unsafe]]
}
}
)cc");
ExpectDiagnosticsFor(R"cc(
#include "unchecked_statusor_access_test_defs.h"

void target(STATUSOR_INT x, STATUSOR_INT y) {
if (x.ok()) {
if (x.status().ok() != y.ok())
y.value(); // [[unsafe]]
else
y.value();
}
}
)cc");
ExpectDiagnosticsFor(R"cc(
#include "unchecked_statusor_access_test_defs.h"

void target(bool b, STATUSOR_INT sor) {
if (sor.ok() == b) {
if (b) sor.value();
}
}
)cc");
ExpectDiagnosticsFor(R"cc(
#include "unchecked_statusor_access_test_defs.h"

void target(STATUSOR_INT sor) {
if (sor.ok() == true) sor.value();
}
)cc");
ExpectDiagnosticsFor(R"cc(
#include "unchecked_statusor_access_test_defs.h"

void target(STATUSOR_INT sor) {
if (sor.ok() == false) sor.value(); // [[unsafe]]
}
)cc");
ExpectDiagnosticsFor(R"cc(
#include "unchecked_statusor_access_test_defs.h"

void target(bool b) {
STATUSOR_INT sor1;
STATUSOR_INT sor2 = Make<STATUSOR_INT>();
if (sor1 == sor2) sor2.value(); // [[unsafe]]
}
)cc");
ExpectDiagnosticsFor(R"cc(
#include "unchecked_statusor_access_test_defs.h"

void target(bool b) {
STATUSOR_INT sor1 = Make<STATUSOR_INT>();
STATUSOR_INT sor2;
if (sor1 == sor2) sor1.value(); // [[unsafe]]
}
)cc");
}

} // namespace

std::string
Expand Down
Loading