Skip to content

Commit 0599bff

Browse files
committed
[flang][OpenMP] Use OmpDirectiveSpecification in Omp[Begin|End]LoopDirective
This makes accessing directive components, such as directive name or the list of clauses simpler and more uniform across different directives. It also makes the parser simpler, since it reuses existing parsing functionality.
1 parent c366cbd commit 0599bff

37 files changed

+351
-478
lines changed

flang/include/flang/Parser/openmp-utils.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,7 @@ struct DirectiveNameScope {
6767
}
6868

6969
static OmpDirectiveName GetOmpDirectiveName(const OmpBeginLoopDirective &x) {
70-
auto &dir{std::get<OmpLoopDirective>(x.t)};
71-
return MakeName(dir.source, dir.v);
70+
return x.DirName();
7271
}
7372

7473
static OmpDirectiveName GetOmpDirectiveName(const OpenMPSectionConstruct &x) {

flang/include/flang/Parser/parse-tree.h

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5158,16 +5158,12 @@ struct OpenMPStandaloneConstruct {
51585158
u;
51595159
};
51605160

5161-
struct OmpBeginLoopDirective {
5162-
TUPLE_CLASS_BOILERPLATE(OmpBeginLoopDirective);
5163-
std::tuple<OmpLoopDirective, OmpClauseList> t;
5164-
CharBlock source;
5161+
struct OmpBeginLoopDirective : public OmpBeginDirective {
5162+
INHERITED_TUPLE_CLASS_BOILERPLATE(OmpBeginLoopDirective, OmpBeginDirective);
51655163
};
51665164

5167-
struct OmpEndLoopDirective {
5168-
TUPLE_CLASS_BOILERPLATE(OmpEndLoopDirective);
5169-
std::tuple<OmpLoopDirective, OmpClauseList> t;
5170-
CharBlock source;
5165+
struct OmpEndLoopDirective : public OmpEndDirective {
5166+
INHERITED_TUPLE_CLASS_BOILERPLATE(OmpEndLoopDirective, OmpEndDirective);
51715167
};
51725168

51735169
// OpenMP directives enclosing do loop
@@ -5177,6 +5173,13 @@ struct OpenMPLoopConstruct {
51775173
TUPLE_CLASS_BOILERPLATE(OpenMPLoopConstruct);
51785174
OpenMPLoopConstruct(OmpBeginLoopDirective &&a)
51795175
: t({std::move(a), std::nullopt, std::nullopt}) {}
5176+
5177+
const OmpBeginLoopDirective &BeginDir() const {
5178+
return std::get<OmpBeginLoopDirective>(t);
5179+
}
5180+
const std::optional<OmpEndLoopDirective> &EndDir() const {
5181+
return std::get<std::optional<OmpEndLoopDirective>>(t);
5182+
}
51805183
std::tuple<OmpBeginLoopDirective, std::optional<NestedConstruct>,
51815184
std::optional<OmpEndLoopDirective>>
51825185
t;

flang/lib/Lower/OpenMP/OpenMP.cpp

Lines changed: 17 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -408,26 +408,15 @@ static void processHostEvalClauses(lower::AbstractConverter &converter,
408408
const parser::OmpClauseList *beginClauseList = nullptr;
409409
const parser::OmpClauseList *endClauseList = nullptr;
410410
common::visit(
411-
common::visitors{
412-
[&](const parser::OmpBlockConstruct &ompConstruct) {
413-
beginClauseList = &ompConstruct.BeginDir().Clauses();
414-
if (auto &endSpec = ompConstruct.EndDir())
415-
endClauseList = &endSpec->Clauses();
416-
},
417-
[&](const parser::OpenMPLoopConstruct &ompConstruct) {
418-
const auto &beginDirective =
419-
std::get<parser::OmpBeginLoopDirective>(ompConstruct.t);
420-
beginClauseList =
421-
&std::get<parser::OmpClauseList>(beginDirective.t);
422-
423-
if (auto &endDirective =
424-
std::get<std::optional<parser::OmpEndLoopDirective>>(
425-
ompConstruct.t)) {
426-
endClauseList =
427-
&std::get<parser::OmpClauseList>(endDirective->t);
428-
}
429-
},
430-
[&](const auto &) {}},
411+
[&](const auto &construct) {
412+
using Type = llvm::remove_cvref_t<decltype(construct)>;
413+
if constexpr (std::is_same_v<Type, parser::OmpBlockConstruct> ||
414+
std::is_same_v<Type, parser::OpenMPLoopConstruct>) {
415+
beginClauseList = &construct.BeginDir().Clauses();
416+
if (auto &endSpec = construct.EndDir())
417+
endClauseList = &endSpec->Clauses();
418+
}
419+
},
431420
ompEval->u);
432421

433422
assert(beginClauseList && "expected begin directive");
@@ -3820,19 +3809,12 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
38203809
semantics::SemanticsContext &semaCtx,
38213810
lower::pft::Evaluation &eval,
38223811
const parser::OpenMPLoopConstruct &loopConstruct) {
3823-
const auto &beginLoopDirective =
3824-
std::get<parser::OmpBeginLoopDirective>(loopConstruct.t);
3825-
List<Clause> clauses = makeClauses(
3826-
std::get<parser::OmpClauseList>(beginLoopDirective.t), semaCtx);
3827-
if (auto &endLoopDirective =
3828-
std::get<std::optional<parser::OmpEndLoopDirective>>(
3829-
loopConstruct.t)) {
3830-
clauses.append(makeClauses(
3831-
std::get<parser::OmpClauseList>(endLoopDirective->t), semaCtx));
3832-
}
3812+
const parser::OmpDirectiveSpecification &beginSpec = loopConstruct.BeginDir();
3813+
List<Clause> clauses = makeClauses(beginSpec.Clauses(), semaCtx);
3814+
if (auto &endSpec = loopConstruct.EndDir())
3815+
clauses.append(makeClauses(endSpec->Clauses(), semaCtx));
38333816

3834-
mlir::Location currentLocation =
3835-
converter.genLocation(beginLoopDirective.source);
3817+
mlir::Location currentLocation = converter.genLocation(beginSpec.source);
38363818

38373819
auto &optLoopCons =
38383820
std::get<std::optional<parser::NestedConstruct>>(loopConstruct.t);
@@ -3858,13 +3840,10 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
38583840
}
38593841
}
38603842

3861-
llvm::omp::Directive directive =
3862-
parser::omp::GetOmpDirectiveName(beginLoopDirective).v;
3863-
const parser::CharBlock &source =
3864-
std::get<parser::OmpLoopDirective>(beginLoopDirective.t).source;
3843+
const parser::OmpDirectiveName &beginName = beginSpec.DirName();
38653844
ConstructQueue queue{
38663845
buildConstructQueue(converter.getFirOpBuilder().getModule(), semaCtx,
3867-
eval, source, directive, clauses)};
3846+
eval, beginName.source, beginName.v, clauses)};
38683847
genOMPDispatch(converter, symTable, semaCtx, eval, currentLocation, queue,
38693848
queue.begin());
38703849
}
@@ -4047,8 +4026,7 @@ bool Fortran::lower::isOpenMPTargetConstruct(
40474026
dir = block->BeginDir().DirId();
40484027
} else if (const auto *loop =
40494028
std::get_if<parser::OpenMPLoopConstruct>(&omp.u)) {
4050-
const auto &begin = std::get<parser::OmpBeginLoopDirective>(loop->t);
4051-
dir = std::get<parser::OmpLoopDirective>(begin.t).v;
4029+
dir = loop->BeginDir().DirId();
40524030
}
40534031
return llvm::omp::allTargetSet.test(dir);
40544032
}

flang/lib/Lower/OpenMP/Utils.cpp

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -616,16 +616,11 @@ static void processTileSizesFromOpenMPConstruct(
616616
&(nestedOptional.value()));
617617
if (innerConstruct) {
618618
const auto &innerLoopDirective = innerConstruct->value();
619-
const auto &innerBegin =
620-
std::get<parser::OmpBeginLoopDirective>(innerLoopDirective.t);
621-
const auto &innerDirective =
622-
std::get<parser::OmpLoopDirective>(innerBegin.t).v;
623-
624-
if (innerDirective == llvm::omp::Directive::OMPD_tile) {
619+
const parser::OmpDirectiveSpecification &innerBeginSpec =
620+
innerLoopDirective.BeginDir();
621+
if (innerBeginSpec.DirId() == llvm::omp::Directive::OMPD_tile) {
625622
// Get the size values from parse tree and convert to a vector.
626-
const auto &innerClauseList{
627-
std::get<parser::OmpClauseList>(innerBegin.t)};
628-
for (const auto &clause : innerClauseList.v) {
623+
for (const auto &clause : innerBeginSpec.Clauses().v) {
629624
if (const auto tclause{
630625
std::get_if<parser::OmpClause::Sizes>(&clause.u)}) {
631626
processFun(tclause);

flang/lib/Parser/openmp-parsers.cpp

Lines changed: 86 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,34 @@
1818
#include "flang/Parser/openmp-utils.h"
1919
#include "flang/Parser/parse-tree.h"
2020
#include "llvm/ADT/ArrayRef.h"
21+
#include "llvm/ADT/Bitset.h"
2122
#include "llvm/ADT/STLExtras.h"
2223
#include "llvm/ADT/StringRef.h"
2324
#include "llvm/ADT/StringSet.h"
2425
#include "llvm/Frontend/OpenMP/OMP.h"
26+
#include "llvm/Support/MathExtras.h"
27+
28+
#include <algorithm>
29+
#include <cctype>
30+
#include <iterator>
31+
#include <list>
32+
#include <optional>
33+
#include <string>
34+
#include <tuple>
35+
#include <type_traits>
36+
#include <utility>
37+
#include <variant>
38+
#include <vector>
2539

2640
// OpenMP Directives and Clauses
2741
namespace Fortran::parser {
2842
using namespace Fortran::parser::omp;
2943

44+
static constexpr size_t DirectiveCount{
45+
static_cast<size_t>(llvm::omp::Directive::Last_) -
46+
static_cast<size_t>(llvm::omp::Directive::First_) + 1};
47+
using DirectiveSet = llvm::Bitset<llvm::NextPowerOf2(DirectiveCount)>;
48+
3049
// Helper function to print the buffer contents starting at the current point.
3150
[[maybe_unused]] static std::string ahead(const ParseState &state) {
3251
return std::string(
@@ -1349,95 +1368,46 @@ TYPE_PARSER(sourced(construct<OpenMPUtilityConstruct>(
13491368
TYPE_PARSER(sourced(construct<OmpMetadirectiveDirective>(
13501369
verbatim("METADIRECTIVE"_tok), Parser<OmpClauseList>{})))
13511370

1352-
// Omp directives enclosing do loop
1353-
TYPE_PARSER(sourced(construct<OmpLoopDirective>(first(
1354-
"DISTRIBUTE PARALLEL DO SIMD" >>
1355-
pure(llvm::omp::Directive::OMPD_distribute_parallel_do_simd),
1356-
"DISTRIBUTE PARALLEL DO" >>
1357-
pure(llvm::omp::Directive::OMPD_distribute_parallel_do),
1358-
"DISTRIBUTE SIMD" >> pure(llvm::omp::Directive::OMPD_distribute_simd),
1359-
"DISTRIBUTE" >> pure(llvm::omp::Directive::OMPD_distribute),
1360-
"DO SIMD" >> pure(llvm::omp::Directive::OMPD_do_simd),
1361-
"DO" >> pure(llvm::omp::Directive::OMPD_do),
1362-
"LOOP" >> pure(llvm::omp::Directive::OMPD_loop),
1363-
"MASKED TASKLOOP SIMD" >>
1364-
pure(llvm::omp::Directive::OMPD_masked_taskloop_simd),
1365-
"MASKED TASKLOOP" >> pure(llvm::omp::Directive::OMPD_masked_taskloop),
1366-
"MASTER TASKLOOP SIMD" >>
1367-
pure(llvm::omp::Directive::OMPD_master_taskloop_simd),
1368-
"MASTER TASKLOOP" >> pure(llvm::omp::Directive::OMPD_master_taskloop),
1369-
"PARALLEL DO SIMD" >> pure(llvm::omp::Directive::OMPD_parallel_do_simd),
1370-
"PARALLEL DO" >> pure(llvm::omp::Directive::OMPD_parallel_do),
1371-
"PARALLEL MASKED TASKLOOP SIMD" >>
1372-
pure(llvm::omp::Directive::OMPD_parallel_masked_taskloop_simd),
1373-
"PARALLEL MASKED TASKLOOP" >>
1374-
pure(llvm::omp::Directive::OMPD_parallel_masked_taskloop),
1375-
"PARALLEL MASTER TASKLOOP SIMD" >>
1376-
pure(llvm::omp::Directive::OMPD_parallel_master_taskloop_simd),
1377-
"PARALLEL MASTER TASKLOOP" >>
1378-
pure(llvm::omp::Directive::OMPD_parallel_master_taskloop),
1379-
"SIMD" >> pure(llvm::omp::Directive::OMPD_simd),
1380-
"TARGET LOOP" >> pure(llvm::omp::Directive::OMPD_target_loop),
1381-
"TARGET PARALLEL DO SIMD" >>
1382-
pure(llvm::omp::Directive::OMPD_target_parallel_do_simd),
1383-
"TARGET PARALLEL DO" >> pure(llvm::omp::Directive::OMPD_target_parallel_do),
1384-
"TARGET PARALLEL LOOP" >>
1385-
pure(llvm::omp::Directive::OMPD_target_parallel_loop),
1386-
"TARGET SIMD" >> pure(llvm::omp::Directive::OMPD_target_simd),
1387-
"TARGET TEAMS DISTRIBUTE PARALLEL DO SIMD" >>
1388-
pure(llvm::omp::Directive::
1389-
OMPD_target_teams_distribute_parallel_do_simd),
1390-
"TARGET TEAMS DISTRIBUTE PARALLEL DO" >>
1391-
pure(llvm::omp::Directive::OMPD_target_teams_distribute_parallel_do),
1392-
"TARGET TEAMS DISTRIBUTE SIMD" >>
1393-
pure(llvm::omp::Directive::OMPD_target_teams_distribute_simd),
1394-
"TARGET TEAMS DISTRIBUTE" >>
1395-
pure(llvm::omp::Directive::OMPD_target_teams_distribute),
1396-
"TARGET TEAMS LOOP" >> pure(llvm::omp::Directive::OMPD_target_teams_loop),
1397-
"TASKLOOP SIMD" >> pure(llvm::omp::Directive::OMPD_taskloop_simd),
1398-
"TASKLOOP" >> pure(llvm::omp::Directive::OMPD_taskloop),
1399-
"TEAMS DISTRIBUTE PARALLEL DO SIMD" >>
1400-
pure(llvm::omp::Directive::OMPD_teams_distribute_parallel_do_simd),
1401-
"TEAMS DISTRIBUTE PARALLEL DO" >>
1402-
pure(llvm::omp::Directive::OMPD_teams_distribute_parallel_do),
1403-
"TEAMS DISTRIBUTE SIMD" >>
1404-
pure(llvm::omp::Directive::OMPD_teams_distribute_simd),
1405-
"TEAMS DISTRIBUTE" >> pure(llvm::omp::Directive::OMPD_teams_distribute),
1406-
"TEAMS LOOP" >> pure(llvm::omp::Directive::OMPD_teams_loop),
1407-
"TILE" >> pure(llvm::omp::Directive::OMPD_tile),
1408-
"UNROLL" >> pure(llvm::omp::Directive::OMPD_unroll)))))
1409-
1410-
TYPE_PARSER(sourced(construct<OmpBeginLoopDirective>(
1411-
sourced(Parser<OmpLoopDirective>{}), Parser<OmpClauseList>{})))
1412-
14131371
static inline constexpr auto IsDirective(llvm::omp::Directive dir) {
14141372
return [dir](const OmpDirectiveName &name) -> bool { return dir == name.v; };
14151373
}
14161374

1375+
static inline constexpr auto IsMemberOf(const DirectiveSet &dirs) {
1376+
return [&dirs](const OmpDirectiveName &name) -> bool {
1377+
return dirs.test(llvm::to_underlying(name.v));
1378+
};
1379+
}
1380+
14171381
struct OmpBeginDirectiveParser {
14181382
using resultType = OmpDirectiveSpecification;
14191383

1420-
constexpr OmpBeginDirectiveParser(llvm::omp::Directive dir) : dir_(dir) {}
1384+
constexpr OmpBeginDirectiveParser(DirectiveSet dirs) : dirs_(dirs) {}
1385+
constexpr OmpBeginDirectiveParser(llvm::omp::Directive dir) {
1386+
dirs_.set(llvm::to_underlying(dir));
1387+
}
14211388

14221389
std::optional<resultType> Parse(ParseState &state) const {
1423-
auto &&p{predicated(Parser<OmpDirectiveName>{}, IsDirective(dir_)) >=
1390+
auto &&p{predicated(Parser<OmpDirectiveName>{}, IsMemberOf(dirs_)) >=
14241391
Parser<OmpDirectiveSpecification>{}};
14251392
return p.Parse(state);
14261393
}
14271394

14281395
private:
1429-
llvm::omp::Directive dir_;
1396+
DirectiveSet dirs_;
14301397
};
14311398

14321399
struct OmpEndDirectiveParser {
14331400
using resultType = OmpDirectiveSpecification;
14341401

1435-
constexpr OmpEndDirectiveParser(llvm::omp::Directive dir) : dir_(dir) {}
1402+
constexpr OmpEndDirectiveParser(DirectiveSet dirs) : dirs_(dirs) {}
1403+
constexpr OmpEndDirectiveParser(llvm::omp::Directive dir) {
1404+
dirs_.set(llvm::to_underlying(dir));
1405+
}
14361406

14371407
std::optional<resultType> Parse(ParseState &state) const {
14381408
if (startOmpLine.Parse(state)) {
14391409
if (auto endToken{verbatim("END"_sptok).Parse(state)}) {
1440-
if (auto &&dirSpec{OmpBeginDirectiveParser(dir_).Parse(state)}) {
1410+
if (auto &&dirSpec{OmpBeginDirectiveParser(dirs_).Parse(state)}) {
14411411
// Extend the "source" on both the OmpDirectiveName and the
14421412
// OmpDirectiveNameSpecification.
14431413
CharBlock &nameSource{std::get<OmpDirectiveName>(dirSpec->t).source};
@@ -1451,7 +1421,7 @@ struct OmpEndDirectiveParser {
14511421
}
14521422

14531423
private:
1454-
llvm::omp::Directive dir_;
1424+
DirectiveSet dirs_;
14551425
};
14561426

14571427
struct OmpStatementConstructParser {
@@ -1946,11 +1916,56 @@ TYPE_CONTEXT_PARSER("OpenMP construct"_en_US,
19461916
construct<OpenMPConstruct>(Parser<OpenMPAssumeConstruct>{}),
19471917
construct<OpenMPConstruct>(Parser<OpenMPCriticalConstruct>{}))))
19481918

1919+
static constexpr DirectiveSet GetLoopDirectives() {
1920+
using Directive = llvm::omp::Directive;
1921+
constexpr DirectiveSet loopDirectives{
1922+
unsigned(Directive::OMPD_distribute),
1923+
unsigned(Directive::OMPD_distribute_parallel_do),
1924+
unsigned(Directive::OMPD_distribute_parallel_do_simd),
1925+
unsigned(Directive::OMPD_distribute_simd),
1926+
unsigned(Directive::OMPD_do),
1927+
unsigned(Directive::OMPD_do_simd),
1928+
unsigned(Directive::OMPD_loop),
1929+
unsigned(Directive::OMPD_masked_taskloop),
1930+
unsigned(Directive::OMPD_masked_taskloop_simd),
1931+
unsigned(Directive::OMPD_master_taskloop),
1932+
unsigned(Directive::OMPD_master_taskloop_simd),
1933+
unsigned(Directive::OMPD_parallel_do),
1934+
unsigned(Directive::OMPD_parallel_do_simd),
1935+
unsigned(Directive::OMPD_parallel_masked_taskloop),
1936+
unsigned(Directive::OMPD_parallel_masked_taskloop_simd),
1937+
unsigned(Directive::OMPD_parallel_master_taskloop),
1938+
unsigned(Directive::OMPD_parallel_master_taskloop_simd),
1939+
unsigned(Directive::OMPD_simd),
1940+
unsigned(Directive::OMPD_target_loop),
1941+
unsigned(Directive::OMPD_target_parallel_do),
1942+
unsigned(Directive::OMPD_target_parallel_do_simd),
1943+
unsigned(Directive::OMPD_target_parallel_loop),
1944+
unsigned(Directive::OMPD_target_simd),
1945+
unsigned(Directive::OMPD_target_teams_distribute),
1946+
unsigned(Directive::OMPD_target_teams_distribute_parallel_do),
1947+
unsigned(Directive::OMPD_target_teams_distribute_parallel_do_simd),
1948+
unsigned(Directive::OMPD_target_teams_distribute_simd),
1949+
unsigned(Directive::OMPD_target_teams_loop),
1950+
unsigned(Directive::OMPD_taskloop),
1951+
unsigned(Directive::OMPD_taskloop_simd),
1952+
unsigned(Directive::OMPD_teams_distribute),
1953+
unsigned(Directive::OMPD_teams_distribute_parallel_do),
1954+
unsigned(Directive::OMPD_teams_distribute_parallel_do_simd),
1955+
unsigned(Directive::OMPD_teams_distribute_simd),
1956+
unsigned(Directive::OMPD_teams_loop),
1957+
unsigned(Directive::OMPD_tile),
1958+
unsigned(Directive::OMPD_unroll),
1959+
};
1960+
return loopDirectives;
1961+
}
1962+
1963+
TYPE_PARSER(sourced(construct<OmpBeginLoopDirective>(
1964+
sourced(OmpBeginDirectiveParser(GetLoopDirectives())))))
1965+
19491966
// END OMP Loop directives
1950-
TYPE_PARSER(
1951-
startOmpLine >> sourced(construct<OmpEndLoopDirective>(
1952-
sourced("END"_tok >> Parser<OmpLoopDirective>{}),
1953-
Parser<OmpClauseList>{})))
1967+
TYPE_PARSER(sourced(construct<OmpEndLoopDirective>(
1968+
sourced(OmpEndDirectiveParser(GetLoopDirectives())))))
19541969

19551970
TYPE_PARSER(construct<OpenMPLoopConstruct>(
19561971
Parser<OmpBeginLoopDirective>{} / endOmpLine))

0 commit comments

Comments
 (0)