Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
175 changes: 175 additions & 0 deletions flang/include/flang/Parser/openmp-utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
//===-- flang/Parser/openmp-utils.h ---------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Common OpenMP utilities.
//
//===----------------------------------------------------------------------===//

#ifndef FORTRAN_PARSER_OPENMP_UTILS_H
#define FORTRAN_PARSER_OPENMP_UTILS_H

#include "flang/Common/indirection.h"
#include "flang/Parser/parse-tree.h"
#include "llvm/Frontend/OpenMP/OMP.h"

#include <cassert>
#include <tuple>
#include <type_traits>
#include <utility>
#include <variant>

namespace Fortran::parser::omp {

namespace detail {
using D = llvm::omp::Directive;

template <typename Construct> //
struct ConstructId {
static constexpr llvm::omp::Directive id{D::OMPD_unknown};
};

#define MAKE_CONSTR_ID(Construct, Id) \
template <> struct ConstructId<Construct> { \
static constexpr llvm::omp::Directive id{Id}; \
}

MAKE_CONSTR_ID(OmpAssumeDirective, D::OMPD_assume);
MAKE_CONSTR_ID(OmpCriticalDirective, D::OMPD_critical);
MAKE_CONSTR_ID(OmpDeclareVariantDirective, D::OMPD_declare_variant);
MAKE_CONSTR_ID(OmpErrorDirective, D::OMPD_error);
MAKE_CONSTR_ID(OmpMetadirectiveDirective, D::OMPD_metadirective);
MAKE_CONSTR_ID(OpenMPDeclarativeAllocate, D::OMPD_allocate);
MAKE_CONSTR_ID(OpenMPDeclarativeAssumes, D::OMPD_assumes);
MAKE_CONSTR_ID(OpenMPDeclareMapperConstruct, D::OMPD_declare_mapper);
MAKE_CONSTR_ID(OpenMPDeclareReductionConstruct, D::OMPD_declare_reduction);
MAKE_CONSTR_ID(OpenMPDeclareSimdConstruct, D::OMPD_declare_simd);
MAKE_CONSTR_ID(OpenMPDeclareTargetConstruct, D::OMPD_declare_target);
MAKE_CONSTR_ID(OpenMPExecutableAllocate, D::OMPD_allocate);
MAKE_CONSTR_ID(OpenMPRequiresConstruct, D::OMPD_requires);
MAKE_CONSTR_ID(OpenMPThreadprivate, D::OMPD_threadprivate);

#undef MAKE_CONSTR_ID

struct DirectiveNameScope {
// Helper types to make overloaded function signatures different.
struct TagA {};
struct TagB {};
struct TagC {};
struct TagD {};

static OmpDirectiveName MakeName(CharBlock source = {},
llvm::omp::Directive id = llvm::omp::Directive::OMPD_unknown) {
OmpDirectiveName name;
name.source = source;
name.v = id;
return name;
}

static OmpDirectiveName GetOmpDirectiveName(const OmpNothingDirective &x) {
return MakeName(x.source, llvm::omp::Directive::OMPD_nothing);
}

static OmpDirectiveName GetOmpDirectiveName(const OmpBeginBlockDirective &x) {
auto &dir{std::get<OmpBlockDirective>(x.t)};
return MakeName(dir.source, dir.v);
}

static OmpDirectiveName GetOmpDirectiveName(const OmpBeginLoopDirective &x) {
auto &dir{std::get<OmpLoopDirective>(x.t)};
return MakeName(dir.source, dir.v);
}

static OmpDirectiveName GetOmpDirectiveName(
const OmpBeginSectionsDirective &x) {
auto &dir{std::get<OmpSectionsDirective>(x.t)};
return MakeName(dir.source, dir.v);
}

template <typename T, typename = std::enable_if_t<WrapperTrait<T>>>
static OmpDirectiveName GetOmpDirectiveName(const T &x, TagA = {}) {
if constexpr (std::is_same_v<T, OpenMPCancelConstruct> ||
std::is_same_v<T, OpenMPCancellationPointConstruct> ||
std::is_same_v<T, OpenMPDepobjConstruct> ||
std::is_same_v<T, OpenMPFlushConstruct> ||
std::is_same_v<T, OpenMPInteropConstruct> ||
std::is_same_v<T, OpenMPSimpleStandaloneConstruct>) {
return x.v.DirName();
} else {
return GetOmpDirectiveName(x.v);
}
}

template <typename T, typename = std::enable_if_t<TupleTrait<T>>>
static OmpDirectiveName GetOmpDirectiveName(const T &x, TagB = {}) {
if constexpr (std::is_same_v<T, OpenMPAllocatorsConstruct> ||
std::is_same_v<T, OpenMPAtomicConstruct> ||
std::is_same_v<T, OpenMPDispatchConstruct>) {
return std::get<OmpDirectiveSpecification>(x.t).DirName();
} else if constexpr (std::is_same_v<T, OmpAssumeDirective> ||
std::is_same_v<T, OmpCriticalDirective> ||
std::is_same_v<T, OmpDeclareVariantDirective> ||
std::is_same_v<T, OmpErrorDirective> ||
std::is_same_v<T, OmpMetadirectiveDirective> ||
std::is_same_v<T, OpenMPDeclarativeAllocate> ||
std::is_same_v<T, OpenMPDeclarativeAssumes> ||
std::is_same_v<T, OpenMPDeclareMapperConstruct> ||
std::is_same_v<T, OpenMPDeclareReductionConstruct> ||
std::is_same_v<T, OpenMPDeclareSimdConstruct> ||
std::is_same_v<T, OpenMPDeclareTargetConstruct> ||
std::is_same_v<T, OpenMPExecutableAllocate> ||
std::is_same_v<T, OpenMPRequiresConstruct> ||
std::is_same_v<T, OpenMPThreadprivate>) {
return MakeName(std::get<Verbatim>(x.t).source, ConstructId<T>::id);
} else {
return GetFromTuple(
x.t, std::make_index_sequence<std::tuple_size_v<decltype(x.t)>>{});
}
}

template <typename T, typename = std::enable_if_t<UnionTrait<T>>>
static OmpDirectiveName GetOmpDirectiveName(const T &x, TagC = {}) {
return common::visit([](auto &&s) { return GetOmpDirectiveName(s); }, x.u);
}

template <typename... Ts, size_t... Is>
static OmpDirectiveName GetFromTuple(
const std::tuple<Ts...> &t, std::index_sequence<Is...>) {
OmpDirectiveName name = MakeName();
auto accumulate = [&](const OmpDirectiveName &n) {
if (name.v == llvm::omp::Directive::OMPD_unknown) {
name = n;
} else {
assert(
n.v == llvm::omp::Directive::OMPD_unknown && "Conflicting names");
}
};
(accumulate(GetOmpDirectiveName(std::get<Is>(t))), ...);
return name;
}

template <typename T>
static OmpDirectiveName GetOmpDirectiveName(const common::Indirection<T> &x) {
return GetOmpDirectiveName(x.value());
}

template <typename T,
typename = std::enable_if_t<!WrapperTrait<T> && !TupleTrait<T> &&
!UnionTrait<T>>>
static OmpDirectiveName GetOmpDirectiveName(const T &x, TagD = {}) {
return MakeName();
}
};
} // namespace detail

template <typename T> OmpDirectiveName GetOmpDirectiveName(const T &x) {
return detail::DirectiveNameScope::GetOmpDirectiveName(x);
}

} // namespace Fortran::parser::omp

#endif // FORTRAN_PARSER_OPENMP_UTILS_H
4 changes: 3 additions & 1 deletion flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "flang/Optimizer/Dialect/FIROps.h"
#include "flang/Optimizer/HLFIR/HLFIRDialect.h"
#include "flang/Optimizer/HLFIR/HLFIROps.h"
#include "flang/Parser/openmp-utils.h"
#include "flang/Semantics/attr.h"
#include "flang/Semantics/tools.h"
#include "llvm/ADT/Sequence.h"
Expand Down Expand Up @@ -465,7 +466,8 @@ bool DataSharingProcessor::isOpenMPPrivatizingConstruct(
// allow a privatizing clause) are: dispatch, distribute, do, for, loop,
// parallel, scope, sections, simd, single, target, target_data, task,
// taskgroup, taskloop, and teams.
return llvm::is_contained(privatizing, extractOmpDirective(omp));
return llvm::is_contained(privatizing,
parser::omp::GetOmpDirectiveName(omp).v);
}

bool DataSharingProcessor::isOpenMPPrivatizingEvaluation(
Expand Down
35 changes: 8 additions & 27 deletions flang/lib/Lower/OpenMP/OpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include "flang/Optimizer/Dialect/FIRType.h"
#include "flang/Optimizer/HLFIR/HLFIROps.h"
#include "flang/Parser/characters.h"
#include "flang/Parser/openmp-utils.h"
#include "flang/Parser/parse-tree.h"
#include "flang/Semantics/openmp-directive-sets.h"
#include "flang/Semantics/tools.h"
Expand Down Expand Up @@ -63,28 +64,6 @@ static void processHostEvalClauses(lower::AbstractConverter &converter,
lower::pft::Evaluation &eval,
mlir::Location loc);

static llvm::omp::Directive
getOpenMPDirectiveEnum(const parser::OmpLoopDirective &beginStatment) {
return beginStatment.v;
}

static llvm::omp::Directive getOpenMPDirectiveEnum(
const parser::OmpBeginLoopDirective &beginLoopDirective) {
return getOpenMPDirectiveEnum(
std::get<parser::OmpLoopDirective>(beginLoopDirective.t));
}

static llvm::omp::Directive
getOpenMPDirectiveEnum(const parser::OpenMPLoopConstruct &ompLoopConstruct) {
return getOpenMPDirectiveEnum(
std::get<parser::OmpBeginLoopDirective>(ompLoopConstruct.t));
}

static llvm::omp::Directive getOpenMPDirectiveEnum(
const common::Indirection<parser::OpenMPLoopConstruct> &ompLoopConstruct) {
return getOpenMPDirectiveEnum(ompLoopConstruct.value());
}

namespace {
/// Structure holding information that is needed to pass host-evaluated
/// information to later lowering stages.
Expand Down Expand Up @@ -468,7 +447,7 @@ static void processHostEvalClauses(lower::AbstractConverter &converter,
llvm::omp::Directive dir;
auto &nested = parent.getFirstNestedEvaluation();
if (const auto *ompEval = nested.getIf<parser::OpenMPConstruct>())
dir = extractOmpDirective(*ompEval);
dir = parser::omp::GetOmpDirectiveName(*ompEval).v;
else
return std::nullopt;

Expand Down Expand Up @@ -508,7 +487,7 @@ static void processHostEvalClauses(lower::AbstractConverter &converter,
HostEvalInfo *hostInfo = getHostEvalInfoStackTop(converter);
assert(hostInfo && "expected HOST_EVAL info structure");

switch (extractOmpDirective(*ompEval)) {
switch (parser::omp::GetOmpDirectiveName(*ompEval).v) {
case OMPD_teams_distribute_parallel_do:
case OMPD_teams_distribute_parallel_do_simd:
cp.processThreadLimit(stmtCtx, hostInfo->ops);
Expand Down Expand Up @@ -569,7 +548,8 @@ static void processHostEvalClauses(lower::AbstractConverter &converter,

const auto *ompEval = eval.getIf<parser::OpenMPConstruct>();
assert(ompEval &&
llvm::omp::allTargetSet.test(extractOmpDirective(*ompEval)) &&
llvm::omp::allTargetSet.test(
parser::omp::GetOmpDirectiveName(*ompEval).v) &&
"expected TARGET construct evaluation");
(void)ompEval;

Expand Down Expand Up @@ -3872,7 +3852,7 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
std::get_if<common::Indirection<parser::OpenMPLoopConstruct>>(
&*optLoopCons)}) {
llvm::omp::Directive nestedDirective =
getOpenMPDirectiveEnum(*ompNestedLoopCons);
parser::omp::GetOmpDirectiveName(*ompNestedLoopCons).v;
switch (nestedDirective) {
case llvm::omp::Directive::OMPD_tile:
// Emit the omp.loop_nest with annotation for tiling
Expand All @@ -3889,7 +3869,8 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
}
}

llvm::omp::Directive directive = getOpenMPDirectiveEnum(beginLoopDirective);
llvm::omp::Directive directive =
parser::omp::GetOmpDirectiveName(beginLoopDirective).v;
const parser::CharBlock &source =
std::get<parser::OmpLoopDirective>(beginLoopDirective.t).source;
ConstructQueue queue{
Expand Down
84 changes: 1 addition & 83 deletions flang/lib/Lower/OpenMP/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <flang/Lower/PFTBuilder.h>
#include <flang/Optimizer/Builder/FIRBuilder.h>
#include <flang/Optimizer/Builder/Todo.h>
#include <flang/Parser/openmp-utils.h>
#include <flang/Parser/parse-tree.h>
#include <flang/Parser/tools.h>
#include <flang/Semantics/tools.h>
Expand Down Expand Up @@ -663,89 +664,6 @@ bool collectLoopRelatedInfo(
return found;
}

/// Get the directive enumeration value corresponding to the given OpenMP
/// construct PFT node.
llvm::omp::Directive
extractOmpDirective(const parser::OpenMPConstruct &ompConstruct) {
return common::visit(
common::visitors{
[](const parser::OpenMPAllocatorsConstruct &c) {
return llvm::omp::OMPD_allocators;
},
[](const parser::OpenMPAssumeConstruct &c) {
return llvm::omp::OMPD_assume;
},
[](const parser::OpenMPAtomicConstruct &c) {
return llvm::omp::OMPD_atomic;
},
[](const parser::OpenMPBlockConstruct &c) {
return std::get<parser::OmpBlockDirective>(
std::get<parser::OmpBeginBlockDirective>(c.t).t)
.v;
},
[](const parser::OpenMPCriticalConstruct &c) {
return llvm::omp::OMPD_critical;
},
[](const parser::OpenMPDeclarativeAllocate &c) {
return llvm::omp::OMPD_allocate;
},
[](const parser::OpenMPDispatchConstruct &c) {
return llvm::omp::OMPD_dispatch;
},
[](const parser::OpenMPExecutableAllocate &c) {
return llvm::omp::OMPD_allocate;
},
[](const parser::OpenMPLoopConstruct &c) {
return std::get<parser::OmpLoopDirective>(
std::get<parser::OmpBeginLoopDirective>(c.t).t)
.v;
},
[](const parser::OpenMPSectionConstruct &c) {
return llvm::omp::OMPD_section;
},
[](const parser::OpenMPSectionsConstruct &c) {
return std::get<parser::OmpSectionsDirective>(
std::get<parser::OmpBeginSectionsDirective>(c.t).t)
.v;
},
[](const parser::OpenMPStandaloneConstruct &c) {
return common::visit(
common::visitors{
[](const parser::OpenMPSimpleStandaloneConstruct &c) {
return c.v.DirId();
},
[](const parser::OpenMPFlushConstruct &c) {
return llvm::omp::OMPD_flush;
},
[](const parser::OpenMPCancelConstruct &c) {
return llvm::omp::OMPD_cancel;
},
[](const parser::OpenMPCancellationPointConstruct &c) {
return llvm::omp::OMPD_cancellation_point;
},
[](const parser::OmpMetadirectiveDirective &c) {
return llvm::omp::OMPD_metadirective;
},
[](const parser::OpenMPDepobjConstruct &c) {
return llvm::omp::OMPD_depobj;
},
[](const parser::OpenMPInteropConstruct &c) {
return llvm::omp::OMPD_interop;
}},
c.u);
},
[](const parser::OpenMPUtilityConstruct &c) {
return common::visit(
common::visitors{[](const parser::OmpErrorDirective &c) {
return llvm::omp::OMPD_error;
},
[](const parser::OmpNothingDirective &c) {
return llvm::omp::OMPD_nothing;
}},
c.u);
}},
ompConstruct.u);
}
} // namespace omp
} // namespace lower
} // namespace Fortran
2 changes: 0 additions & 2 deletions flang/lib/Lower/OpenMP/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,6 @@ bool collectLoopRelatedInfo(
mlir::omp::LoopRelatedClauseOps &result,
llvm::SmallVectorImpl<const semantics::Symbol *> &iv);

llvm::omp::Directive
extractOmpDirective(const parser::OpenMPConstruct &ompConstruct);
} // namespace omp
} // namespace lower
} // namespace Fortran
Expand Down
Loading