Skip to content

Commit 43db6c5

Browse files
authored
[flang][OpenMP] General utility to get directive id from AST node (#150121)
Fortran::parser::omp::GetOmpDirectiveName(t) will get the OmpDirectiveName object that corresponds to construct t. That object (an AST node) contains the enum id and the source information of the directive. Replace uses of extractOmpDirective and getOpenMPDirectiveEnum with the new function.
1 parent 081b74c commit 43db6c5

File tree

5 files changed

+173
-113
lines changed

5 files changed

+173
-113
lines changed
Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
//===-- flang/Parser/openmp-utils.h ---------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// Common OpenMP utilities.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#ifndef FORTRAN_PARSER_OPENMP_UTILS_H
14+
#define FORTRAN_PARSER_OPENMP_UTILS_H
15+
16+
#include "flang/Common/indirection.h"
17+
#include "flang/Parser/parse-tree.h"
18+
#include "llvm/Frontend/OpenMP/OMP.h"
19+
20+
#include <cassert>
21+
#include <tuple>
22+
#include <type_traits>
23+
#include <utility>
24+
#include <variant>
25+
26+
namespace Fortran::parser::omp {
27+
28+
namespace detail {
29+
using D = llvm::omp::Directive;
30+
31+
template <typename Construct> //
32+
struct ConstructId {
33+
static constexpr llvm::omp::Directive id{D::OMPD_unknown};
34+
};
35+
36+
#define MAKE_CONSTR_ID(Construct, Id) \
37+
template <> struct ConstructId<Construct> { \
38+
static constexpr llvm::omp::Directive id{Id}; \
39+
}
40+
41+
MAKE_CONSTR_ID(OmpAssumeDirective, D::OMPD_assume);
42+
MAKE_CONSTR_ID(OmpCriticalDirective, D::OMPD_critical);
43+
MAKE_CONSTR_ID(OmpDeclareVariantDirective, D::OMPD_declare_variant);
44+
MAKE_CONSTR_ID(OmpErrorDirective, D::OMPD_error);
45+
MAKE_CONSTR_ID(OmpMetadirectiveDirective, D::OMPD_metadirective);
46+
MAKE_CONSTR_ID(OpenMPDeclarativeAllocate, D::OMPD_allocate);
47+
MAKE_CONSTR_ID(OpenMPDeclarativeAssumes, D::OMPD_assumes);
48+
MAKE_CONSTR_ID(OpenMPDeclareMapperConstruct, D::OMPD_declare_mapper);
49+
MAKE_CONSTR_ID(OpenMPDeclareReductionConstruct, D::OMPD_declare_reduction);
50+
MAKE_CONSTR_ID(OpenMPDeclareSimdConstruct, D::OMPD_declare_simd);
51+
MAKE_CONSTR_ID(OpenMPDeclareTargetConstruct, D::OMPD_declare_target);
52+
MAKE_CONSTR_ID(OpenMPExecutableAllocate, D::OMPD_allocate);
53+
MAKE_CONSTR_ID(OpenMPRequiresConstruct, D::OMPD_requires);
54+
MAKE_CONSTR_ID(OpenMPThreadprivate, D::OMPD_threadprivate);
55+
56+
#undef MAKE_CONSTR_ID
57+
58+
struct DirectiveNameScope {
59+
static OmpDirectiveName MakeName(CharBlock source = {},
60+
llvm::omp::Directive id = llvm::omp::Directive::OMPD_unknown) {
61+
OmpDirectiveName name;
62+
name.source = source;
63+
name.v = id;
64+
return name;
65+
}
66+
67+
static OmpDirectiveName GetOmpDirectiveName(const OmpNothingDirective &x) {
68+
return MakeName(x.source, llvm::omp::Directive::OMPD_nothing);
69+
}
70+
71+
static OmpDirectiveName GetOmpDirectiveName(const OmpBeginBlockDirective &x) {
72+
auto &dir{std::get<OmpBlockDirective>(x.t)};
73+
return MakeName(dir.source, dir.v);
74+
}
75+
76+
static OmpDirectiveName GetOmpDirectiveName(const OmpBeginLoopDirective &x) {
77+
auto &dir{std::get<OmpLoopDirective>(x.t)};
78+
return MakeName(dir.source, dir.v);
79+
}
80+
81+
static OmpDirectiveName GetOmpDirectiveName(
82+
const OmpBeginSectionsDirective &x) {
83+
auto &dir{std::get<OmpSectionsDirective>(x.t)};
84+
return MakeName(dir.source, dir.v);
85+
}
86+
87+
template <typename T>
88+
static OmpDirectiveName GetOmpDirectiveName(const T &x) {
89+
if constexpr (WrapperTrait<T>) {
90+
if constexpr (std::is_same_v<T, OpenMPCancelConstruct> ||
91+
std::is_same_v<T, OpenMPCancellationPointConstruct> ||
92+
std::is_same_v<T, OpenMPDepobjConstruct> ||
93+
std::is_same_v<T, OpenMPFlushConstruct> ||
94+
std::is_same_v<T, OpenMPInteropConstruct> ||
95+
std::is_same_v<T, OpenMPSimpleStandaloneConstruct>) {
96+
return x.v.DirName();
97+
} else {
98+
return GetOmpDirectiveName(x.v);
99+
}
100+
} else if constexpr (TupleTrait<T>) {
101+
if constexpr (std::is_same_v<T, OpenMPAllocatorsConstruct> ||
102+
std::is_same_v<T, OpenMPAtomicConstruct> ||
103+
std::is_same_v<T, OpenMPDispatchConstruct>) {
104+
return std::get<OmpDirectiveSpecification>(x.t).DirName();
105+
} else if constexpr (std::is_same_v<T, OmpAssumeDirective> ||
106+
std::is_same_v<T, OmpCriticalDirective> ||
107+
std::is_same_v<T, OmpDeclareVariantDirective> ||
108+
std::is_same_v<T, OmpErrorDirective> ||
109+
std::is_same_v<T, OmpMetadirectiveDirective> ||
110+
std::is_same_v<T, OpenMPDeclarativeAllocate> ||
111+
std::is_same_v<T, OpenMPDeclarativeAssumes> ||
112+
std::is_same_v<T, OpenMPDeclareMapperConstruct> ||
113+
std::is_same_v<T, OpenMPDeclareReductionConstruct> ||
114+
std::is_same_v<T, OpenMPDeclareSimdConstruct> ||
115+
std::is_same_v<T, OpenMPDeclareTargetConstruct> ||
116+
std::is_same_v<T, OpenMPExecutableAllocate> ||
117+
std::is_same_v<T, OpenMPRequiresConstruct> ||
118+
std::is_same_v<T, OpenMPThreadprivate>) {
119+
return MakeName(std::get<Verbatim>(x.t).source, ConstructId<T>::id);
120+
} else {
121+
return GetFromTuple(
122+
x.t, std::make_index_sequence<std::tuple_size_v<decltype(x.t)>>{});
123+
}
124+
} else if constexpr (UnionTrait<T>) {
125+
return common::visit(
126+
[](auto &&s) { return GetOmpDirectiveName(s); }, x.u);
127+
} else {
128+
return MakeName();
129+
}
130+
}
131+
132+
template <typename... Ts, size_t... Is>
133+
static OmpDirectiveName GetFromTuple(
134+
const std::tuple<Ts...> &t, std::index_sequence<Is...>) {
135+
OmpDirectiveName name = MakeName();
136+
auto accumulate = [&](const OmpDirectiveName &n) {
137+
if (name.v == llvm::omp::Directive::OMPD_unknown) {
138+
name = n;
139+
} else {
140+
assert(
141+
n.v == llvm::omp::Directive::OMPD_unknown && "Conflicting names");
142+
}
143+
};
144+
(accumulate(GetOmpDirectiveName(std::get<Is>(t))), ...);
145+
return name;
146+
}
147+
148+
template <typename T>
149+
static OmpDirectiveName GetOmpDirectiveName(const common::Indirection<T> &x) {
150+
return GetOmpDirectiveName(x.value());
151+
}
152+
};
153+
} // namespace detail
154+
155+
template <typename T> OmpDirectiveName GetOmpDirectiveName(const T &x) {
156+
return detail::DirectiveNameScope::GetOmpDirectiveName(x);
157+
}
158+
159+
} // namespace Fortran::parser::omp
160+
161+
#endif // FORTRAN_PARSER_OPENMP_UTILS_H

flang/lib/Lower/OpenMP/DataSharingProcessor.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include "flang/Optimizer/Dialect/FIROps.h"
2525
#include "flang/Optimizer/HLFIR/HLFIRDialect.h"
2626
#include "flang/Optimizer/HLFIR/HLFIROps.h"
27+
#include "flang/Parser/openmp-utils.h"
2728
#include "flang/Semantics/attr.h"
2829
#include "flang/Semantics/tools.h"
2930
#include "llvm/ADT/Sequence.h"
@@ -465,7 +466,8 @@ bool DataSharingProcessor::isOpenMPPrivatizingConstruct(
465466
// allow a privatizing clause) are: dispatch, distribute, do, for, loop,
466467
// parallel, scope, sections, simd, single, target, target_data, task,
467468
// taskgroup, taskloop, and teams.
468-
return llvm::is_contained(privatizing, extractOmpDirective(omp));
469+
return llvm::is_contained(privatizing,
470+
parser::omp::GetOmpDirectiveName(omp).v);
469471
}
470472

471473
bool DataSharingProcessor::isOpenMPPrivatizingEvaluation(

flang/lib/Lower/OpenMP/OpenMP.cpp

Lines changed: 8 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
#include "flang/Optimizer/Dialect/FIRType.h"
3232
#include "flang/Optimizer/HLFIR/HLFIROps.h"
3333
#include "flang/Parser/characters.h"
34+
#include "flang/Parser/openmp-utils.h"
3435
#include "flang/Parser/parse-tree.h"
3536
#include "flang/Semantics/openmp-directive-sets.h"
3637
#include "flang/Semantics/tools.h"
@@ -63,28 +64,6 @@ static void processHostEvalClauses(lower::AbstractConverter &converter,
6364
lower::pft::Evaluation &eval,
6465
mlir::Location loc);
6566

66-
static llvm::omp::Directive
67-
getOpenMPDirectiveEnum(const parser::OmpLoopDirective &beginStatment) {
68-
return beginStatment.v;
69-
}
70-
71-
static llvm::omp::Directive getOpenMPDirectiveEnum(
72-
const parser::OmpBeginLoopDirective &beginLoopDirective) {
73-
return getOpenMPDirectiveEnum(
74-
std::get<parser::OmpLoopDirective>(beginLoopDirective.t));
75-
}
76-
77-
static llvm::omp::Directive
78-
getOpenMPDirectiveEnum(const parser::OpenMPLoopConstruct &ompLoopConstruct) {
79-
return getOpenMPDirectiveEnum(
80-
std::get<parser::OmpBeginLoopDirective>(ompLoopConstruct.t));
81-
}
82-
83-
static llvm::omp::Directive getOpenMPDirectiveEnum(
84-
const common::Indirection<parser::OpenMPLoopConstruct> &ompLoopConstruct) {
85-
return getOpenMPDirectiveEnum(ompLoopConstruct.value());
86-
}
87-
8867
namespace {
8968
/// Structure holding information that is needed to pass host-evaluated
9069
/// information to later lowering stages.
@@ -468,7 +447,7 @@ static void processHostEvalClauses(lower::AbstractConverter &converter,
468447
llvm::omp::Directive dir;
469448
auto &nested = parent.getFirstNestedEvaluation();
470449
if (const auto *ompEval = nested.getIf<parser::OpenMPConstruct>())
471-
dir = extractOmpDirective(*ompEval);
450+
dir = parser::omp::GetOmpDirectiveName(*ompEval).v;
472451
else
473452
return std::nullopt;
474453

@@ -508,7 +487,7 @@ static void processHostEvalClauses(lower::AbstractConverter &converter,
508487
HostEvalInfo *hostInfo = getHostEvalInfoStackTop(converter);
509488
assert(hostInfo && "expected HOST_EVAL info structure");
510489

511-
switch (extractOmpDirective(*ompEval)) {
490+
switch (parser::omp::GetOmpDirectiveName(*ompEval).v) {
512491
case OMPD_teams_distribute_parallel_do:
513492
case OMPD_teams_distribute_parallel_do_simd:
514493
cp.processThreadLimit(stmtCtx, hostInfo->ops);
@@ -569,7 +548,8 @@ static void processHostEvalClauses(lower::AbstractConverter &converter,
569548

570549
const auto *ompEval = eval.getIf<parser::OpenMPConstruct>();
571550
assert(ompEval &&
572-
llvm::omp::allTargetSet.test(extractOmpDirective(*ompEval)) &&
551+
llvm::omp::allTargetSet.test(
552+
parser::omp::GetOmpDirectiveName(*ompEval).v) &&
573553
"expected TARGET construct evaluation");
574554
(void)ompEval;
575555

@@ -3872,7 +3852,7 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
38723852
std::get_if<common::Indirection<parser::OpenMPLoopConstruct>>(
38733853
&*optLoopCons)}) {
38743854
llvm::omp::Directive nestedDirective =
3875-
getOpenMPDirectiveEnum(*ompNestedLoopCons);
3855+
parser::omp::GetOmpDirectiveName(*ompNestedLoopCons).v;
38763856
switch (nestedDirective) {
38773857
case llvm::omp::Directive::OMPD_tile:
38783858
// Emit the omp.loop_nest with annotation for tiling
@@ -3889,7 +3869,8 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
38893869
}
38903870
}
38913871

3892-
llvm::omp::Directive directive = getOpenMPDirectiveEnum(beginLoopDirective);
3872+
llvm::omp::Directive directive =
3873+
parser::omp::GetOmpDirectiveName(beginLoopDirective).v;
38933874
const parser::CharBlock &source =
38943875
std::get<parser::OmpLoopDirective>(beginLoopDirective.t).source;
38953876
ConstructQueue queue{

flang/lib/Lower/OpenMP/Utils.cpp

Lines changed: 1 addition & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include <flang/Lower/PFTBuilder.h>
2121
#include <flang/Optimizer/Builder/FIRBuilder.h>
2222
#include <flang/Optimizer/Builder/Todo.h>
23+
#include <flang/Parser/openmp-utils.h>
2324
#include <flang/Parser/parse-tree.h>
2425
#include <flang/Parser/tools.h>
2526
#include <flang/Semantics/tools.h>
@@ -663,89 +664,6 @@ bool collectLoopRelatedInfo(
663664
return found;
664665
}
665666

666-
/// Get the directive enumeration value corresponding to the given OpenMP
667-
/// construct PFT node.
668-
llvm::omp::Directive
669-
extractOmpDirective(const parser::OpenMPConstruct &ompConstruct) {
670-
return common::visit(
671-
common::visitors{
672-
[](const parser::OpenMPAllocatorsConstruct &c) {
673-
return llvm::omp::OMPD_allocators;
674-
},
675-
[](const parser::OpenMPAssumeConstruct &c) {
676-
return llvm::omp::OMPD_assume;
677-
},
678-
[](const parser::OpenMPAtomicConstruct &c) {
679-
return llvm::omp::OMPD_atomic;
680-
},
681-
[](const parser::OpenMPBlockConstruct &c) {
682-
return std::get<parser::OmpBlockDirective>(
683-
std::get<parser::OmpBeginBlockDirective>(c.t).t)
684-
.v;
685-
},
686-
[](const parser::OpenMPCriticalConstruct &c) {
687-
return llvm::omp::OMPD_critical;
688-
},
689-
[](const parser::OpenMPDeclarativeAllocate &c) {
690-
return llvm::omp::OMPD_allocate;
691-
},
692-
[](const parser::OpenMPDispatchConstruct &c) {
693-
return llvm::omp::OMPD_dispatch;
694-
},
695-
[](const parser::OpenMPExecutableAllocate &c) {
696-
return llvm::omp::OMPD_allocate;
697-
},
698-
[](const parser::OpenMPLoopConstruct &c) {
699-
return std::get<parser::OmpLoopDirective>(
700-
std::get<parser::OmpBeginLoopDirective>(c.t).t)
701-
.v;
702-
},
703-
[](const parser::OpenMPSectionConstruct &c) {
704-
return llvm::omp::OMPD_section;
705-
},
706-
[](const parser::OpenMPSectionsConstruct &c) {
707-
return std::get<parser::OmpSectionsDirective>(
708-
std::get<parser::OmpBeginSectionsDirective>(c.t).t)
709-
.v;
710-
},
711-
[](const parser::OpenMPStandaloneConstruct &c) {
712-
return common::visit(
713-
common::visitors{
714-
[](const parser::OpenMPSimpleStandaloneConstruct &c) {
715-
return c.v.DirId();
716-
},
717-
[](const parser::OpenMPFlushConstruct &c) {
718-
return llvm::omp::OMPD_flush;
719-
},
720-
[](const parser::OpenMPCancelConstruct &c) {
721-
return llvm::omp::OMPD_cancel;
722-
},
723-
[](const parser::OpenMPCancellationPointConstruct &c) {
724-
return llvm::omp::OMPD_cancellation_point;
725-
},
726-
[](const parser::OmpMetadirectiveDirective &c) {
727-
return llvm::omp::OMPD_metadirective;
728-
},
729-
[](const parser::OpenMPDepobjConstruct &c) {
730-
return llvm::omp::OMPD_depobj;
731-
},
732-
[](const parser::OpenMPInteropConstruct &c) {
733-
return llvm::omp::OMPD_interop;
734-
}},
735-
c.u);
736-
},
737-
[](const parser::OpenMPUtilityConstruct &c) {
738-
return common::visit(
739-
common::visitors{[](const parser::OmpErrorDirective &c) {
740-
return llvm::omp::OMPD_error;
741-
},
742-
[](const parser::OmpNothingDirective &c) {
743-
return llvm::omp::OMPD_nothing;
744-
}},
745-
c.u);
746-
}},
747-
ompConstruct.u);
748-
}
749667
} // namespace omp
750668
} // namespace lower
751669
} // namespace Fortran

flang/lib/Lower/OpenMP/Utils.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -167,8 +167,6 @@ bool collectLoopRelatedInfo(
167167
mlir::omp::LoopRelatedClauseOps &result,
168168
llvm::SmallVectorImpl<const semantics::Symbol *> &iv);
169169

170-
llvm::omp::Directive
171-
extractOmpDirective(const parser::OpenMPConstruct &ompConstruct);
172170
} // namespace omp
173171
} // namespace lower
174172
} // namespace Fortran

0 commit comments

Comments
 (0)