diff --git a/flang/include/flang/Parser/openmp-utils.h b/flang/include/flang/Parser/openmp-utils.h new file mode 100644 index 0000000000000..579ea7d74957f --- /dev/null +++ b/flang/include/flang/Parser/openmp-utils.h @@ -0,0 +1,161 @@ +//===-- flang/Parser/openmp-utils.h ---------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Common OpenMP utilities. +// +//===----------------------------------------------------------------------===// + +#ifndef FORTRAN_PARSER_OPENMP_UTILS_H +#define FORTRAN_PARSER_OPENMP_UTILS_H + +#include "flang/Common/indirection.h" +#include "flang/Parser/parse-tree.h" +#include "llvm/Frontend/OpenMP/OMP.h" + +#include +#include +#include +#include +#include + +namespace Fortran::parser::omp { + +namespace detail { +using D = llvm::omp::Directive; + +template // +struct ConstructId { + static constexpr llvm::omp::Directive id{D::OMPD_unknown}; +}; + +#define MAKE_CONSTR_ID(Construct, Id) \ + template <> struct ConstructId { \ + static constexpr llvm::omp::Directive id{Id}; \ + } + +MAKE_CONSTR_ID(OmpAssumeDirective, D::OMPD_assume); +MAKE_CONSTR_ID(OmpCriticalDirective, D::OMPD_critical); +MAKE_CONSTR_ID(OmpDeclareVariantDirective, D::OMPD_declare_variant); +MAKE_CONSTR_ID(OmpErrorDirective, D::OMPD_error); +MAKE_CONSTR_ID(OmpMetadirectiveDirective, D::OMPD_metadirective); +MAKE_CONSTR_ID(OpenMPDeclarativeAllocate, D::OMPD_allocate); +MAKE_CONSTR_ID(OpenMPDeclarativeAssumes, D::OMPD_assumes); +MAKE_CONSTR_ID(OpenMPDeclareMapperConstruct, D::OMPD_declare_mapper); +MAKE_CONSTR_ID(OpenMPDeclareReductionConstruct, D::OMPD_declare_reduction); +MAKE_CONSTR_ID(OpenMPDeclareSimdConstruct, D::OMPD_declare_simd); +MAKE_CONSTR_ID(OpenMPDeclareTargetConstruct, D::OMPD_declare_target); +MAKE_CONSTR_ID(OpenMPExecutableAllocate, D::OMPD_allocate); +MAKE_CONSTR_ID(OpenMPRequiresConstruct, D::OMPD_requires); +MAKE_CONSTR_ID(OpenMPThreadprivate, D::OMPD_threadprivate); + +#undef MAKE_CONSTR_ID + +struct DirectiveNameScope { + static OmpDirectiveName MakeName(CharBlock source = {}, + llvm::omp::Directive id = llvm::omp::Directive::OMPD_unknown) { + OmpDirectiveName name; + name.source = source; + name.v = id; + return name; + } + + static OmpDirectiveName GetOmpDirectiveName(const OmpNothingDirective &x) { + return MakeName(x.source, llvm::omp::Directive::OMPD_nothing); + } + + static OmpDirectiveName GetOmpDirectiveName(const OmpBeginBlockDirective &x) { + auto &dir{std::get(x.t)}; + return MakeName(dir.source, dir.v); + } + + static OmpDirectiveName GetOmpDirectiveName(const OmpBeginLoopDirective &x) { + auto &dir{std::get(x.t)}; + return MakeName(dir.source, dir.v); + } + + static OmpDirectiveName GetOmpDirectiveName( + const OmpBeginSectionsDirective &x) { + auto &dir{std::get(x.t)}; + return MakeName(dir.source, dir.v); + } + + template + static OmpDirectiveName GetOmpDirectiveName(const T &x) { + if constexpr (WrapperTrait) { + if constexpr (std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v) { + return x.v.DirName(); + } else { + return GetOmpDirectiveName(x.v); + } + } else if constexpr (TupleTrait) { + if constexpr (std::is_same_v || + std::is_same_v || + std::is_same_v) { + return std::get(x.t).DirName(); + } else if constexpr (std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v) { + return MakeName(std::get(x.t).source, ConstructId::id); + } else { + return GetFromTuple( + x.t, std::make_index_sequence>{}); + } + } else if constexpr (UnionTrait) { + return common::visit( + [](auto &&s) { return GetOmpDirectiveName(s); }, x.u); + } else { + return MakeName(); + } + } + + template + static OmpDirectiveName GetFromTuple( + const std::tuple &t, std::index_sequence) { + OmpDirectiveName name = MakeName(); + auto accumulate = [&](const OmpDirectiveName &n) { + if (name.v == llvm::omp::Directive::OMPD_unknown) { + name = n; + } else { + assert( + n.v == llvm::omp::Directive::OMPD_unknown && "Conflicting names"); + } + }; + (accumulate(GetOmpDirectiveName(std::get(t))), ...); + return name; + } + + template + static OmpDirectiveName GetOmpDirectiveName(const common::Indirection &x) { + return GetOmpDirectiveName(x.value()); + } +}; +} // namespace detail + +template OmpDirectiveName GetOmpDirectiveName(const T &x) { + return detail::DirectiveNameScope::GetOmpDirectiveName(x); +} + +} // namespace Fortran::parser::omp + +#endif // FORTRAN_PARSER_OPENMP_UTILS_H diff --git a/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp b/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp index 11e488371b886..2ac4d9548b65b 100644 --- a/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp +++ b/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp @@ -24,6 +24,7 @@ #include "flang/Optimizer/Dialect/FIROps.h" #include "flang/Optimizer/HLFIR/HLFIRDialect.h" #include "flang/Optimizer/HLFIR/HLFIROps.h" +#include "flang/Parser/openmp-utils.h" #include "flang/Semantics/attr.h" #include "flang/Semantics/tools.h" #include "llvm/ADT/Sequence.h" @@ -465,7 +466,8 @@ bool DataSharingProcessor::isOpenMPPrivatizingConstruct( // allow a privatizing clause) are: dispatch, distribute, do, for, loop, // parallel, scope, sections, simd, single, target, target_data, task, // taskgroup, taskloop, and teams. - return llvm::is_contained(privatizing, extractOmpDirective(omp)); + return llvm::is_contained(privatizing, + parser::omp::GetOmpDirectiveName(omp).v); } bool DataSharingProcessor::isOpenMPPrivatizingEvaluation( diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp index fc5fef9b2c577..4c2d7badef382 100644 --- a/flang/lib/Lower/OpenMP/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP/OpenMP.cpp @@ -31,6 +31,7 @@ #include "flang/Optimizer/Dialect/FIRType.h" #include "flang/Optimizer/HLFIR/HLFIROps.h" #include "flang/Parser/characters.h" +#include "flang/Parser/openmp-utils.h" #include "flang/Parser/parse-tree.h" #include "flang/Semantics/openmp-directive-sets.h" #include "flang/Semantics/tools.h" @@ -63,28 +64,6 @@ static void processHostEvalClauses(lower::AbstractConverter &converter, lower::pft::Evaluation &eval, mlir::Location loc); -static llvm::omp::Directive -getOpenMPDirectiveEnum(const parser::OmpLoopDirective &beginStatment) { - return beginStatment.v; -} - -static llvm::omp::Directive getOpenMPDirectiveEnum( - const parser::OmpBeginLoopDirective &beginLoopDirective) { - return getOpenMPDirectiveEnum( - std::get(beginLoopDirective.t)); -} - -static llvm::omp::Directive -getOpenMPDirectiveEnum(const parser::OpenMPLoopConstruct &ompLoopConstruct) { - return getOpenMPDirectiveEnum( - std::get(ompLoopConstruct.t)); -} - -static llvm::omp::Directive getOpenMPDirectiveEnum( - const common::Indirection &ompLoopConstruct) { - return getOpenMPDirectiveEnum(ompLoopConstruct.value()); -} - namespace { /// Structure holding information that is needed to pass host-evaluated /// information to later lowering stages. @@ -468,7 +447,7 @@ static void processHostEvalClauses(lower::AbstractConverter &converter, llvm::omp::Directive dir; auto &nested = parent.getFirstNestedEvaluation(); if (const auto *ompEval = nested.getIf()) - dir = extractOmpDirective(*ompEval); + dir = parser::omp::GetOmpDirectiveName(*ompEval).v; else return std::nullopt; @@ -508,7 +487,7 @@ static void processHostEvalClauses(lower::AbstractConverter &converter, HostEvalInfo *hostInfo = getHostEvalInfoStackTop(converter); assert(hostInfo && "expected HOST_EVAL info structure"); - switch (extractOmpDirective(*ompEval)) { + switch (parser::omp::GetOmpDirectiveName(*ompEval).v) { case OMPD_teams_distribute_parallel_do: case OMPD_teams_distribute_parallel_do_simd: cp.processThreadLimit(stmtCtx, hostInfo->ops); @@ -569,7 +548,8 @@ static void processHostEvalClauses(lower::AbstractConverter &converter, const auto *ompEval = eval.getIf(); assert(ompEval && - llvm::omp::allTargetSet.test(extractOmpDirective(*ompEval)) && + llvm::omp::allTargetSet.test( + parser::omp::GetOmpDirectiveName(*ompEval).v) && "expected TARGET construct evaluation"); (void)ompEval; @@ -3872,7 +3852,7 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable, std::get_if>( &*optLoopCons)}) { llvm::omp::Directive nestedDirective = - getOpenMPDirectiveEnum(*ompNestedLoopCons); + parser::omp::GetOmpDirectiveName(*ompNestedLoopCons).v; switch (nestedDirective) { case llvm::omp::Directive::OMPD_tile: // Emit the omp.loop_nest with annotation for tiling @@ -3889,7 +3869,8 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable, } } - llvm::omp::Directive directive = getOpenMPDirectiveEnum(beginLoopDirective); + llvm::omp::Directive directive = + parser::omp::GetOmpDirectiveName(beginLoopDirective).v; const parser::CharBlock &source = std::get(beginLoopDirective.t).source; ConstructQueue queue{ diff --git a/flang/lib/Lower/OpenMP/Utils.cpp b/flang/lib/Lower/OpenMP/Utils.cpp index b1716d6afb200..13fda978c5369 100644 --- a/flang/lib/Lower/OpenMP/Utils.cpp +++ b/flang/lib/Lower/OpenMP/Utils.cpp @@ -20,6 +20,7 @@ #include #include #include +#include #include #include #include @@ -663,89 +664,6 @@ bool collectLoopRelatedInfo( return found; } -/// Get the directive enumeration value corresponding to the given OpenMP -/// construct PFT node. -llvm::omp::Directive -extractOmpDirective(const parser::OpenMPConstruct &ompConstruct) { - return common::visit( - common::visitors{ - [](const parser::OpenMPAllocatorsConstruct &c) { - return llvm::omp::OMPD_allocators; - }, - [](const parser::OpenMPAssumeConstruct &c) { - return llvm::omp::OMPD_assume; - }, - [](const parser::OpenMPAtomicConstruct &c) { - return llvm::omp::OMPD_atomic; - }, - [](const parser::OpenMPBlockConstruct &c) { - return std::get( - std::get(c.t).t) - .v; - }, - [](const parser::OpenMPCriticalConstruct &c) { - return llvm::omp::OMPD_critical; - }, - [](const parser::OpenMPDeclarativeAllocate &c) { - return llvm::omp::OMPD_allocate; - }, - [](const parser::OpenMPDispatchConstruct &c) { - return llvm::omp::OMPD_dispatch; - }, - [](const parser::OpenMPExecutableAllocate &c) { - return llvm::omp::OMPD_allocate; - }, - [](const parser::OpenMPLoopConstruct &c) { - return std::get( - std::get(c.t).t) - .v; - }, - [](const parser::OpenMPSectionConstruct &c) { - return llvm::omp::OMPD_section; - }, - [](const parser::OpenMPSectionsConstruct &c) { - return std::get( - std::get(c.t).t) - .v; - }, - [](const parser::OpenMPStandaloneConstruct &c) { - return common::visit( - common::visitors{ - [](const parser::OpenMPSimpleStandaloneConstruct &c) { - return c.v.DirId(); - }, - [](const parser::OpenMPFlushConstruct &c) { - return llvm::omp::OMPD_flush; - }, - [](const parser::OpenMPCancelConstruct &c) { - return llvm::omp::OMPD_cancel; - }, - [](const parser::OpenMPCancellationPointConstruct &c) { - return llvm::omp::OMPD_cancellation_point; - }, - [](const parser::OmpMetadirectiveDirective &c) { - return llvm::omp::OMPD_metadirective; - }, - [](const parser::OpenMPDepobjConstruct &c) { - return llvm::omp::OMPD_depobj; - }, - [](const parser::OpenMPInteropConstruct &c) { - return llvm::omp::OMPD_interop; - }}, - c.u); - }, - [](const parser::OpenMPUtilityConstruct &c) { - return common::visit( - common::visitors{[](const parser::OmpErrorDirective &c) { - return llvm::omp::OMPD_error; - }, - [](const parser::OmpNothingDirective &c) { - return llvm::omp::OMPD_nothing; - }}, - c.u); - }}, - ompConstruct.u); -} } // namespace omp } // namespace lower } // namespace Fortran diff --git a/flang/lib/Lower/OpenMP/Utils.h b/flang/lib/Lower/OpenMP/Utils.h index 8e3ad5c3452e2..11641ba5e8606 100644 --- a/flang/lib/Lower/OpenMP/Utils.h +++ b/flang/lib/Lower/OpenMP/Utils.h @@ -167,8 +167,6 @@ bool collectLoopRelatedInfo( mlir::omp::LoopRelatedClauseOps &result, llvm::SmallVectorImpl &iv); -llvm::omp::Directive -extractOmpDirective(const parser::OpenMPConstruct &ompConstruct); } // namespace omp } // namespace lower } // namespace Fortran