diff --git a/flang/include/flang/Semantics/openmp-utils.h b/flang/include/flang/Semantics/openmp-utils.h index 2954a1c4769f7..0f851830edd46 100644 --- a/flang/include/flang/Semantics/openmp-utils.h +++ b/flang/include/flang/Semantics/openmp-utils.h @@ -38,6 +38,7 @@ template > U AsRvalue(T &t) { template T &&AsRvalue(T &&t) { return std::move(t); } const Scope &GetScopingUnit(const Scope &scope); +const Scope &GetProgramUnit(const Scope &scope); // There is no consistent way to get the source of an ActionStmt, but there // is "source" in Statement. This structure keeps the ActionStmt with the diff --git a/flang/lib/Semantics/openmp-utils.cpp b/flang/lib/Semantics/openmp-utils.cpp index a8ec4d6c24beb..292e73b4899c0 100644 --- a/flang/lib/Semantics/openmp-utils.cpp +++ b/flang/lib/Semantics/openmp-utils.cpp @@ -13,6 +13,7 @@ #include "flang/Semantics/openmp-utils.h" #include "flang/Common/Fortran-consts.h" +#include "flang/Common/idioms.h" #include "flang/Common/indirection.h" #include "flang/Common/reference.h" #include "flang/Common/visit.h" @@ -59,6 +60,26 @@ const Scope &GetScopingUnit(const Scope &scope) { return *iter; } +const Scope &GetProgramUnit(const Scope &scope) { + const Scope *unit{nullptr}; + for (const Scope *iter{&scope}; !iter->IsTopLevel(); iter = &iter->parent()) { + switch (iter->kind()) { + case Scope::Kind::BlockData: + case Scope::Kind::MainProgram: + case Scope::Kind::Module: + return *iter; + case Scope::Kind::Subprogram: + // Ignore subprograms that are nested. + unit = iter; + break; + default: + break; + } + } + assert(unit && "Scope not in a program unit"); + return *unit; +} + SourcedActionStmt GetActionStmt(const parser::ExecutionPartConstruct *x) { if (x == nullptr) { return SourcedActionStmt{}; @@ -202,7 +223,7 @@ std::optional GetEvaluateExpr(const parser::Expr &parserExpr) { // ForwardOwningPointer typedExpr // `- GenericExprWrapper ^.get() // `- std::optional ^->v - return typedExpr.get()->v; + return DEREF(typedExpr.get()).v; } std::optional GetDynamicType( diff --git a/flang/lib/Semantics/resolve-directives.cpp b/flang/lib/Semantics/resolve-directives.cpp index 18fc63814d973..de680b41d1524 100644 --- a/flang/lib/Semantics/resolve-directives.cpp +++ b/flang/lib/Semantics/resolve-directives.cpp @@ -3549,40 +3549,38 @@ void OmpAttributeVisitor::CheckLabelContext(const parser::CharBlock source, void OmpAttributeVisitor::AddOmpRequiresToScope(Scope &scope, WithOmpDeclarative::RequiresFlags flags, std::optional memOrder) { - Scope *scopeIter = &scope; - do { - if (Symbol * symbol{scopeIter->symbol()}) { - common::visit( - [&](auto &details) { - // Store clauses information into the symbol for the parent and - // enclosing modules, programs, functions and subroutines. - if constexpr (std::is_convertible_v) { - if (flags.any()) { - if (const WithOmpDeclarative::RequiresFlags * - otherFlags{details.ompRequires()}) { - flags |= *otherFlags; - } - details.set_ompRequires(flags); + const Scope &programUnit{omp::GetProgramUnit(scope)}; + + if (auto *symbol{const_cast(programUnit.symbol())}) { + common::visit( + [&](auto &details) { + // Store clauses information into the symbol for the parent and + // enclosing modules, programs, functions and subroutines. + if constexpr (std::is_convertible_v) { + if (flags.any()) { + if (const WithOmpDeclarative::RequiresFlags *otherFlags{ + details.ompRequires()}) { + flags |= *otherFlags; } - if (memOrder) { - if (details.has_ompAtomicDefaultMemOrder() && - *details.ompAtomicDefaultMemOrder() != *memOrder) { - context_.Say(scopeIter->sourceRange(), - "Conflicting '%s' REQUIRES clauses found in compilation " - "unit"_err_en_US, - parser::ToUpperCaseLetters(llvm::omp::getOpenMPClauseName( - llvm::omp::Clause::OMPC_atomic_default_mem_order) - .str())); - } - details.set_ompAtomicDefaultMemOrder(*memOrder); + details.set_ompRequires(flags); + } + if (memOrder) { + if (details.has_ompAtomicDefaultMemOrder() && + *details.ompAtomicDefaultMemOrder() != *memOrder) { + context_.Say(programUnit.sourceRange(), + "Conflicting '%s' REQUIRES clauses found in compilation " + "unit"_err_en_US, + parser::ToUpperCaseLetters(llvm::omp::getOpenMPClauseName( + llvm::omp::Clause::OMPC_atomic_default_mem_order) + .str())); } + details.set_ompAtomicDefaultMemOrder(*memOrder); } - }, - symbol->details()); - } - scopeIter = &scopeIter->parent(); - } while (!scopeIter->IsGlobal()); + } + }, + symbol->details()); + } } void OmpAttributeVisitor::IssueNonConformanceWarning(llvm::omp::Directive D,