Skip to content

Commit e75e28a

Browse files
authored
[flang][OpenMP] Use OmpDirectiveSpecification in Omp[Begin|End]LoopDi… (#159087)
…rective This makes accessing directive components, such as directive name or the list of clauses simpler and more uniform across different directives. It also makes the parser simpler, since it reuses existing parsing functionality. The changes are scattered over a number of files, but they all share the same nature: - getting the begin/end directive from OpenMPLoopConstruct, - getting the llvm::omp::Directive enum, and the source location, - getting the clause list.
1 parent b22448c commit e75e28a

37 files changed

+349
-478
lines changed

flang/include/flang/Parser/openmp-utils.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,7 @@ struct DirectiveNameScope {
6767
}
6868

6969
static OmpDirectiveName GetOmpDirectiveName(const OmpBeginLoopDirective &x) {
70-
auto &dir{std::get<OmpLoopDirective>(x.t)};
71-
return MakeName(dir.source, dir.v);
70+
return x.DirName();
7271
}
7372

7473
static OmpDirectiveName GetOmpDirectiveName(const OpenMPSectionConstruct &x) {

flang/include/flang/Parser/parse-tree.h

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5158,16 +5158,12 @@ struct OpenMPStandaloneConstruct {
51585158
u;
51595159
};
51605160

5161-
struct OmpBeginLoopDirective {
5162-
TUPLE_CLASS_BOILERPLATE(OmpBeginLoopDirective);
5163-
std::tuple<OmpLoopDirective, OmpClauseList> t;
5164-
CharBlock source;
5161+
struct OmpBeginLoopDirective : public OmpBeginDirective {
5162+
INHERITED_TUPLE_CLASS_BOILERPLATE(OmpBeginLoopDirective, OmpBeginDirective);
51655163
};
51665164

5167-
struct OmpEndLoopDirective {
5168-
TUPLE_CLASS_BOILERPLATE(OmpEndLoopDirective);
5169-
std::tuple<OmpLoopDirective, OmpClauseList> t;
5170-
CharBlock source;
5165+
struct OmpEndLoopDirective : public OmpEndDirective {
5166+
INHERITED_TUPLE_CLASS_BOILERPLATE(OmpEndLoopDirective, OmpEndDirective);
51715167
};
51725168

51735169
// OpenMP directives enclosing do loop
@@ -5177,6 +5173,13 @@ struct OpenMPLoopConstruct {
51775173
TUPLE_CLASS_BOILERPLATE(OpenMPLoopConstruct);
51785174
OpenMPLoopConstruct(OmpBeginLoopDirective &&a)
51795175
: t({std::move(a), std::nullopt, std::nullopt}) {}
5176+
5177+
const OmpBeginLoopDirective &BeginDir() const {
5178+
return std::get<OmpBeginLoopDirective>(t);
5179+
}
5180+
const std::optional<OmpEndLoopDirective> &EndDir() const {
5181+
return std::get<std::optional<OmpEndLoopDirective>>(t);
5182+
}
51805183
std::tuple<OmpBeginLoopDirective, std::optional<NestedConstruct>,
51815184
std::optional<OmpEndLoopDirective>>
51825185
t;

flang/lib/Lower/OpenMP/OpenMP.cpp

Lines changed: 17 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -408,26 +408,15 @@ static void processHostEvalClauses(lower::AbstractConverter &converter,
408408
const parser::OmpClauseList *beginClauseList = nullptr;
409409
const parser::OmpClauseList *endClauseList = nullptr;
410410
common::visit(
411-
common::visitors{
412-
[&](const parser::OmpBlockConstruct &ompConstruct) {
413-
beginClauseList = &ompConstruct.BeginDir().Clauses();
414-
if (auto &endSpec = ompConstruct.EndDir())
415-
endClauseList = &endSpec->Clauses();
416-
},
417-
[&](const parser::OpenMPLoopConstruct &ompConstruct) {
418-
const auto &beginDirective =
419-
std::get<parser::OmpBeginLoopDirective>(ompConstruct.t);
420-
beginClauseList =
421-
&std::get<parser::OmpClauseList>(beginDirective.t);
422-
423-
if (auto &endDirective =
424-
std::get<std::optional<parser::OmpEndLoopDirective>>(
425-
ompConstruct.t)) {
426-
endClauseList =
427-
&std::get<parser::OmpClauseList>(endDirective->t);
428-
}
429-
},
430-
[&](const auto &) {}},
411+
[&](const auto &construct) {
412+
using Type = llvm::remove_cvref_t<decltype(construct)>;
413+
if constexpr (std::is_same_v<Type, parser::OmpBlockConstruct> ||
414+
std::is_same_v<Type, parser::OpenMPLoopConstruct>) {
415+
beginClauseList = &construct.BeginDir().Clauses();
416+
if (auto &endSpec = construct.EndDir())
417+
endClauseList = &endSpec->Clauses();
418+
}
419+
},
431420
ompEval->u);
432421

433422
assert(beginClauseList && "expected begin directive");
@@ -3820,19 +3809,12 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
38203809
semantics::SemanticsContext &semaCtx,
38213810
lower::pft::Evaluation &eval,
38223811
const parser::OpenMPLoopConstruct &loopConstruct) {
3823-
const auto &beginLoopDirective =
3824-
std::get<parser::OmpBeginLoopDirective>(loopConstruct.t);
3825-
List<Clause> clauses = makeClauses(
3826-
std::get<parser::OmpClauseList>(beginLoopDirective.t), semaCtx);
3827-
if (auto &endLoopDirective =
3828-
std::get<std::optional<parser::OmpEndLoopDirective>>(
3829-
loopConstruct.t)) {
3830-
clauses.append(makeClauses(
3831-
std::get<parser::OmpClauseList>(endLoopDirective->t), semaCtx));
3832-
}
3812+
const parser::OmpDirectiveSpecification &beginSpec = loopConstruct.BeginDir();
3813+
List<Clause> clauses = makeClauses(beginSpec.Clauses(), semaCtx);
3814+
if (auto &endSpec = loopConstruct.EndDir())
3815+
clauses.append(makeClauses(endSpec->Clauses(), semaCtx));
38333816

3834-
mlir::Location currentLocation =
3835-
converter.genLocation(beginLoopDirective.source);
3817+
mlir::Location currentLocation = converter.genLocation(beginSpec.source);
38363818

38373819
auto &optLoopCons =
38383820
std::get<std::optional<parser::NestedConstruct>>(loopConstruct.t);
@@ -3858,13 +3840,10 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
38583840
}
38593841
}
38603842

3861-
llvm::omp::Directive directive =
3862-
parser::omp::GetOmpDirectiveName(beginLoopDirective).v;
3863-
const parser::CharBlock &source =
3864-
std::get<parser::OmpLoopDirective>(beginLoopDirective.t).source;
3843+
const parser::OmpDirectiveName &beginName = beginSpec.DirName();
38653844
ConstructQueue queue{
38663845
buildConstructQueue(converter.getFirOpBuilder().getModule(), semaCtx,
3867-
eval, source, directive, clauses)};
3846+
eval, beginName.source, beginName.v, clauses)};
38683847
genOMPDispatch(converter, symTable, semaCtx, eval, currentLocation, queue,
38693848
queue.begin());
38703849
}
@@ -4047,8 +4026,7 @@ bool Fortran::lower::isOpenMPTargetConstruct(
40474026
dir = block->BeginDir().DirId();
40484027
} else if (const auto *loop =
40494028
std::get_if<parser::OpenMPLoopConstruct>(&omp.u)) {
4050-
const auto &begin = std::get<parser::OmpBeginLoopDirective>(loop->t);
4051-
dir = std::get<parser::OmpLoopDirective>(begin.t).v;
4029+
dir = loop->BeginDir().DirId();
40524030
}
40534031
return llvm::omp::allTargetSet.test(dir);
40544032
}

flang/lib/Lower/OpenMP/Utils.cpp

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -616,16 +616,11 @@ static void processTileSizesFromOpenMPConstruct(
616616
&(nestedOptional.value()));
617617
if (innerConstruct) {
618618
const auto &innerLoopDirective = innerConstruct->value();
619-
const auto &innerBegin =
620-
std::get<parser::OmpBeginLoopDirective>(innerLoopDirective.t);
621-
const auto &innerDirective =
622-
std::get<parser::OmpLoopDirective>(innerBegin.t).v;
623-
624-
if (innerDirective == llvm::omp::Directive::OMPD_tile) {
619+
const parser::OmpDirectiveSpecification &innerBeginSpec =
620+
innerLoopDirective.BeginDir();
621+
if (innerBeginSpec.DirId() == llvm::omp::Directive::OMPD_tile) {
625622
// Get the size values from parse tree and convert to a vector.
626-
const auto &innerClauseList{
627-
std::get<parser::OmpClauseList>(innerBegin.t)};
628-
for (const auto &clause : innerClauseList.v) {
623+
for (const auto &clause : innerBeginSpec.Clauses().v) {
629624
if (const auto tclause{
630625
std::get_if<parser::OmpClause::Sizes>(&clause.u)}) {
631626
processFun(tclause);

flang/lib/Parser/openmp-parsers.cpp

Lines changed: 84 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,32 @@
1818
#include "flang/Parser/openmp-utils.h"
1919
#include "flang/Parser/parse-tree.h"
2020
#include "llvm/ADT/ArrayRef.h"
21+
#include "llvm/ADT/Bitset.h"
2122
#include "llvm/ADT/STLExtras.h"
2223
#include "llvm/ADT/StringRef.h"
2324
#include "llvm/ADT/StringSet.h"
2425
#include "llvm/Frontend/OpenMP/OMP.h"
26+
#include "llvm/Support/MathExtras.h"
27+
28+
#include <algorithm>
29+
#include <cctype>
30+
#include <iterator>
31+
#include <list>
32+
#include <optional>
33+
#include <string>
34+
#include <tuple>
35+
#include <type_traits>
36+
#include <utility>
37+
#include <variant>
38+
#include <vector>
2539

2640
// OpenMP Directives and Clauses
2741
namespace Fortran::parser {
2842
using namespace Fortran::parser::omp;
2943

44+
using DirectiveSet =
45+
llvm::Bitset<llvm::NextPowerOf2(llvm::omp::Directive_enumSize)>;
46+
3047
// Helper function to print the buffer contents starting at the current point.
3148
[[maybe_unused]] static std::string ahead(const ParseState &state) {
3249
return std::string(
@@ -1349,95 +1366,46 @@ TYPE_PARSER(sourced(construct<OpenMPUtilityConstruct>(
13491366
TYPE_PARSER(sourced(construct<OmpMetadirectiveDirective>(
13501367
verbatim("METADIRECTIVE"_tok), Parser<OmpClauseList>{})))
13511368

1352-
// Omp directives enclosing do loop
1353-
TYPE_PARSER(sourced(construct<OmpLoopDirective>(first(
1354-
"DISTRIBUTE PARALLEL DO SIMD" >>
1355-
pure(llvm::omp::Directive::OMPD_distribute_parallel_do_simd),
1356-
"DISTRIBUTE PARALLEL DO" >>
1357-
pure(llvm::omp::Directive::OMPD_distribute_parallel_do),
1358-
"DISTRIBUTE SIMD" >> pure(llvm::omp::Directive::OMPD_distribute_simd),
1359-
"DISTRIBUTE" >> pure(llvm::omp::Directive::OMPD_distribute),
1360-
"DO SIMD" >> pure(llvm::omp::Directive::OMPD_do_simd),
1361-
"DO" >> pure(llvm::omp::Directive::OMPD_do),
1362-
"LOOP" >> pure(llvm::omp::Directive::OMPD_loop),
1363-
"MASKED TASKLOOP SIMD" >>
1364-
pure(llvm::omp::Directive::OMPD_masked_taskloop_simd),
1365-
"MASKED TASKLOOP" >> pure(llvm::omp::Directive::OMPD_masked_taskloop),
1366-
"MASTER TASKLOOP SIMD" >>
1367-
pure(llvm::omp::Directive::OMPD_master_taskloop_simd),
1368-
"MASTER TASKLOOP" >> pure(llvm::omp::Directive::OMPD_master_taskloop),
1369-
"PARALLEL DO SIMD" >> pure(llvm::omp::Directive::OMPD_parallel_do_simd),
1370-
"PARALLEL DO" >> pure(llvm::omp::Directive::OMPD_parallel_do),
1371-
"PARALLEL MASKED TASKLOOP SIMD" >>
1372-
pure(llvm::omp::Directive::OMPD_parallel_masked_taskloop_simd),
1373-
"PARALLEL MASKED TASKLOOP" >>
1374-
pure(llvm::omp::Directive::OMPD_parallel_masked_taskloop),
1375-
"PARALLEL MASTER TASKLOOP SIMD" >>
1376-
pure(llvm::omp::Directive::OMPD_parallel_master_taskloop_simd),
1377-
"PARALLEL MASTER TASKLOOP" >>
1378-
pure(llvm::omp::Directive::OMPD_parallel_master_taskloop),
1379-
"SIMD" >> pure(llvm::omp::Directive::OMPD_simd),
1380-
"TARGET LOOP" >> pure(llvm::omp::Directive::OMPD_target_loop),
1381-
"TARGET PARALLEL DO SIMD" >>
1382-
pure(llvm::omp::Directive::OMPD_target_parallel_do_simd),
1383-
"TARGET PARALLEL DO" >> pure(llvm::omp::Directive::OMPD_target_parallel_do),
1384-
"TARGET PARALLEL LOOP" >>
1385-
pure(llvm::omp::Directive::OMPD_target_parallel_loop),
1386-
"TARGET SIMD" >> pure(llvm::omp::Directive::OMPD_target_simd),
1387-
"TARGET TEAMS DISTRIBUTE PARALLEL DO SIMD" >>
1388-
pure(llvm::omp::Directive::
1389-
OMPD_target_teams_distribute_parallel_do_simd),
1390-
"TARGET TEAMS DISTRIBUTE PARALLEL DO" >>
1391-
pure(llvm::omp::Directive::OMPD_target_teams_distribute_parallel_do),
1392-
"TARGET TEAMS DISTRIBUTE SIMD" >>
1393-
pure(llvm::omp::Directive::OMPD_target_teams_distribute_simd),
1394-
"TARGET TEAMS DISTRIBUTE" >>
1395-
pure(llvm::omp::Directive::OMPD_target_teams_distribute),
1396-
"TARGET TEAMS LOOP" >> pure(llvm::omp::Directive::OMPD_target_teams_loop),
1397-
"TASKLOOP SIMD" >> pure(llvm::omp::Directive::OMPD_taskloop_simd),
1398-
"TASKLOOP" >> pure(llvm::omp::Directive::OMPD_taskloop),
1399-
"TEAMS DISTRIBUTE PARALLEL DO SIMD" >>
1400-
pure(llvm::omp::Directive::OMPD_teams_distribute_parallel_do_simd),
1401-
"TEAMS DISTRIBUTE PARALLEL DO" >>
1402-
pure(llvm::omp::Directive::OMPD_teams_distribute_parallel_do),
1403-
"TEAMS DISTRIBUTE SIMD" >>
1404-
pure(llvm::omp::Directive::OMPD_teams_distribute_simd),
1405-
"TEAMS DISTRIBUTE" >> pure(llvm::omp::Directive::OMPD_teams_distribute),
1406-
"TEAMS LOOP" >> pure(llvm::omp::Directive::OMPD_teams_loop),
1407-
"TILE" >> pure(llvm::omp::Directive::OMPD_tile),
1408-
"UNROLL" >> pure(llvm::omp::Directive::OMPD_unroll)))))
1409-
1410-
TYPE_PARSER(sourced(construct<OmpBeginLoopDirective>(
1411-
sourced(Parser<OmpLoopDirective>{}), Parser<OmpClauseList>{})))
1412-
14131369
static inline constexpr auto IsDirective(llvm::omp::Directive dir) {
14141370
return [dir](const OmpDirectiveName &name) -> bool { return dir == name.v; };
14151371
}
14161372

1373+
static inline constexpr auto IsMemberOf(const DirectiveSet &dirs) {
1374+
return [&dirs](const OmpDirectiveName &name) -> bool {
1375+
return dirs.test(llvm::to_underlying(name.v));
1376+
};
1377+
}
1378+
14171379
struct OmpBeginDirectiveParser {
14181380
using resultType = OmpDirectiveSpecification;
14191381

1420-
constexpr OmpBeginDirectiveParser(llvm::omp::Directive dir) : dir_(dir) {}
1382+
constexpr OmpBeginDirectiveParser(DirectiveSet dirs) : dirs_(dirs) {}
1383+
constexpr OmpBeginDirectiveParser(llvm::omp::Directive dir) {
1384+
dirs_.set(llvm::to_underlying(dir));
1385+
}
14211386

14221387
std::optional<resultType> Parse(ParseState &state) const {
1423-
auto &&p{predicated(Parser<OmpDirectiveName>{}, IsDirective(dir_)) >=
1388+
auto &&p{predicated(Parser<OmpDirectiveName>{}, IsMemberOf(dirs_)) >=
14241389
Parser<OmpDirectiveSpecification>{}};
14251390
return p.Parse(state);
14261391
}
14271392

14281393
private:
1429-
llvm::omp::Directive dir_;
1394+
DirectiveSet dirs_;
14301395
};
14311396

14321397
struct OmpEndDirectiveParser {
14331398
using resultType = OmpDirectiveSpecification;
14341399

1435-
constexpr OmpEndDirectiveParser(llvm::omp::Directive dir) : dir_(dir) {}
1400+
constexpr OmpEndDirectiveParser(DirectiveSet dirs) : dirs_(dirs) {}
1401+
constexpr OmpEndDirectiveParser(llvm::omp::Directive dir) {
1402+
dirs_.set(llvm::to_underlying(dir));
1403+
}
14361404

14371405
std::optional<resultType> Parse(ParseState &state) const {
14381406
if (startOmpLine.Parse(state)) {
14391407
if (auto endToken{verbatim("END"_sptok).Parse(state)}) {
1440-
if (auto &&dirSpec{OmpBeginDirectiveParser(dir_).Parse(state)}) {
1408+
if (auto &&dirSpec{OmpBeginDirectiveParser(dirs_).Parse(state)}) {
14411409
// Extend the "source" on both the OmpDirectiveName and the
14421410
// OmpDirectiveNameSpecification.
14431411
CharBlock &nameSource{std::get<OmpDirectiveName>(dirSpec->t).source};
@@ -1451,7 +1419,7 @@ struct OmpEndDirectiveParser {
14511419
}
14521420

14531421
private:
1454-
llvm::omp::Directive dir_;
1422+
DirectiveSet dirs_;
14551423
};
14561424

14571425
struct OmpStatementConstructParser {
@@ -1946,11 +1914,56 @@ TYPE_CONTEXT_PARSER("OpenMP construct"_en_US,
19461914
construct<OpenMPConstruct>(Parser<OpenMPAssumeConstruct>{}),
19471915
construct<OpenMPConstruct>(Parser<OpenMPCriticalConstruct>{}))))
19481916

1917+
static constexpr DirectiveSet GetLoopDirectives() {
1918+
using Directive = llvm::omp::Directive;
1919+
constexpr DirectiveSet loopDirectives{
1920+
unsigned(Directive::OMPD_distribute),
1921+
unsigned(Directive::OMPD_distribute_parallel_do),
1922+
unsigned(Directive::OMPD_distribute_parallel_do_simd),
1923+
unsigned(Directive::OMPD_distribute_simd),
1924+
unsigned(Directive::OMPD_do),
1925+
unsigned(Directive::OMPD_do_simd),
1926+
unsigned(Directive::OMPD_loop),
1927+
unsigned(Directive::OMPD_masked_taskloop),
1928+
unsigned(Directive::OMPD_masked_taskloop_simd),
1929+
unsigned(Directive::OMPD_master_taskloop),
1930+
unsigned(Directive::OMPD_master_taskloop_simd),
1931+
unsigned(Directive::OMPD_parallel_do),
1932+
unsigned(Directive::OMPD_parallel_do_simd),
1933+
unsigned(Directive::OMPD_parallel_masked_taskloop),
1934+
unsigned(Directive::OMPD_parallel_masked_taskloop_simd),
1935+
unsigned(Directive::OMPD_parallel_master_taskloop),
1936+
unsigned(Directive::OMPD_parallel_master_taskloop_simd),
1937+
unsigned(Directive::OMPD_simd),
1938+
unsigned(Directive::OMPD_target_loop),
1939+
unsigned(Directive::OMPD_target_parallel_do),
1940+
unsigned(Directive::OMPD_target_parallel_do_simd),
1941+
unsigned(Directive::OMPD_target_parallel_loop),
1942+
unsigned(Directive::OMPD_target_simd),
1943+
unsigned(Directive::OMPD_target_teams_distribute),
1944+
unsigned(Directive::OMPD_target_teams_distribute_parallel_do),
1945+
unsigned(Directive::OMPD_target_teams_distribute_parallel_do_simd),
1946+
unsigned(Directive::OMPD_target_teams_distribute_simd),
1947+
unsigned(Directive::OMPD_target_teams_loop),
1948+
unsigned(Directive::OMPD_taskloop),
1949+
unsigned(Directive::OMPD_taskloop_simd),
1950+
unsigned(Directive::OMPD_teams_distribute),
1951+
unsigned(Directive::OMPD_teams_distribute_parallel_do),
1952+
unsigned(Directive::OMPD_teams_distribute_parallel_do_simd),
1953+
unsigned(Directive::OMPD_teams_distribute_simd),
1954+
unsigned(Directive::OMPD_teams_loop),
1955+
unsigned(Directive::OMPD_tile),
1956+
unsigned(Directive::OMPD_unroll),
1957+
};
1958+
return loopDirectives;
1959+
}
1960+
1961+
TYPE_PARSER(sourced(construct<OmpBeginLoopDirective>(
1962+
sourced(OmpBeginDirectiveParser(GetLoopDirectives())))))
1963+
19491964
// END OMP Loop directives
1950-
TYPE_PARSER(
1951-
startOmpLine >> sourced(construct<OmpEndLoopDirective>(
1952-
sourced("END"_tok >> Parser<OmpLoopDirective>{}),
1953-
Parser<OmpClauseList>{})))
1965+
TYPE_PARSER(sourced(construct<OmpEndLoopDirective>(
1966+
sourced(OmpEndDirectiveParser(GetLoopDirectives())))))
19541967

19551968
TYPE_PARSER(construct<OpenMPLoopConstruct>(
19561969
Parser<OmpBeginLoopDirective>{} / endOmpLine))

0 commit comments

Comments
 (0)