diff --git a/clang/include/clang/ASTMatchers/ASTMatchers.h b/clang/include/clang/ASTMatchers/ASTMatchers.h index f1d88a9523838..f007ef3bdf88d 100644 --- a/clang/include/clang/ASTMatchers/ASTMatchers.h +++ b/clang/include/clang/ASTMatchers/ASTMatchers.h @@ -8723,6 +8723,21 @@ AST_MATCHER_P(OMPExecutableDirective, hasAnyClause, Builder) != Clauses.end(); } +/// Matches any ``#pragma omp target update`` executable directive. +/// +/// Given +/// +/// \code +/// #pragma omp target update from(a) +/// #pragma omp target update to(b) +/// \endcode +/// +/// ``ompTargetUpdateDirective()`` matches both ``omp target update from(a)`` +/// and ``omp target update to(b)``. +extern const internal::VariadicDynCastAllOfMatcher + ompTargetUpdateDirective; + /// Matches OpenMP ``default`` clause. /// /// Given @@ -8836,6 +8851,30 @@ AST_MATCHER_P(OMPExecutableDirective, isAllowedToContainClauseKind, Finder->getASTContext().getLangOpts().OpenMP); } +/// Matches OpenMP ``from`` clause. +/// +/// Given +/// +/// \code +/// #pragma omp target update from(a) +/// \endcode +/// +/// ``ompFromClause()`` matches ``from(a)``. +extern const internal::VariadicDynCastAllOfMatcher + ompFromClause; + +/// Matches OpenMP ``to`` clause. +/// +/// Given +/// +/// \code +/// #pragma omp target update to(a) +/// \endcode +/// +/// ``ompToClause()`` matches ``to(a)``. +extern const internal::VariadicDynCastAllOfMatcher + ompToClause; + //----------------------------------------------------------------------------// // End OpenMP handling. //----------------------------------------------------------------------------// diff --git a/clang/lib/ASTMatchers/ASTMatchersInternal.cpp b/clang/lib/ASTMatchers/ASTMatchersInternal.cpp index 653b3810cb68b..5efa7f162789c 100644 --- a/clang/lib/ASTMatchers/ASTMatchersInternal.cpp +++ b/clang/lib/ASTMatchers/ASTMatchersInternal.cpp @@ -1124,8 +1124,13 @@ AST_TYPELOC_TRAVERSE_MATCHER_DEF( const internal::VariadicDynCastAllOfMatcher ompExecutableDirective; +const internal::VariadicDynCastAllOfMatcher + ompTargetUpdateDirective; const internal::VariadicDynCastAllOfMatcher ompDefaultClause; +const internal::VariadicDynCastAllOfMatcher + ompFromClause; +const internal::VariadicDynCastAllOfMatcher ompToClause; const internal::VariadicDynCastAllOfMatcher cxxDeductionGuideDecl; diff --git a/clang/lib/ASTMatchers/Dynamic/Registry.cpp b/clang/lib/ASTMatchers/Dynamic/Registry.cpp index 48a7b91969aef..447c70dc6f9af 100644 --- a/clang/lib/ASTMatchers/Dynamic/Registry.cpp +++ b/clang/lib/ASTMatchers/Dynamic/Registry.cpp @@ -528,7 +528,10 @@ RegistryMaps::RegistryMaps() { REGISTER_MATCHER(ofClass); REGISTER_MATCHER(ofKind); REGISTER_MATCHER(ompDefaultClause); + REGISTER_MATCHER(ompFromClause); + REGISTER_MATCHER(ompToClause); REGISTER_MATCHER(ompExecutableDirective); + REGISTER_MATCHER(ompTargetUpdateDirective); REGISTER_MATCHER(on); REGISTER_MATCHER(onImplicitObjectArgument); REGISTER_MATCHER(opaqueValueExpr); diff --git a/clang/unittests/ASTMatchers/ASTMatchersNarrowingTest.cpp b/clang/unittests/ASTMatchers/ASTMatchersNarrowingTest.cpp index 8a957864cdd12..63639cc890ec9 100644 --- a/clang/unittests/ASTMatchers/ASTMatchersNarrowingTest.cpp +++ b/clang/unittests/ASTMatchers/ASTMatchersNarrowingTest.cpp @@ -4734,6 +4734,203 @@ void x() { EXPECT_TRUE(matchesWithOpenMP(Source8, Matcher)); } +TEST_P(ASTMatchersTest, OMPTargetUpdateDirective_From_IsStandaloneDirective) { + auto Matcher = ompTargetUpdateDirective(isStandaloneDirective()); + + StringRef Source0 = R"( + void foo() { + int arr[8]; + #pragma omp target update from(arr[0:8:2]) + ; + } + )"; + EXPECT_TRUE(matchesWithOpenMP(Source0, Matcher)); +} + +TEST_P(ASTMatchersTest, OMPTargetUpdateDirective_To_IsStandaloneDirective) { + auto Matcher = ompTargetUpdateDirective(isStandaloneDirective()); + + StringRef Source0 = R"( + void foo() { + int arr[8]; + #pragma omp target update to(arr[0:8:2]) + ; + } + )"; + EXPECT_TRUE(matchesWithOpenMP(Source0, Matcher)); +} + +TEST_P(ASTMatchersTest, OMPTargetUpdateDirective_From_HasStructuredBlock) { + StringRef Source0 = R"( + void foo() { + int arr[8]; + #pragma omp target update from(arr[0:8:2]) + ; + } + )"; + EXPECT_TRUE(notMatchesWithOpenMP( + Source0, ompTargetUpdateDirective(hasStructuredBlock(nullStmt())))); +} + +TEST_P(ASTMatchersTest, OMPTargetUpdateDirective_To_HasStructuredBlock) { + StringRef Source0 = R"( + void foo() { + int arr[8]; + #pragma omp target update to(arr[0:8:2]) + ; + } + )"; + EXPECT_TRUE(notMatchesWithOpenMP( + Source0, ompTargetUpdateDirective(hasStructuredBlock(nullStmt())))); +} + +TEST_P(ASTMatchersTest, OMPTargetUpdateDirective_From_HasClause) { + auto Matcher = ompTargetUpdateDirective(hasAnyClause(ompFromClause())); + + StringRef Source0 = R"( + void foo() { + int arr[8]; + #pragma omp target update from(arr[0:8:2]) + ; + } + )"; + EXPECT_TRUE(matchesWithOpenMP(Source0, Matcher)); + + auto astUnit = tooling::buildASTFromCodeWithArgs(Source0, {"-fopenmp"}); + ASSERT_TRUE(astUnit); + + auto Results = match(ompTargetUpdateDirective().bind("directive"), + astUnit->getASTContext()); + ASSERT_FALSE(Results.empty()); + + const auto *Directive = + Results[0].getNodeAs("directive"); + ASSERT_TRUE(Directive); + + OMPFromClause *FromClause = nullptr; + for (auto *Clause : Directive->clauses()) { + if ((FromClause = dyn_cast(Clause))) { + break; + } + } + ASSERT_TRUE(FromClause); + + for (const auto *VarExpr : FromClause->varlist()) { + const auto *ArraySection = dyn_cast(VarExpr); + if (!ArraySection) + continue; + // base (arr) + const Expr *Base = ArraySection->getBase(); + ASSERT_TRUE(Base); + + // lower bound (0) + const Expr *LowerBound = ArraySection->getLowerBound(); + ASSERT_TRUE(LowerBound); + + // length (8) + const Expr *Length = ArraySection->getLength(); + ASSERT_TRUE(Length); + + // stride (2) + const Expr *Stride = ArraySection->getStride(); + ASSERT_TRUE(Stride); + } +} + +TEST_P(ASTMatchersTest, OMPTargetUpdateDirective_To_HasClause) { + auto Matcher = ompTargetUpdateDirective(hasAnyClause(ompToClause())); + + StringRef Source0 = R"( + void foo() { + int arr[8]; + #pragma omp target update to(arr[0:8:2]) + ; + } + )"; + EXPECT_TRUE(matchesWithOpenMP(Source0, Matcher)); + + auto astUnit = tooling::buildASTFromCodeWithArgs(Source0, {"-fopenmp"}); + ASSERT_TRUE(astUnit); + + auto Results = match(ompTargetUpdateDirective().bind("directive"), + astUnit->getASTContext()); + ASSERT_FALSE(Results.empty()); + + const auto *Directive = + Results[0].getNodeAs("directive"); + ASSERT_TRUE(Directive); + + OMPToClause *ToClause = nullptr; + for (auto *Clause : Directive->clauses()) { + if ((ToClause = dyn_cast(Clause))) { + break; + } + } + ASSERT_TRUE(ToClause); + + for (const auto *VarExpr : ToClause->varlist()) { + const auto *ArraySection = dyn_cast(VarExpr); + if (!ArraySection) + continue; + + const Expr *Base = ArraySection->getBase(); + ASSERT_TRUE(Base); + + const Expr *LowerBound = ArraySection->getLowerBound(); + ASSERT_TRUE(LowerBound); + + const Expr *Length = ArraySection->getLength(); + ASSERT_TRUE(Length); + + const Expr *Stride = ArraySection->getStride(); + ASSERT_TRUE(Stride); + } +} + +TEST_P(ASTMatchersTest, + OMPTargetUpdateDirective_IsAllowedToContainClauseKind_From) { + auto Matcher = ompTargetUpdateDirective( + isAllowedToContainClauseKind(llvm::omp::OMPC_from)); + + StringRef Source0 = R"( + void x() { + ; + } + )"; + EXPECT_TRUE(notMatchesWithOpenMP(Source0, Matcher)); + + StringRef Source1 = R"( + void foo() { + int arr[8]; + #pragma omp target update from(arr[0:8:2]) + ; + } + )"; + EXPECT_TRUE(matchesWithOpenMP(Source1, Matcher)); +} + +TEST_P(ASTMatchersTest, + OMPTargetUpdateDirective_IsAllowedToContainClauseKind_To) { + auto Matcher = ompTargetUpdateDirective( + isAllowedToContainClauseKind(llvm::omp::OMPC_to)); + + StringRef Source0 = R"( + void x() { + ; + } + )"; + EXPECT_TRUE(notMatchesWithOpenMP(Source0, Matcher)); + + StringRef Source1 = R"( + void foo() { + int arr[8]; + #pragma omp target update to(arr[0:8:2]) + ; + } + )"; + EXPECT_TRUE(matchesWithOpenMP(Source1, Matcher)); +} + TEST_P(ASTMatchersTest, HasAnyBase_DirectBase) { if (!GetParam().isCXX()) { return; diff --git a/clang/unittests/ASTMatchers/ASTMatchersNodeTest.cpp b/clang/unittests/ASTMatchers/ASTMatchersNodeTest.cpp index d7df9cae01f33..edc84704b5ed2 100644 --- a/clang/unittests/ASTMatchers/ASTMatchersNodeTest.cpp +++ b/clang/unittests/ASTMatchers/ASTMatchersNodeTest.cpp @@ -2742,6 +2742,140 @@ void x() { EXPECT_TRUE(notMatchesWithOpenMP(Source2, Matcher)); } +TEST(ASTMatchersTestOpenMP, OMPTargetUpdateDirective_From) { + auto Matcher = stmt(ompTargetUpdateDirective()); + + StringRef Source0 = R"( + void foo() { + int arr[8]; + #pragma omp target update from(arr[0:8:2]) + ; + } + )"; + EXPECT_TRUE(matchesWithOpenMP(Source0, Matcher)); +} + +TEST(ASTMatchersTestOpenMP, OMPTargetUpdateDirective_To) { + auto Matcher = stmt(ompTargetUpdateDirective()); + + StringRef Source0 = R"( + void foo() { + int arr[8]; + #pragma omp target update to(arr[0:8:2]) + ; + } + )"; + EXPECT_TRUE(matchesWithOpenMP(Source0, Matcher)); +} + +TEST(ASTMatchersTestOpenMP, OMPFromClause) { + auto Matcher = ompTargetUpdateDirective(hasAnyClause(ompFromClause())); + + StringRef Source0 = R"( + void foo() { + int arr[8]; + #pragma omp target update from(arr[0:8:2]) + ; + } + )"; + EXPECT_TRUE(matchesWithOpenMP(Source0, Matcher)); + + auto astUnit = tooling::buildASTFromCodeWithArgs(Source0, {"-fopenmp"}); + ASSERT_TRUE(astUnit); + + auto Results = match(ompTargetUpdateDirective().bind("directive"), + astUnit->getASTContext()); + ASSERT_FALSE(Results.empty()); + + const auto *Directive = + Results[0].getNodeAs("directive"); + ASSERT_TRUE(Directive); + + OMPFromClause *FromClause = nullptr; + for (auto *Clause : Directive->clauses()) { + if ((FromClause = dyn_cast(Clause))) { + break; + } + } + ASSERT_TRUE(FromClause); + + for (const auto *VarExpr : FromClause->varlist()) { + const auto *ArraySection = dyn_cast(VarExpr); + if (!ArraySection) + continue; + + // base (arr) + const Expr *Base = ArraySection->getBase(); + ASSERT_TRUE(Base); + + // lower bound (0) + const Expr *LowerBound = ArraySection->getLowerBound(); + ASSERT_TRUE(LowerBound); + + // length (8) + const Expr *Length = ArraySection->getLength(); + ASSERT_TRUE(Length); + + // stride (2) + const Expr *Stride = ArraySection->getStride(); + ASSERT_TRUE(Stride); + } +} + +TEST(ASTMatchersTestOpenMP, OMPToClause) { + auto Matcher = ompTargetUpdateDirective(hasAnyClause(ompToClause())); + + StringRef Source0 = R"( + void foo() { + int arr[8]; + #pragma omp target update to(arr[0:8:2]) + ; + } + )"; + EXPECT_TRUE(matchesWithOpenMP(Source0, Matcher)); + + auto astUnit = tooling::buildASTFromCodeWithArgs(Source0, {"-fopenmp"}); + ASSERT_TRUE(astUnit); + + auto Results = match(ompTargetUpdateDirective().bind("directive"), + astUnit->getASTContext()); + ASSERT_FALSE(Results.empty()); + + const auto *Directive = + Results[0].getNodeAs("directive"); + ASSERT_TRUE(Directive); + + OMPToClause *ToClause = nullptr; + for (auto *Clause : Directive->clauses()) { + if ((ToClause = dyn_cast(Clause))) { + break; + } + } + ASSERT_TRUE(ToClause); + + for (const auto *VarExpr : ToClause->varlist()) { + const auto *ArraySection = dyn_cast(VarExpr); + if (!ArraySection) + continue; + + // base (arr) + const Expr *Base = ArraySection->getBase(); + ASSERT_TRUE(Base); + + // lower bound (0) + const Expr *LowerBound = ArraySection->getLowerBound(); + ASSERT_TRUE(LowerBound); + + // length (8) + const Expr *Length = ArraySection->getLength(); + ASSERT_TRUE(Length); + + // stride (2) + const Expr *Stride = ArraySection->getStride(); + ASSERT_TRUE(Stride); + } +} + TEST(ASTMatchersTestOpenMP, OMPDefaultClause) { auto Matcher = ompExecutableDirective(hasAnyClause(ompDefaultClause()));