Skip to content

Commit c7810cb

Browse files
fmayeraidint
authored andcommitted
[FlowSensitive] [StatusOr] [12/N] Add support for smart pointers (llvm#170943)
1 parent e3905a4 commit c7810cb

File tree

6 files changed

+203
-6
lines changed

6 files changed

+203
-6
lines changed

clang/lib/Analysis/FlowSensitive/Models/UncheckedStatusOrAccessModel.cpp

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#include "clang/Analysis/FlowSensitive/DataflowEnvironment.h"
2626
#include "clang/Analysis/FlowSensitive/MatchSwitch.h"
2727
#include "clang/Analysis/FlowSensitive/RecordOps.h"
28+
#include "clang/Analysis/FlowSensitive/SmartPointerAccessorCaching.h"
2829
#include "clang/Analysis/FlowSensitive/StorageLocation.h"
2930
#include "clang/Analysis/FlowSensitive/Value.h"
3031
#include "clang/Basic/LLVM.h"
@@ -849,6 +850,16 @@ transferNonConstMemberOperatorCall(const CXXOperatorCallExpr *Expr,
849850
handleNonConstMemberCall(Expr, RecordLoc, Result, State);
850851
}
851852

853+
static RecordStorageLocation *
854+
getSmartPtrLikeStorageLocation(const Expr &E, const Environment &Env) {
855+
if (!E.isPRValue())
856+
return dyn_cast_or_null<RecordStorageLocation>(Env.getStorageLocation(E));
857+
if (auto *PointerVal = dyn_cast_or_null<PointerValue>(Env.getValue(E)))
858+
return dyn_cast_or_null<RecordStorageLocation>(
859+
&PointerVal->getPointeeLoc());
860+
return nullptr;
861+
}
862+
852863
CFGMatchSwitch<LatticeTransferState>
853864
buildTransferMatchSwitch(ASTContext &Ctx,
854865
CFGMatchSwitchBuilder<LatticeTransferState> Builder) {
@@ -906,6 +917,43 @@ buildTransferMatchSwitch(ASTContext &Ctx,
906917
transferLoggingGetReferenceableValueCall)
907918
.CaseOfCFGStmt<CallExpr>(isLoggingCheckEqImpl(),
908919
transferLoggingCheckEqImpl)
920+
// This needs to go before the const accessor call matcher, because these
921+
// look like them, but we model `operator`* and `get` to return the same
922+
// object. Also, we model them for non-const cases.
923+
.CaseOfCFGStmt<CXXOperatorCallExpr>(
924+
isPointerLikeOperatorStar(),
925+
[](const CXXOperatorCallExpr *E,
926+
const MatchFinder::MatchResult &Result,
927+
LatticeTransferState &State) {
928+
transferSmartPointerLikeCachedDeref(
929+
E, getSmartPtrLikeStorageLocation(*E->getArg(0), State.Env),
930+
State, [](StorageLocation &Loc) {});
931+
})
932+
.CaseOfCFGStmt<CXXOperatorCallExpr>(
933+
isPointerLikeOperatorArrow(),
934+
[](const CXXOperatorCallExpr *E,
935+
const MatchFinder::MatchResult &Result,
936+
LatticeTransferState &State) {
937+
transferSmartPointerLikeCachedGet(
938+
E, getSmartPtrLikeStorageLocation(*E->getArg(0), State.Env),
939+
State, [](StorageLocation &Loc) {});
940+
})
941+
.CaseOfCFGStmt<CXXMemberCallExpr>(
942+
isSmartPointerLikeValueMethodCall(),
943+
[](const CXXMemberCallExpr *E, const MatchFinder::MatchResult &Result,
944+
LatticeTransferState &State) {
945+
transferSmartPointerLikeCachedDeref(
946+
E, getImplicitObjectLocation(*E, State.Env), State,
947+
[](StorageLocation &Loc) {});
948+
})
949+
.CaseOfCFGStmt<CXXMemberCallExpr>(
950+
isSmartPointerLikeGetMethodCall(),
951+
[](const CXXMemberCallExpr *E, const MatchFinder::MatchResult &Result,
952+
LatticeTransferState &State) {
953+
transferSmartPointerLikeCachedGet(
954+
E, getImplicitObjectLocation(*E, State.Env), State,
955+
[](StorageLocation &Loc) {});
956+
})
909957
// const accessor calls
910958
.CaseOfCFGStmt<CXXMemberCallExpr>(isConstStatusOrAccessorMemberCall(),
911959
transferConstStatusOrAccessorMemberCall)

clang/unittests/Analysis/FlowSensitive/UncheckedStatusOrAccessModelTestFixture.cpp

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3443,6 +3443,79 @@ TEST_P(UncheckedStatusOrAccessModelTest, AccessorCall) {
34433443
)cc");
34443444
}
34453445

3446+
TEST_P(UncheckedStatusOrAccessModelTest, PointerLike) {
3447+
ExpectDiagnosticsFor(R"cc(
3448+
#include "unchecked_statusor_access_test_defs.h"
3449+
3450+
class Foo {
3451+
public:
3452+
std::pair<int, STATUSOR_VOIDPTR>& operator*() const;
3453+
std::pair<int, STATUSOR_VOIDPTR>* operator->() const;
3454+
bool operator!=(const Foo& other) const;
3455+
};
3456+
3457+
void target() {
3458+
Foo foo;
3459+
if (foo->second.ok() && *foo->second != nullptr) {
3460+
*foo->second;
3461+
(*foo).second.value();
3462+
}
3463+
}
3464+
)cc");
3465+
ExpectDiagnosticsFor(R"cc(
3466+
#include "unchecked_statusor_access_test_defs.h"
3467+
3468+
class Foo {
3469+
public:
3470+
std::pair<int, STATUSOR_INT>& operator*() const;
3471+
std::pair<int, STATUSOR_INT>* operator->() const;
3472+
};
3473+
void target() {
3474+
Foo foo;
3475+
if (!foo->second.ok()) return;
3476+
foo->second.value();
3477+
(*foo).second.value();
3478+
}
3479+
)cc");
3480+
ExpectDiagnosticsFor(R"cc(
3481+
#include "unchecked_statusor_access_test_defs.h"
3482+
3483+
void target(std::pair<int, STATUSOR_VOIDPTR>* foo) {
3484+
if (foo->second.ok() && *foo->second != nullptr) {
3485+
*foo->second;
3486+
(*foo).second.value();
3487+
}
3488+
}
3489+
)cc");
3490+
}
3491+
3492+
TEST_P(UncheckedStatusOrAccessModelTest, UniquePtr) {
3493+
ExpectDiagnosticsFor(
3494+
R"cc(
3495+
#include "unchecked_statusor_access_test_defs.h"
3496+
3497+
void target() {
3498+
auto sor_up = Make<std::unique_ptr<STATUSOR_INT>>();
3499+
if (sor_up->ok()) sor_up->value();
3500+
}
3501+
)cc");
3502+
}
3503+
3504+
TEST_P(UncheckedStatusOrAccessModelTest, UniquePtrReset) {
3505+
ExpectDiagnosticsFor(
3506+
R"cc(
3507+
#include "unchecked_statusor_access_test_defs.h"
3508+
3509+
void target() {
3510+
auto sor_up = Make<std::unique_ptr<STATUSOR_INT>>();
3511+
if (sor_up->ok()) {
3512+
sor_up.reset(Make<STATUSOR_INT*>());
3513+
sor_up->value(); // [[unsafe]]
3514+
}
3515+
}
3516+
)cc");
3517+
}
3518+
34463519
} // namespace
34473520

34483521
std::string
@@ -3492,6 +3565,7 @@ GetHeaders(UncheckedStatusOrAccessModelTestAliasKind AliasKind) {
34923565
#include "std_pair.h"
34933566
#include "absl_log.h"
34943567
#include "testing_defs.h"
3568+
#include "std_unique_ptr.h"
34953569
34963570
template <typename T>
34973571
T Make();

mlir/include/mlir/IR/Interfaces.td

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,16 @@ class StaticInterfaceMethod<string desc, string retTy, string methodName,
8585
: InterfaceMethod<desc, retTy, methodName, args, methodBody,
8686
defaultImplementation>;
8787

88+
// This class represents a pure virtual interface method.
89+
class PureVirtualInterfaceMethod<string desc, string retTy, string methodName,
90+
dag args = (ins)>
91+
: InterfaceMethod<desc, retTy, methodName, args>;
92+
93+
// This class represents a interface method declaration.
94+
class InterfaceMethodDeclaration<string desc, string retTy, string methodName,
95+
dag args = (ins)>
96+
: InterfaceMethod<desc, retTy, methodName, args>;
97+
8898
// Interface represents a base interface.
8999
class Interface<string name, list<Interface> baseInterfacesArg = []> {
90100
// A human-readable description of what this interface does.
@@ -147,9 +157,17 @@ class TypeInterface<string name, list<Interface> baseInterfaces = []>
147157
!if(!empty(cppNamespace),"", cppNamespace # "::") # name
148158
>;
149159

160+
// AliasDeclaration represents an Alias Declaration in a Dialect Interface
161+
class AliasDeclaration<string alias, string typeId> {
162+
string name = alias;
163+
string aliased = typeId;
164+
}
165+
150166
// DialectInterface represents a Dialect Interface.
151167
class DialectInterface<string name, list<Interface> baseInterfaces = []>
152-
: Interface<name, baseInterfaces>, OpInterfaceTrait<name>;
168+
: Interface<name, baseInterfaces>, OpInterfaceTrait<name> {
169+
list<AliasDeclaration> aliasDeclarations = [];
170+
}
153171

154172

155173
// Whether to declare the interface methods in the user entity's header. This

mlir/include/mlir/TableGen/Interfaces.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,12 @@ class InterfaceMethod {
4545

4646
// Return if this method is static.
4747
bool isStatic() const;
48+
49+
// Return if the method is a pure virtual one.
50+
bool isPureVirtual() const;
51+
52+
// Return if the method is only a declaration.
53+
bool isDeclaration() const;
4854

4955
// Return the body for this method if it has one.
5056
std::optional<StringRef> getBody() const;
@@ -161,6 +167,9 @@ struct TypeInterface : public Interface {
161167
struct DialectInterface : public Interface {
162168
using Interface::Interface;
163169

170+
// Return alias declarations
171+
SmallVector<std::pair<StringRef, StringRef>> getAliasDeclarations() const;
172+
164173
static bool classof(const Interface *interface);
165174
};
166175

mlir/lib/TableGen/Interfaces.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include "llvm/ADT/StringSet.h"
1212
#include "llvm/TableGen/Error.h"
1313
#include "llvm/TableGen/Record.h"
14+
#include <utility>
1415

1516
using namespace mlir;
1617
using namespace mlir::tblgen;
@@ -51,6 +52,16 @@ bool InterfaceMethod::isStatic() const {
5152
return def->isSubClassOf("StaticInterfaceMethod");
5253
}
5354

55+
// Return if the method is a pure virtual one.
56+
bool InterfaceMethod::isPureVirtual() const {
57+
return def->isSubClassOf("PureVirtualInterfaceMethod");
58+
}
59+
60+
// Return if the method is only a declaration.
61+
bool InterfaceMethod::isDeclaration() const {
62+
return def->isSubClassOf("InterfaceMethodDeclaration");
63+
}
64+
5465
// Return the body for this method if it has one.
5566
std::optional<StringRef> InterfaceMethod::getBody() const {
5667
// Trim leading and trailing spaces from the default implementation.
@@ -216,3 +227,16 @@ bool TypeInterface::classof(const Interface *interface) {
216227
bool DialectInterface::classof(const Interface *interface) {
217228
return interface->getDef().isSubClassOf("DialectInterface");
218229
}
230+
231+
// Return the interfaces extra class declaration code.
232+
SmallVector<std::pair<StringRef, StringRef>> DialectInterface::getAliasDeclarations() const {
233+
SmallVector<std::pair<StringRef, StringRef>, 1> aliasDeclarations;
234+
235+
for (auto &aliasDef : getDef().getValueAsListOfDefs("aliasDeclarations")) {
236+
auto alias = aliasDef->getValueAsString("name");
237+
auto typeId = aliasDef->getValueAsString("aliased");
238+
aliasDeclarations.push_back(std::make_pair(alias, typeId));
239+
}
240+
return aliasDeclarations;
241+
}
242+

mlir/tools/mlir-tblgen/DialectInterfacesGen.cpp

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
using namespace mlir;
2727
using llvm::Record;
2828
using llvm::RecordKeeper;
29-
using mlir::tblgen::Interface;
29+
using mlir::tblgen::DialectInterface;
3030
using mlir::tblgen::InterfaceMethod;
3131

3232
/// Emit a string corresponding to a C++ type, followed by a space if necessary.
@@ -74,7 +74,7 @@ class DialectInterfaceGenerator {
7474
bool emitInterfaceDecls();
7575

7676
protected:
77-
void emitInterfaceDecl(const Interface &interface);
77+
void emitInterfaceDecl(const DialectInterface &interface);
7878

7979
/// The set of interface records to emit.
8080
std::vector<const Record *> defs;
@@ -93,7 +93,7 @@ static void emitInterfaceMethodDoc(const InterfaceMethod &method,
9393
tblgen::emitDescriptionComment(*description, os, prefix);
9494
}
9595

96-
static void emitInterfaceMethodsDef(const Interface &interface,
96+
static void emitInterfaceMethodsDef(const DialectInterface &interface,
9797
raw_ostream &os) {
9898

9999
raw_indented_ostream ios(os);
@@ -104,6 +104,18 @@ static void emitInterfaceMethodsDef(const Interface &interface,
104104
ios << "virtual ";
105105
emitCPPType(method.getReturnType(), ios);
106106
emitMethodNameAndArgs(method, method.getName(), ios);
107+
108+
if (method.isDeclaration()) {
109+
ios << ";\n";
110+
continue;
111+
}
112+
113+
if (method.isPureVirtual()) {
114+
ios << " = 0;\n";
115+
continue;
116+
}
117+
118+
// Otherwise it's a normal interface method
107119
ios << " {";
108120

109121
if (auto body = method.getBody()) {
@@ -116,7 +128,17 @@ static void emitInterfaceMethodsDef(const Interface &interface,
116128
}
117129
}
118130

119-
void DialectInterfaceGenerator::emitInterfaceDecl(const Interface &interface) {
131+
static void emitInterfaceAliasDeclarations(const DialectInterface &interface, raw_ostream &os) {
132+
raw_indented_ostream ios(os);
133+
ios.indent(2);
134+
135+
for (auto [alias, typeId] : interface.getAliasDeclarations()) {
136+
ios << "using " << alias << " = " << typeId << ";\n";
137+
}
138+
139+
}
140+
141+
void DialectInterfaceGenerator::emitInterfaceDecl(const DialectInterface &interface) {
120142
llvm::NamespaceEmitter ns(os, interface.getCppNamespace());
121143

122144
StringRef interfaceName = interface.getName();
@@ -131,6 +153,8 @@ void DialectInterfaceGenerator::emitInterfaceDecl(const Interface &interface) {
131153
" {0}(::mlir::Dialect *dialect) : Base(dialect) {{}\n",
132154
interfaceName);
133155

156+
emitInterfaceAliasDeclarations(interface, os);
157+
134158
emitInterfaceMethodsDef(interface, os);
135159

136160
os << "};\n";
@@ -148,7 +172,7 @@ bool DialectInterfaceGenerator::emitInterfaceDecls() {
148172
});
149173

150174
for (const Record *def : sortedDefs)
151-
emitInterfaceDecl(Interface(def));
175+
emitInterfaceDecl(DialectInterface(def));
152176

153177
return false;
154178
}

0 commit comments

Comments
 (0)