Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions clang/include/clang/ASTMatchers/ASTMatchFinder.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,13 @@ class MatchFinder {
///
/// It prints a report after match.
std::optional<Profiling> CheckProfiling;

/// Whether to traverse a Decl. This is relevant for clang modules, as they
/// are imported into the AST, but are actually part of a different TU.
/// It can result in hundreds of milliseconds of additional time to also
/// traverse the AST of these modules, and often for no benefit, as they
/// are frequently already traversed in their own TU.
std::optional<llvm::function_ref<bool(const Decl &)>> ShouldTraverseDecl;
};

MatchFinder(MatchFinderOptions Options = MatchFinderOptions());
Expand Down
3 changes: 2 additions & 1 deletion clang/lib/ASTMatchers/ASTMatchFinder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1443,7 +1443,8 @@ bool MatchASTVisitor::objcClassIsDerivedFrom(
}

bool MatchASTVisitor::TraverseDecl(Decl *DeclNode) {
if (!DeclNode) {
if (!DeclNode || (Options.ShouldTraverseDecl &&
!(*Options.ShouldTraverseDecl)(*DeclNode))) {
return true;
}

Expand Down
23 changes: 17 additions & 6 deletions clang/unittests/ASTMatchers/ASTMatchersTest.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,11 @@ class VerifyMatch : public MatchFinder::MatchCallback {
const std::unique_ptr<BoundNodesCallback> FindResultReviewer;
};

inline ArrayRef<TestLanguage> langCxx11() {
static const TestLanguage Result[] = {Lang_CXX11};
return ArrayRef<TestLanguage>(Result);
}

inline ArrayRef<TestLanguage> langCxx11OrLater() {
static const TestLanguage Result[] = {Lang_CXX11, Lang_CXX14, Lang_CXX17,
Lang_CXX20, Lang_CXX23};
Expand Down Expand Up @@ -91,9 +96,11 @@ testing::AssertionResult matchesConditionally(
const Twine &Code, const T &AMatcher, bool ExpectMatch,
ArrayRef<std::string> CompileArgs,
const FileContentMappings &VirtualMappedFiles = FileContentMappings(),
StringRef Filename = "input.cc") {
StringRef Filename = "input.cc",
MatchFinder::MatchFinderOptions Options =
MatchFinder::MatchFinderOptions()) {
bool Found = false, DynamicFound = false;
MatchFinder Finder;
MatchFinder Finder(Options);
VerifyMatch VerifyFound(nullptr, &Found);
Finder.addMatcher(AMatcher, &VerifyFound);
VerifyMatch VerifyDynamicFound(nullptr, &DynamicFound);
Expand Down Expand Up @@ -147,11 +154,13 @@ testing::AssertionResult matchesConditionally(
template <typename T>
testing::AssertionResult
matchesConditionally(const Twine &Code, const T &AMatcher, bool ExpectMatch,
ArrayRef<TestLanguage> TestLanguages) {
ArrayRef<TestLanguage> TestLanguages,
MatchFinder::MatchFinderOptions Options =
MatchFinder::MatchFinderOptions()) {
for (auto Lang : TestLanguages) {
auto Result = matchesConditionally(
Code, AMatcher, ExpectMatch, getCommandLineArgsForTesting(Lang),
FileContentMappings(), getFilenameForTesting(Lang));
FileContentMappings(), getFilenameForTesting(Lang), Options);
if (!Result)
return Result;
}
Expand All @@ -162,8 +171,10 @@ matchesConditionally(const Twine &Code, const T &AMatcher, bool ExpectMatch,
template <typename T>
testing::AssertionResult
matches(const Twine &Code, const T &AMatcher,
ArrayRef<TestLanguage> TestLanguages = {Lang_CXX11}) {
return matchesConditionally(Code, AMatcher, true, TestLanguages);
ArrayRef<TestLanguage> TestLanguages = {Lang_CXX11},
MatchFinder::MatchFinderOptions Options =
MatchFinder::MatchFinderOptions()) {
return matchesConditionally(Code, AMatcher, true, TestLanguages, Options);
}

template <typename T>
Expand Down
12 changes: 12 additions & 0 deletions clang/unittests/ASTMatchers/ASTMatchersTraversalTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,18 @@ TEST(DeclarationMatcher, hasMethod) {
cxxRecordDecl(hasMethod(isPublic()))));
}

TEST(DeclarationMatcher, shouldTraverse) {
MatchFinder::MatchFinderOptions Options;
Options.ShouldTraverseDecl = [](const Decl &decl) { return true; };
EXPECT_TRUE(matches("class A { void func(); };",
cxxRecordDecl(hasMethod(hasName("func"))), langCxx11(),
Options));
Options.ShouldTraverseDecl = [](const Decl &decl) { return false; };
EXPECT_FALSE(matches("class A { void func(); };",
cxxRecordDecl(hasMethod(hasName("func"))), langCxx11(),
Options));
}

TEST(DeclarationMatcher, ClassDerivedFromDependentTemplateSpecialization) {
EXPECT_TRUE(matches(
"template <typename T> struct A {"
Expand Down