Skip to content

Commit c2f8d80

Browse files
kparzyszkcloudy0717
authored andcommitted
[flang][OpenMP] Expand GetOmpObjectList to all subclasses of OmpClause (llvm#170351)
Use GetOmpObjectList instead of extracting the object list by hand.
1 parent af39d95 commit c2f8d80

File tree

5 files changed

+143
-97
lines changed

5 files changed

+143
-97
lines changed

flang/include/flang/Parser/openmp-utils.h

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#define FORTRAN_PARSER_OPENMP_UTILS_H
1515

1616
#include "flang/Common/indirection.h"
17+
#include "flang/Common/template.h"
1718
#include "flang/Parser/parse-tree.h"
1819
#include "llvm/Frontend/OpenMP/OMP.h"
1920

@@ -127,7 +128,88 @@ template <typename T> struct IsStatement<Statement<T>> {
127128
std::optional<Label> GetStatementLabel(const ExecutionPartConstruct &x);
128129
std::optional<Label> GetFinalLabel(const OpenMPConstruct &x);
129130

131+
namespace detail {
132+
// Clauses with flangClass = "OmpObjectList".
133+
using MemberObjectListClauses =
134+
std::tuple<OmpClause::Copyin, OmpClause::Copyprivate, OmpClause::Exclusive,
135+
OmpClause::Firstprivate, OmpClause::HasDeviceAddr, OmpClause::Inclusive,
136+
OmpClause::IsDevicePtr, OmpClause::Link, OmpClause::Private,
137+
OmpClause::Shared, OmpClause::UseDeviceAddr, OmpClause::UseDevicePtr>;
138+
139+
// Clauses with flangClass = "OmpSomeClause", and OmpObjectList a
140+
// member of tuple OmpSomeClause::t.
141+
using TupleObjectListClauses = std::tuple<OmpClause::AdjustArgs,
142+
OmpClause::Affinity, OmpClause::Aligned, OmpClause::Allocate,
143+
OmpClause::Enter, OmpClause::From, OmpClause::InReduction,
144+
OmpClause::Lastprivate, OmpClause::Linear, OmpClause::Map,
145+
OmpClause::Reduction, OmpClause::TaskReduction, OmpClause::To>;
146+
147+
// Does U have WrapperTrait (i.e. has a member 'v'), and if so, is T the
148+
// type of v?
149+
template <typename T, typename U, bool IsWrapper> struct WrappedInType {
150+
static constexpr bool value{false};
151+
};
152+
153+
template <typename T, typename U> struct WrappedInType<T, U, true> {
154+
static constexpr bool value{std::is_same_v<T, decltype(U::v)>};
155+
};
156+
157+
// Same as WrappedInType, but with a list of types Us. Satisfied if any
158+
// type U in Us satisfies WrappedInType<T, U>.
159+
template <typename...> struct WrappedInTypes;
160+
161+
template <typename T> struct WrappedInTypes<T> {
162+
static constexpr bool value{false};
163+
};
164+
165+
template <typename T, typename U, typename... Us>
166+
struct WrappedInTypes<T, U, Us...> {
167+
static constexpr bool value{WrappedInType<T, U, WrapperTrait<U>>::value ||
168+
WrappedInTypes<T, Us...>::value};
169+
};
170+
171+
// Same as WrappedInTypes, but takes type list in a form of a tuple or
172+
// a variant.
173+
template <typename...> struct WrappedInTupleOrVariant {
174+
static constexpr bool value{false};
175+
};
176+
template <typename T, typename... Us>
177+
struct WrappedInTupleOrVariant<T, std::tuple<Us...>> {
178+
static constexpr bool value{WrappedInTypes<T, Us...>::value};
179+
};
180+
template <typename T, typename... Us>
181+
struct WrappedInTupleOrVariant<T, std::variant<Us...>> {
182+
static constexpr bool value{WrappedInTypes<T, Us...>::value};
183+
};
184+
template <typename T, typename U>
185+
constexpr bool WrappedInTupleOrVariantV{WrappedInTupleOrVariant<T, U>::value};
186+
} // namespace detail
187+
188+
template <typename T> const OmpObjectList *GetOmpObjectList(const T &clause) {
189+
using namespace detail;
190+
static_assert(std::is_class_v<T>, "Unexpected argument type");
191+
192+
if constexpr (common::HasMember<T, decltype(OmpClause::u)>) {
193+
if constexpr (common::HasMember<T, MemberObjectListClauses>) {
194+
return &clause.v;
195+
} else if constexpr (common::HasMember<T, TupleObjectListClauses>) {
196+
return &std::get<OmpObjectList>(clause.v.t);
197+
} else {
198+
return nullptr;
199+
}
200+
} else if constexpr (WrappedInTupleOrVariantV<T, TupleObjectListClauses>) {
201+
return &std::get<OmpObjectList>(clause.t);
202+
} else if constexpr (WrappedInTupleOrVariantV<T, decltype(OmpClause::u)>) {
203+
return nullptr;
204+
} else {
205+
// The condition should be type-dependent, but it should always be false.
206+
static_assert(sizeof(T) < 0 && "Unexpected argument type");
207+
}
208+
}
209+
130210
const OmpObjectList *GetOmpObjectList(const OmpClause &clause);
211+
const OmpObjectList *GetOmpObjectList(const OmpClause::Depend &clause);
212+
const OmpObjectList *GetOmpObjectList(const OmpDependClause::TaskDep &x);
131213

132214
template <typename T>
133215
const T *GetFirstArgument(const OmpDirectiveSpecification &spec) {

flang/lib/Parser/openmp-utils.cpp

Lines changed: 11 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -117,43 +117,20 @@ std::optional<Label> GetFinalLabel(const OpenMPConstruct &x) {
117117
}
118118

119119
const OmpObjectList *GetOmpObjectList(const OmpClause &clause) {
120-
// Clauses with OmpObjectList as its data member
121-
using MemberObjectListClauses = std::tuple<OmpClause::Copyin,
122-
OmpClause::Copyprivate, OmpClause::Exclusive, OmpClause::Firstprivate,
123-
OmpClause::HasDeviceAddr, OmpClause::Inclusive, OmpClause::IsDevicePtr,
124-
OmpClause::Link, OmpClause::Private, OmpClause::Shared,
125-
OmpClause::UseDeviceAddr, OmpClause::UseDevicePtr>;
126-
127-
// Clauses with OmpObjectList in the tuple
128-
using TupleObjectListClauses = std::tuple<OmpClause::AdjustArgs,
129-
OmpClause::Affinity, OmpClause::Aligned, OmpClause::Allocate,
130-
OmpClause::Enter, OmpClause::From, OmpClause::InReduction,
131-
OmpClause::Lastprivate, OmpClause::Linear, OmpClause::Map,
132-
OmpClause::Reduction, OmpClause::TaskReduction, OmpClause::To>;
133-
134-
// TODO:: Generate the tuples using TableGen.
120+
return common::visit([](auto &&s) { return GetOmpObjectList(s); }, clause.u);
121+
}
122+
123+
const OmpObjectList *GetOmpObjectList(const OmpClause::Depend &clause) {
135124
return common::visit(
136125
common::visitors{
137-
[&](const OmpClause::Depend &x) -> const OmpObjectList * {
138-
if (auto *taskDep{std::get_if<OmpDependClause::TaskDep>(&x.v.u)}) {
139-
return &std::get<OmpObjectList>(taskDep->t);
140-
} else {
141-
return nullptr;
142-
}
143-
},
144-
[&](const auto &x) -> const OmpObjectList * {
145-
using Ty = std::decay_t<decltype(x)>;
146-
if constexpr (common::HasMember<Ty, MemberObjectListClauses>) {
147-
return &x.v;
148-
} else if constexpr (common::HasMember<Ty,
149-
TupleObjectListClauses>) {
150-
return &std::get<OmpObjectList>(x.v.t);
151-
} else {
152-
return nullptr;
153-
}
154-
},
126+
[](const OmpDoacross &) { return nullptr; },
127+
[](const OmpDependClause::TaskDep &x) { return GetOmpObjectList(x); },
155128
},
156-
clause.u);
129+
clause.v.u);
130+
}
131+
132+
const OmpObjectList *GetOmpObjectList(const OmpDependClause::TaskDep &x) {
133+
return &std::get<OmpObjectList>(x.t);
157134
}
158135

159136
const BlockConstruct *GetFortranBlockConstruct(

flang/lib/Semantics/check-omp-loop.cpp

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -480,9 +480,8 @@ void OmpStructureChecker::CheckDistLinear(
480480

481481
// Collect symbols of all the variables from linear clauses
482482
for (auto &clause : clauses.v) {
483-
if (auto *linearClause{std::get_if<parser::OmpClause::Linear>(&clause.u)}) {
484-
auto &objects{std::get<parser::OmpObjectList>(linearClause->v.t)};
485-
GetSymbolsInObjectList(objects, indexVars);
483+
if (std::get_if<parser::OmpClause::Linear>(&clause.u)) {
484+
GetSymbolsInObjectList(*parser::omp::GetOmpObjectList(clause), indexVars);
486485
}
487486
}
488487

@@ -604,8 +603,6 @@ void OmpStructureChecker::Leave(const parser::OpenMPLoopConstruct &x) {
604603
auto *maybeModifier{OmpGetUniqueModifier<ReductionModifier>(modifiers)};
605604
if (maybeModifier &&
606605
maybeModifier->v == ReductionModifier::Value::Inscan) {
607-
const auto &objectList{
608-
std::get<parser::OmpObjectList>(reductionClause->v.t)};
609606
auto checkReductionSymbolInScan = [&](const parser::Name *name) {
610607
if (auto &symbol = name->symbol) {
611608
if (!symbol->test(Symbol::Flag::OmpInclusiveScan) &&
@@ -618,7 +615,7 @@ void OmpStructureChecker::Leave(const parser::OpenMPLoopConstruct &x) {
618615
}
619616
}
620617
};
621-
for (const auto &ompObj : objectList.v) {
618+
for (const auto &ompObj : parser::omp::GetOmpObjectList(clause)->v) {
622619
common::visit(
623620
common::visitors{
624621
[&](const parser::Designator &designator) {

flang/lib/Semantics/check-omp-structure.cpp

Lines changed: 33 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -624,11 +624,9 @@ void OmpStructureChecker::CheckMultListItems() {
624624

625625
// Linear clause
626626
for (auto [_, clause] : FindClauses(llvm::omp::Clause::OMPC_linear)) {
627-
auto &linearClause{std::get<parser::OmpClause::Linear>(clause->u)};
628627
std::list<parser::Name> nameList;
629628
SymbolSourceMap symbols;
630-
GetSymbolsInObjectList(
631-
std::get<parser::OmpObjectList>(linearClause.v.t), symbols);
629+
GetSymbolsInObjectList(*GetOmpObjectList(*clause), symbols);
632630
llvm::transform(symbols, std::back_inserter(nameList), [&](auto &&pair) {
633631
return parser::Name{pair.second, const_cast<Symbol *>(pair.first)};
634632
});
@@ -2101,29 +2099,29 @@ void OmpStructureChecker::Leave(const parser::OpenMPDeclareTargetConstruct &x) {
21012099
}
21022100
}
21032101

2104-
bool toClauseFound{false}, deviceTypeClauseFound{false},
2105-
enterClauseFound{false};
2102+
bool toClauseFound{false};
2103+
bool deviceTypeClauseFound{false};
2104+
bool enterClauseFound{false};
21062105
for (const parser::OmpClause &clause : x.v.Clauses().v) {
21072106
common::visit(
21082107
common::visitors{
2109-
[&](const parser::OmpClause::To &toClause) {
2110-
toClauseFound = true;
2111-
auto &objList{std::get<parser::OmpObjectList>(toClause.v.t)};
2112-
CheckSymbolNames(dirName.source, objList);
2113-
CheckVarIsNotPartOfAnotherVar(dirName.source, objList);
2114-
CheckThreadprivateOrDeclareTargetVar(objList);
2115-
},
2116-
[&](const parser::OmpClause::Link &linkClause) {
2117-
CheckSymbolNames(dirName.source, linkClause.v);
2118-
CheckVarIsNotPartOfAnotherVar(dirName.source, linkClause.v);
2119-
CheckThreadprivateOrDeclareTargetVar(linkClause.v);
2120-
},
2121-
[&](const parser::OmpClause::Enter &enterClause) {
2122-
enterClauseFound = true;
2123-
auto &objList{std::get<parser::OmpObjectList>(enterClause.v.t)};
2124-
CheckSymbolNames(dirName.source, objList);
2125-
CheckVarIsNotPartOfAnotherVar(dirName.source, objList);
2126-
CheckThreadprivateOrDeclareTargetVar(objList);
2108+
[&](const auto &c) {
2109+
using TypeC = llvm::remove_cvref_t<decltype(c)>;
2110+
if constexpr ( //
2111+
std::is_same_v<TypeC, parser::OmpClause::Enter> ||
2112+
std::is_same_v<TypeC, parser::OmpClause::Link> ||
2113+
std::is_same_v<TypeC, parser::OmpClause::To>) {
2114+
auto &objList{*GetOmpObjectList(c)};
2115+
CheckSymbolNames(dirName.source, objList);
2116+
CheckVarIsNotPartOfAnotherVar(dirName.source, objList);
2117+
CheckThreadprivateOrDeclareTargetVar(objList);
2118+
}
2119+
if constexpr (std::is_same_v<TypeC, parser::OmpClause::Enter>) {
2120+
enterClauseFound = true;
2121+
}
2122+
if constexpr (std::is_same_v<TypeC, parser::OmpClause::To>) {
2123+
toClauseFound = true;
2124+
}
21272125
},
21282126
[&](const parser::OmpClause::DeviceType &deviceTypeClause) {
21292127
deviceTypeClauseFound = true;
@@ -2134,7 +2132,6 @@ void OmpStructureChecker::Leave(const parser::OpenMPDeclareTargetConstruct &x) {
21342132
deviceConstructFound_ = true;
21352133
}
21362134
},
2137-
[&](const auto &) {},
21382135
},
21392136
clause.u);
21402137

@@ -2424,12 +2421,8 @@ void OmpStructureChecker::CheckTargetUpdate() {
24242421
}
24252422
if (toWrapper && fromWrapper) {
24262423
SymbolSourceMap toSymbols, fromSymbols;
2427-
auto &fromClause{std::get<parser::OmpClause::From>(fromWrapper->u).v};
2428-
auto &toClause{std::get<parser::OmpClause::To>(toWrapper->u).v};
2429-
GetSymbolsInObjectList(
2430-
std::get<parser::OmpObjectList>(fromClause.t), fromSymbols);
2431-
GetSymbolsInObjectList(
2432-
std::get<parser::OmpObjectList>(toClause.t), toSymbols);
2424+
GetSymbolsInObjectList(*GetOmpObjectList(*fromWrapper), fromSymbols);
2425+
GetSymbolsInObjectList(*GetOmpObjectList(*toWrapper), toSymbols);
24332426

24342427
for (auto &[symbol, source] : toSymbols) {
24352428
auto fromSymbol{fromSymbols.find(symbol)};
@@ -3269,7 +3262,7 @@ void OmpStructureChecker::Leave(const parser::OmpClauseList &) {
32693262
const auto &irClause{
32703263
std::get<parser::OmpClause::InReduction>(dataEnvClause->u)};
32713264
checkVarAppearsInDataEnvClause(
3272-
std::get<parser::OmpObjectList>(irClause.v.t), "IN_REDUCTION");
3265+
*GetOmpObjectList(irClause), "IN_REDUCTION");
32733266
}
32743267
}
32753268
}
@@ -3421,7 +3414,7 @@ void OmpStructureChecker::Enter(const parser::OmpClause::Destroy &x) {
34213414

34223415
void OmpStructureChecker::Enter(const parser::OmpClause::Reduction &x) {
34233416
CheckAllowedClause(llvm::omp::Clause::OMPC_reduction);
3424-
auto &objects{std::get<parser::OmpObjectList>(x.v.t)};
3417+
auto &objects{*GetOmpObjectList(x)};
34253418

34263419
if (OmpVerifyModifiers(x.v, llvm::omp::OMPC_reduction,
34273420
GetContext().clauseSource, context_)) {
@@ -3461,7 +3454,7 @@ void OmpStructureChecker::Enter(const parser::OmpClause::Reduction &x) {
34613454

34623455
void OmpStructureChecker::Enter(const parser::OmpClause::InReduction &x) {
34633456
CheckAllowedClause(llvm::omp::Clause::OMPC_in_reduction);
3464-
auto &objects{std::get<parser::OmpObjectList>(x.v.t)};
3457+
auto &objects{*GetOmpObjectList(x)};
34653458

34663459
if (OmpVerifyModifiers(x.v, llvm::omp::OMPC_in_reduction,
34673460
GetContext().clauseSource, context_)) {
@@ -3479,7 +3472,7 @@ void OmpStructureChecker::Enter(const parser::OmpClause::InReduction &x) {
34793472

34803473
void OmpStructureChecker::Enter(const parser::OmpClause::TaskReduction &x) {
34813474
CheckAllowedClause(llvm::omp::Clause::OMPC_task_reduction);
3482-
auto &objects{std::get<parser::OmpObjectList>(x.v.t)};
3475+
auto &objects{*GetOmpObjectList(x)};
34833476

34843477
if (OmpVerifyModifiers(x.v, llvm::omp::OMPC_task_reduction,
34853478
GetContext().clauseSource, context_)) {
@@ -4332,8 +4325,7 @@ void OmpStructureChecker::Enter(const parser::OmpClause::Map &x) {
43324325
}};
43334326

43344327
evaluate::ExpressionAnalyzer ea{context_};
4335-
const auto &objects{std::get<parser::OmpObjectList>(x.v.t)};
4336-
for (auto &object : objects.v) {
4328+
for (auto &object : GetOmpObjectList(x)->v) {
43374329
if (const parser::Designator *d{GetDesignatorFromObj(object)}) {
43384330
if (auto &&expr{ea.Analyze(*d)}) {
43394331
if (hasBasePointer(*expr)) {
@@ -4486,7 +4478,7 @@ void OmpStructureChecker::Enter(const parser::OmpClause::Depend &x) {
44864478
}
44874479
}
44884480
if (taskDep) {
4489-
auto &objList{std::get<parser::OmpObjectList>(taskDep->t)};
4481+
auto &objList{*GetOmpObjectList(*taskDep)};
44904482
if (dir == llvm::omp::OMPD_depobj) {
44914483
// [5.0:255:13], [5.1:288:6], [5.2:322:26]
44924484
// A depend clause on a depobj construct must only specify one locator.
@@ -4632,7 +4624,7 @@ void OmpStructureChecker::Enter(const parser::OmpClause::Copyprivate &x) {
46324624
void OmpStructureChecker::Enter(const parser::OmpClause::Lastprivate &x) {
46334625
CheckAllowedClause(llvm::omp::Clause::OMPC_lastprivate);
46344626

4635-
const auto &objectList{std::get<parser::OmpObjectList>(x.v.t)};
4627+
const auto &objectList{*GetOmpObjectList(x)};
46364628
CheckVarIsNotPartOfAnotherVar(
46374629
GetContext().clauseSource, objectList, "LASTPRIVATE");
46384630
CheckCrayPointee(objectList, "LASTPRIVATE");
@@ -4874,9 +4866,8 @@ void OmpStructureChecker::Enter(const parser::OmpClause::Enter &x) {
48744866
x.v, llvm::omp::OMPC_enter, GetContext().clauseSource, context_)) {
48754867
return;
48764868
}
4877-
const parser::OmpObjectList &objList{std::get<parser::OmpObjectList>(x.v.t)};
48784869
SymbolSourceMap symbols;
4879-
GetSymbolsInObjectList(objList, symbols);
4870+
GetSymbolsInObjectList(*GetOmpObjectList(x), symbols);
48804871
for (const auto &[symbol, source] : symbols) {
48814872
if (!IsExtendedListItem(*symbol)) {
48824873
context_.SayWithDecl(*symbol, source,
@@ -4899,7 +4890,7 @@ void OmpStructureChecker::Enter(const parser::OmpClause::From &x) {
48994890
CheckIteratorModifier(*iter);
49004891
}
49014892

4902-
const auto &objList{std::get<parser::OmpObjectList>(x.v.t)};
4893+
const auto &objList{*GetOmpObjectList(x)};
49034894
SymbolSourceMap symbols;
49044895
GetSymbolsInObjectList(objList, symbols);
49054896
CheckVariableListItem(symbols);
@@ -4939,7 +4930,7 @@ void OmpStructureChecker::Enter(const parser::OmpClause::To &x) {
49394930
CheckIteratorModifier(*iter);
49404931
}
49414932

4942-
const auto &objList{std::get<parser::OmpObjectList>(x.v.t)};
4933+
const auto &objList{*GetOmpObjectList(x)};
49434934
SymbolSourceMap symbols;
49444935
GetSymbolsInObjectList(objList, symbols);
49454936
CheckVariableListItem(symbols);

0 commit comments

Comments
 (0)