diff --git a/clang/include/clang/ASTMatchers/ASTMatchFinder.h b/clang/include/clang/ASTMatchers/ASTMatchFinder.h index a387d9037b7da..2b161a574d5b6 100644 --- a/clang/include/clang/ASTMatchers/ASTMatchFinder.h +++ b/clang/include/clang/ASTMatchers/ASTMatchFinder.h @@ -139,6 +139,13 @@ class MatchFinder { /// /// It prints a report after match. std::optional CheckProfiling; + + /// Whether to traverse a Decl. This is relevant for clang modules, as they + /// are imported into the AST, but are actually part of a different TU. + /// It can result in hundreds of milliseconds of additional time to also + /// traverse the AST of these modules, and often for no benefit, as they + /// are frequently already traversed in their own TU. + std::optional> ShouldTraverseDecl; }; MatchFinder(MatchFinderOptions Options = MatchFinderOptions()); diff --git a/clang/lib/ASTMatchers/ASTMatchFinder.cpp b/clang/lib/ASTMatchers/ASTMatchFinder.cpp index 3d01a70395a9b..5d2f2065ceba1 100644 --- a/clang/lib/ASTMatchers/ASTMatchFinder.cpp +++ b/clang/lib/ASTMatchers/ASTMatchFinder.cpp @@ -1443,7 +1443,8 @@ bool MatchASTVisitor::objcClassIsDerivedFrom( } bool MatchASTVisitor::TraverseDecl(Decl *DeclNode) { - if (!DeclNode) { + if (!DeclNode || (Options.ShouldTraverseDecl && + !(*Options.ShouldTraverseDecl)(*DeclNode))) { return true; } diff --git a/clang/unittests/ASTMatchers/ASTMatchersTest.h b/clang/unittests/ASTMatchers/ASTMatchersTest.h index ad2f5f355621c..02bdcc3a3ab1f 100644 --- a/clang/unittests/ASTMatchers/ASTMatchersTest.h +++ b/clang/unittests/ASTMatchers/ASTMatchersTest.h @@ -59,6 +59,11 @@ class VerifyMatch : public MatchFinder::MatchCallback { const std::unique_ptr FindResultReviewer; }; +inline ArrayRef langCxx11() { + static const TestLanguage Result[] = {Lang_CXX11}; + return ArrayRef(Result); +} + inline ArrayRef langCxx11OrLater() { static const TestLanguage Result[] = {Lang_CXX11, Lang_CXX14, Lang_CXX17, Lang_CXX20, Lang_CXX23}; @@ -91,9 +96,11 @@ testing::AssertionResult matchesConditionally( const Twine &Code, const T &AMatcher, bool ExpectMatch, ArrayRef CompileArgs, const FileContentMappings &VirtualMappedFiles = FileContentMappings(), - StringRef Filename = "input.cc") { + StringRef Filename = "input.cc", + MatchFinder::MatchFinderOptions Options = + MatchFinder::MatchFinderOptions()) { bool Found = false, DynamicFound = false; - MatchFinder Finder; + MatchFinder Finder(Options); VerifyMatch VerifyFound(nullptr, &Found); Finder.addMatcher(AMatcher, &VerifyFound); VerifyMatch VerifyDynamicFound(nullptr, &DynamicFound); @@ -147,11 +154,13 @@ testing::AssertionResult matchesConditionally( template testing::AssertionResult matchesConditionally(const Twine &Code, const T &AMatcher, bool ExpectMatch, - ArrayRef TestLanguages) { + ArrayRef TestLanguages, + MatchFinder::MatchFinderOptions Options = + MatchFinder::MatchFinderOptions()) { for (auto Lang : TestLanguages) { auto Result = matchesConditionally( Code, AMatcher, ExpectMatch, getCommandLineArgsForTesting(Lang), - FileContentMappings(), getFilenameForTesting(Lang)); + FileContentMappings(), getFilenameForTesting(Lang), Options); if (!Result) return Result; } @@ -162,8 +171,10 @@ matchesConditionally(const Twine &Code, const T &AMatcher, bool ExpectMatch, template testing::AssertionResult matches(const Twine &Code, const T &AMatcher, - ArrayRef TestLanguages = {Lang_CXX11}) { - return matchesConditionally(Code, AMatcher, true, TestLanguages); + ArrayRef TestLanguages = {Lang_CXX11}, + MatchFinder::MatchFinderOptions Options = + MatchFinder::MatchFinderOptions()) { + return matchesConditionally(Code, AMatcher, true, TestLanguages, Options); } template diff --git a/clang/unittests/ASTMatchers/ASTMatchersTraversalTest.cpp b/clang/unittests/ASTMatchers/ASTMatchersTraversalTest.cpp index 068cf66771027..02badc50241d2 100644 --- a/clang/unittests/ASTMatchers/ASTMatchersTraversalTest.cpp +++ b/clang/unittests/ASTMatchers/ASTMatchersTraversalTest.cpp @@ -28,6 +28,18 @@ TEST(DeclarationMatcher, hasMethod) { cxxRecordDecl(hasMethod(isPublic())))); } +TEST(DeclarationMatcher, shouldTraverse) { + MatchFinder::MatchFinderOptions Options; + Options.ShouldTraverseDecl = [](const Decl &decl) { return true; }; + EXPECT_TRUE(matches("class A { void func(); };", + cxxRecordDecl(hasMethod(hasName("func"))), langCxx11(), + Options)); + Options.ShouldTraverseDecl = [](const Decl &decl) { return false; }; + EXPECT_FALSE(matches("class A { void func(); };", + cxxRecordDecl(hasMethod(hasName("func"))), langCxx11(), + Options)); +} + TEST(DeclarationMatcher, ClassDerivedFromDependentTemplateSpecialization) { EXPECT_TRUE(matches( "template struct A {"