Skip to content

Commit a22916b

Browse files
committed
[clang][nfc] Define ConstRecursiveASTVisitor twin of RecursiveASTVisitor
Downstream whenever we reach out for a RecursiveASTVisitor we always have to add a few const_casts to shoe it in. This NFC patch introduces a const version of the same CRTP class. To reduce code duplication, I factored out all the common logic (which is all of it) into `RecursiveASTVisitorBase` and made `RecursiveASTVisitor` and `ConstRecursiveASTVisitor` essentially the two instances of it, that you should use depending on whether you want to modify AST in your visitor. This is very similar to the DynamicRecursiveASTVisitor structure. One point of difference is that instead of type aliases I use inheritance to reduce the diff of this change because templated alias is not accepted in the implementation forwarding of the overridden member functions: `return RecursiveASTVisitor::TraverseStmt(S);` works only if `RecursiveASTVisitor` is defined as a derived class of `RecursiveASTVisitorBase` and not as a parametric alias. This was not an issue for DynamicRecursiveASTVisitor because it is not parametrised bythe `Derived` type. Unfortunately, I did not manager to maintain a full backwards compatibility when it comes to the `friend` declarations, you have to befriend the `RecursiveASTVisitorBase` and not `RecursiveASTVisitor`. Moreover, the error message is not obvious, as it speaks of the member function being private and does not point to the `friend` declaration.
1 parent 3cc1b7c commit a22916b

File tree

8 files changed

+677
-511
lines changed

8 files changed

+677
-511
lines changed

clang-tools-extra/clang-tidy/modernize/LoopConvertUtils.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ class StmtAncestorASTVisitor
7272
/// Accessor for DeclParents.
7373
const DeclParentMap &getDeclToParentStmtMap() { return DeclParents; }
7474

75-
friend class clang::RecursiveASTVisitor<StmtAncestorASTVisitor>;
75+
friend class clang::RecursiveASTVisitorBase<StmtAncestorASTVisitor, /*Const=*/false>;
7676

7777
private:
7878
StmtParentMap StmtAncestors;
@@ -98,7 +98,7 @@ class ComponentFinderASTVisitor
9898
/// Accessor for Components.
9999
const ComponentVector &getComponents() { return Components; }
100100

101-
friend class clang::RecursiveASTVisitor<ComponentFinderASTVisitor>;
101+
friend class clang::RecursiveASTVisitorBase<ComponentFinderASTVisitor, /*Const=*/false>;
102102

103103
private:
104104
ComponentVector Components;
@@ -155,7 +155,7 @@ class DependencyFinderASTVisitor
155155
return DependsOnInsideVariable;
156156
}
157157

158-
friend class clang::RecursiveASTVisitor<DependencyFinderASTVisitor>;
158+
friend class clang::RecursiveASTVisitorBase<DependencyFinderASTVisitor, /*Const=*/false>;
159159

160160
private:
161161
const StmtParentMap *StmtParents;
@@ -188,7 +188,7 @@ class DeclFinderASTVisitor
188188
return Found;
189189
}
190190

191-
friend class clang::RecursiveASTVisitor<DeclFinderASTVisitor>;
191+
friend class clang::RecursiveASTVisitorBase<DeclFinderASTVisitor, /*Const=*/false>;
192192

193193
private:
194194
std::string Name;
@@ -340,7 +340,7 @@ class ForLoopIndexUseVisitor
340340
private:
341341
/// Typedef used in CRTP functions.
342342
using VisitorBase = RecursiveASTVisitor<ForLoopIndexUseVisitor>;
343-
friend class RecursiveASTVisitor<ForLoopIndexUseVisitor>;
343+
friend class RecursiveASTVisitorBase<ForLoopIndexUseVisitor, /*Const=*/false>;
344344

345345
/// Overriden methods for RecursiveASTVisitor's traversal.
346346
bool TraverseArraySubscriptExpr(ArraySubscriptExpr *E);

clang-tools-extra/clang-tidy/modernize/PassByValueCheck.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ static bool paramReferredExactlyOnce(const CXXConstructorDecl *Ctor,
9595
/// \see ExactlyOneUsageVisitor::hasExactlyOneUsageIn()
9696
class ExactlyOneUsageVisitor
9797
: public RecursiveASTVisitor<ExactlyOneUsageVisitor> {
98-
friend class RecursiveASTVisitor<ExactlyOneUsageVisitor>;
98+
friend class RecursiveASTVisitorBase<ExactlyOneUsageVisitor, /*Const=*/false>;
9999

100100
public:
101101
ExactlyOneUsageVisitor(const ParmVarDecl *ParamDecl)

clang/include/clang/AST/RecursiveASTVisitor.h

Lines changed: 528 additions & 492 deletions
Large diffs are not rendered by default.

clang/include/clang/AST/StmtOpenACC.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,6 @@ class OpenACCConstructStmt : public Stmt {
8181
class OpenACCAssociatedStmtConstruct : public OpenACCConstructStmt {
8282
friend class ASTStmtWriter;
8383
friend class ASTStmtReader;
84-
template <typename Derived> friend class RecursiveASTVisitor;
8584
Stmt *AssociatedStmt = nullptr;
8685

8786
protected:

clang/lib/AST/ParentMapContext.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -363,7 +363,7 @@ class ParentMapContext::ParentMap::ASTVisitor
363363
ASTVisitor(ParentMap &Map) : Map(Map) {}
364364

365365
private:
366-
friend class RecursiveASTVisitor<ASTVisitor>;
366+
friend class RecursiveASTVisitorBase<ASTVisitor, false>;
367367

368368
using VisitorBase = RecursiveASTVisitor<ASTVisitor>;
369369

clang/lib/Index/IndexBody.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -508,7 +508,7 @@ class BodyIndexer : public RecursiveASTVisitor<BodyIndexer> {
508508
bool TraverseTypeConstraint(const TypeConstraint *C) {
509509
IndexCtx.handleReference(C->getNamedConcept(), C->getConceptNameLoc(),
510510
Parent, ParentDC);
511-
return RecursiveASTVisitor::TraverseTypeConstraint(C);
511+
return base::TraverseTypeConstraint(C);
512512
}
513513
};
514514

clang/unittests/AST/RecursiveASTVisitorTest.cpp

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,82 @@ std::vector<VisitEvent> collectEvents(llvm::StringRef Code,
143143
Code, FileName);
144144
return std::move(Visitor).takeEvents();
145145
}
146+
class ConstCollectInterestingEvents
147+
: public ConstRecursiveASTVisitor<ConstCollectInterestingEvents> {
148+
public:
149+
bool TraverseFunctionDecl(const FunctionDecl *D) {
150+
Events.push_back(VisitEvent::StartTraverseFunction);
151+
bool Ret = ConstRecursiveASTVisitor::TraverseFunctionDecl(D);
152+
Events.push_back(VisitEvent::EndTraverseFunction);
153+
154+
return Ret;
155+
}
156+
157+
bool TraverseAttr(const Attr *A) {
158+
Events.push_back(VisitEvent::StartTraverseAttr);
159+
bool Ret = ConstRecursiveASTVisitor::TraverseAttr(A);
160+
Events.push_back(VisitEvent::EndTraverseAttr);
161+
162+
return Ret;
163+
}
164+
165+
bool TraverseEnumDecl(const EnumDecl *D) {
166+
Events.push_back(VisitEvent::StartTraverseEnum);
167+
bool Ret = ConstRecursiveASTVisitor::TraverseEnumDecl(D);
168+
Events.push_back(VisitEvent::EndTraverseEnum);
169+
170+
return Ret;
171+
}
172+
173+
bool TraverseTypedefTypeLoc(TypedefTypeLoc TL, bool TraverseQualifier) {
174+
Events.push_back(VisitEvent::StartTraverseTypedefType);
175+
bool Ret =
176+
ConstRecursiveASTVisitor::TraverseTypedefTypeLoc(TL, TraverseQualifier);
177+
Events.push_back(VisitEvent::EndTraverseTypedefType);
178+
179+
return Ret;
180+
}
181+
182+
bool TraverseObjCInterfaceDecl(const ObjCInterfaceDecl *ID) {
183+
Events.push_back(VisitEvent::StartTraverseObjCInterface);
184+
bool Ret = ConstRecursiveASTVisitor::TraverseObjCInterfaceDecl(ID);
185+
Events.push_back(VisitEvent::EndTraverseObjCInterface);
186+
187+
return Ret;
188+
}
189+
190+
bool TraverseObjCProtocolDecl(const ObjCProtocolDecl *PD) {
191+
Events.push_back(VisitEvent::StartTraverseObjCProtocol);
192+
bool Ret = ConstRecursiveASTVisitor::TraverseObjCProtocolDecl(PD);
193+
Events.push_back(VisitEvent::EndTraverseObjCProtocol);
194+
195+
return Ret;
196+
}
197+
198+
bool TraverseObjCProtocolLoc(ObjCProtocolLoc ProtocolLoc) {
199+
Events.push_back(VisitEvent::StartTraverseObjCProtocolLoc);
200+
bool Ret = ConstRecursiveASTVisitor::TraverseObjCProtocolLoc(ProtocolLoc);
201+
Events.push_back(VisitEvent::EndTraverseObjCProtocolLoc);
202+
203+
return Ret;
204+
}
205+
206+
std::vector<VisitEvent> takeEvents() && { return std::move(Events); }
207+
208+
private:
209+
std::vector<VisitEvent> Events;
210+
};
211+
212+
std::vector<VisitEvent> collectConstEvents(llvm::StringRef Code,
213+
const Twine &FileName = "input.cc") {
214+
ConstCollectInterestingEvents Visitor;
215+
clang::tooling::runToolOnCode(
216+
std::make_unique<ProcessASTAction>(
217+
[&](const clang::ASTContext &Ctx) { Visitor.TraverseAST(Ctx); }),
218+
Code, FileName);
219+
return std::move(Visitor).takeEvents();
220+
}
221+
146222
} // namespace
147223

148224
TEST(RecursiveASTVisitorTest, AttributesInsideDecls) {
@@ -151,6 +227,7 @@ TEST(RecursiveASTVisitorTest, AttributesInsideDecls) {
151227
__attribute__((annotate("something"))) int foo() { return 10; }
152228
)cpp";
153229

230+
EXPECT_EQ(collectEvents(Code), collectConstEvents(Code));
154231
EXPECT_THAT(collectEvents(Code),
155232
ElementsAre(VisitEvent::StartTraverseFunction,
156233
VisitEvent::StartTraverseAttr,
@@ -165,6 +242,7 @@ TEST(RecursiveASTVisitorTest, EnumDeclWithBase) {
165242
enum Bar : Foo;
166243
)cpp";
167244

245+
EXPECT_EQ(collectEvents(Code), collectConstEvents(Code));
168246
EXPECT_THAT(collectEvents(Code),
169247
ElementsAre(VisitEvent::StartTraverseEnum,
170248
VisitEvent::StartTraverseTypedefType,
@@ -184,6 +262,7 @@ TEST(RecursiveASTVisitorTest, InterfaceDeclWithProtocols) {
184262
@end
185263
)cpp";
186264

265+
EXPECT_EQ(collectEvents(Code), collectConstEvents(Code));
187266
EXPECT_THAT(collectEvents(Code, "input.m"),
188267
ElementsAre(VisitEvent::StartTraverseObjCProtocol,
189268
VisitEvent::EndTraverseObjCProtocol,
@@ -196,3 +275,53 @@ TEST(RecursiveASTVisitorTest, InterfaceDeclWithProtocols) {
196275
VisitEvent::EndTraverseObjCProtocolLoc,
197276
VisitEvent::EndTraverseObjCInterface));
198277
}
278+
279+
TEST(ConstRecursiveASTVisitorTest, ConstCorrectness) {
280+
// This test verifies that ConstRecursiveASTVisitor properly enforces
281+
// const-correctness.
282+
// The derived class defines const versions of the Visit* methods,
283+
// and they should correctly override the default implementations,
284+
// which is demonstrated by non-0 counters.
285+
286+
class ConstCorrectnessValidator
287+
: public ConstRecursiveASTVisitor<ConstCorrectnessValidator> {
288+
public:
289+
bool VisitFunctionDecl(const FunctionDecl *D) {
290+
FunctionDeclCount++;
291+
return true;
292+
}
293+
294+
bool VisitStmt(const Stmt *S) {
295+
StmtCount++;
296+
return true;
297+
}
298+
299+
int getFunctionDeclCount() const { return FunctionDeclCount; }
300+
int getStmtCount() const { return StmtCount; }
301+
302+
private:
303+
int FunctionDeclCount = 0;
304+
int StmtCount = 0;
305+
};
306+
307+
llvm::StringRef Code = R"cpp(
308+
int foo() {
309+
return 42;
310+
}
311+
void bar() {
312+
int x = 0;
313+
x += 2;
314+
}
315+
)cpp";
316+
317+
ConstCorrectnessValidator Visitor;
318+
clang::tooling::runToolOnCode(
319+
std::make_unique<ProcessASTAction>(
320+
[&](clang::ASTContext &Ctx) { Visitor.TraverseAST(Ctx); }),
321+
Code);
322+
323+
// Verify that the visitor found the expected number of nodes
324+
EXPECT_EQ(Visitor.getFunctionDeclCount(), 2); // foo and bar
325+
// There are at least 3 statements: return 42; int x = 0; x += 2;
326+
EXPECT_GE(Visitor.getStmtCount(), 3);
327+
}

clang/utils/TableGen/ClangAttrEmitter.cpp

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3983,10 +3983,10 @@ void EmitClangAttrASTVisitor(const RecordKeeper &Records, raw_ostream &OS) {
39833983
const Record &R = *Attr;
39843984
if (!R.getValueAsBit("ASTNode"))
39853985
continue;
3986-
OS << " bool Traverse"
3987-
<< R.getName() << "Attr(" << R.getName() << "Attr *A);\n";
3988-
OS << " bool Visit"
3989-
<< R.getName() << "Attr(" << R.getName() << "Attr *A) {\n"
3986+
OS << " bool Traverse" << R.getName()
3987+
<< "Attr(MaybeConst<" << R.getName() << "Attr> *A);\n";
3988+
OS << " bool Visit" << R.getName()
3989+
<< "Attr(MaybeConst<" << R.getName() << "Attr> *A) {\n"
39903990
<< " return true; \n"
39913991
<< " }\n";
39923992
}
@@ -3998,9 +3998,9 @@ void EmitClangAttrASTVisitor(const RecordKeeper &Records, raw_ostream &OS) {
39983998
if (!R.getValueAsBit("ASTNode"))
39993999
continue;
40004000

4001-
OS << "template <typename Derived>\n"
4002-
<< "bool VISITORCLASS<Derived>::Traverse"
4003-
<< R.getName() << "Attr(" << R.getName() << "Attr *A) {\n"
4001+
OS << "template <typename Derived, bool IsConst>\n"
4002+
<< "bool VISITORCLASS<Derived, IsConst>::Traverse" << R.getName()
4003+
<< "Attr(MaybeConst<" << R.getName() << "Attr> *A) {\n"
40044004
<< " if (!getDerived().VisitAttr(A))\n"
40054005
<< " return false;\n"
40064006
<< " if (!getDerived().Visit" << R.getName() << "Attr(A))\n"
@@ -4018,8 +4018,9 @@ void EmitClangAttrASTVisitor(const RecordKeeper &Records, raw_ostream &OS) {
40184018
}
40194019

40204020
// Write generic Traverse routine
4021-
OS << "template <typename Derived>\n"
4022-
<< "bool VISITORCLASS<Derived>::TraverseAttr(Attr *A) {\n"
4021+
OS << "template <typename Derived, bool IsConst>\n"
4022+
<< "bool VISITORCLASS<Derived, IsConst>::TraverseAttr("
4023+
<< "MaybeConst<Attr> *A) {\n"
40234024
<< " if (!A)\n"
40244025
<< " return true;\n"
40254026
<< "\n"
@@ -4032,7 +4033,8 @@ void EmitClangAttrASTVisitor(const RecordKeeper &Records, raw_ostream &OS) {
40324033

40334034
OS << " case attr::" << R.getName() << ":\n"
40344035
<< " return getDerived().Traverse" << R.getName() << "Attr("
4035-
<< "cast<" << R.getName() << "Attr>(A));\n";
4036+
<< "const_cast<MaybeConst<" << R.getName()
4037+
<< "Attr> *>(static_cast<const " << R.getName() << "Attr *>(A)));\n";
40364038
}
40374039
OS << " }\n"; // end switch
40384040
OS << " llvm_unreachable(\"bad attribute kind\");\n";

0 commit comments

Comments
 (0)