Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions flang/include/flang/Parser/openmp-utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#define FORTRAN_PARSER_OPENMP_UTILS_H

#include "flang/Common/indirection.h"
#include "flang/Common/template.h"
#include "flang/Parser/parse-tree.h"
#include "llvm/Frontend/OpenMP/OMP.h"

Expand Down Expand Up @@ -127,7 +128,62 @@ template <typename T> struct IsStatement<Statement<T>> {
std::optional<Label> GetStatementLabel(const ExecutionPartConstruct &x);
std::optional<Label> GetFinalLabel(const OpenMPConstruct &x);

namespace detail {
// Clauses with OmpObjectList as its data member
using MemberObjectListClauses =
std::tuple<OmpClause::Copyin, OmpClause::Copyprivate, OmpClause::Exclusive,
OmpClause::Firstprivate, OmpClause::HasDeviceAddr, OmpClause::Inclusive,
OmpClause::IsDevicePtr, OmpClause::Link, OmpClause::Private,
OmpClause::Shared, OmpClause::UseDeviceAddr, OmpClause::UseDevicePtr>;

// Clauses with OmpObjectList in the tuple
using TupleObjectListClauses = std::tuple<OmpClause::AdjustArgs,
OmpClause::Affinity, OmpClause::Aligned, OmpClause::Allocate,
OmpClause::Enter, OmpClause::From, OmpClause::InReduction,
OmpClause::Lastprivate, OmpClause::Linear, OmpClause::Map,
OmpClause::Reduction, OmpClause::TaskReduction, OmpClause::To>;

template <typename...> struct WrappedInType;

template <typename T> struct WrappedInType<T> {
static constexpr bool value{false};
};

template <typename T, typename U, typename... Us>
struct WrappedInType<T, U, Us...> {
static constexpr bool value{//
std::is_same_v<T, decltype(U::v)> || WrappedInType<T, Us...>::value};
};

template <typename...> struct WrappedInTuple {
static constexpr bool value{false};
};
template <typename T, typename... Us>
struct WrappedInTuple<T, std::tuple<Us...>> {
static constexpr bool value{WrappedInType<T, Us...>::value};
};
template <typename T, typename U>
constexpr bool WrappedInTupleV{WrappedInTuple<T, U>::value};
} // namespace detail

template <typename T> const OmpObjectList *GetOmpObjectList(const T &clause) {
using namespace detail;

if constexpr (common::HasMember<T, MemberObjectListClauses>) {
return &clause.v;
} else if constexpr (common::HasMember<T, TupleObjectListClauses>) {
return &std::get<OmpObjectList>(clause.v.t);
} else if constexpr (WrappedInTupleV<T, TupleObjectListClauses>) {
return &std::get<OmpObjectList>(clause.t);
} else {
static_assert(std::is_class_v<T>, "Unexpected argument type");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why didn't this fire for the issue Kareem pointed out?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

x was of type OpenMPAllocatorsConstruct, which is a class/struct. This assertion is intended to flag pointers that are easy to pass by accident.

I'll try to add a more detailed check.

return nullptr;
}
}

const OmpObjectList *GetOmpObjectList(const OmpClause &clause);
const OmpObjectList *GetOmpObjectList(const OmpClause::Depend &clause);
const OmpObjectList *GetOmpObjectList(const OmpDependClause::TaskDep &x);

template <typename T>
const T *GetFirstArgument(const OmpDirectiveSpecification &spec) {
Expand Down
45 changes: 11 additions & 34 deletions flang/lib/Parser/openmp-utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,43 +117,20 @@ std::optional<Label> GetFinalLabel(const OpenMPConstruct &x) {
}

const OmpObjectList *GetOmpObjectList(const OmpClause &clause) {
// Clauses with OmpObjectList as its data member
using MemberObjectListClauses = std::tuple<OmpClause::Copyin,
OmpClause::Copyprivate, OmpClause::Exclusive, OmpClause::Firstprivate,
OmpClause::HasDeviceAddr, OmpClause::Inclusive, OmpClause::IsDevicePtr,
OmpClause::Link, OmpClause::Private, OmpClause::Shared,
OmpClause::UseDeviceAddr, OmpClause::UseDevicePtr>;

// Clauses with OmpObjectList in the tuple
using TupleObjectListClauses = std::tuple<OmpClause::AdjustArgs,
OmpClause::Affinity, OmpClause::Aligned, OmpClause::Allocate,
OmpClause::Enter, OmpClause::From, OmpClause::InReduction,
OmpClause::Lastprivate, OmpClause::Linear, OmpClause::Map,
OmpClause::Reduction, OmpClause::TaskReduction, OmpClause::To>;

// TODO:: Generate the tuples using TableGen.
return common::visit([](auto &&s) { return GetOmpObjectList(s); }, clause.u);
}

const OmpObjectList *GetOmpObjectList(const OmpClause::Depend &clause) {
return common::visit(
common::visitors{
[&](const OmpClause::Depend &x) -> const OmpObjectList * {
if (auto *taskDep{std::get_if<OmpDependClause::TaskDep>(&x.v.u)}) {
return &std::get<OmpObjectList>(taskDep->t);
} else {
return nullptr;
}
},
[&](const auto &x) -> const OmpObjectList * {
using Ty = std::decay_t<decltype(x)>;
if constexpr (common::HasMember<Ty, MemberObjectListClauses>) {
return &x.v;
} else if constexpr (common::HasMember<Ty,
TupleObjectListClauses>) {
return &std::get<OmpObjectList>(x.v.t);
} else {
return nullptr;
}
},
[](const OmpDoacross &) { return nullptr; },
[](const OmpDependClause::TaskDep &x) { return GetOmpObjectList(x); },
},
clause.u);
clause.v.u);
}

const OmpObjectList *GetOmpObjectList(const OmpDependClause::TaskDep &x) {
return &std::get<OmpObjectList>(x.t);
}

const BlockConstruct *GetFortranBlockConstruct(
Expand Down
9 changes: 3 additions & 6 deletions flang/lib/Semantics/check-omp-loop.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -480,9 +480,8 @@ void OmpStructureChecker::CheckDistLinear(

// Collect symbols of all the variables from linear clauses
for (auto &clause : clauses.v) {
if (auto *linearClause{std::get_if<parser::OmpClause::Linear>(&clause.u)}) {
auto &objects{std::get<parser::OmpObjectList>(linearClause->v.t)};
GetSymbolsInObjectList(objects, indexVars);
if (std::get_if<parser::OmpClause::Linear>(&clause.u)) {
GetSymbolsInObjectList(*parser::omp::GetOmpObjectList(clause), indexVars);
}
}

Expand Down Expand Up @@ -604,8 +603,6 @@ void OmpStructureChecker::Leave(const parser::OpenMPLoopConstruct &x) {
auto *maybeModifier{OmpGetUniqueModifier<ReductionModifier>(modifiers)};
if (maybeModifier &&
maybeModifier->v == ReductionModifier::Value::Inscan) {
const auto &objectList{
std::get<parser::OmpObjectList>(reductionClause->v.t)};
auto checkReductionSymbolInScan = [&](const parser::Name *name) {
if (auto &symbol = name->symbol) {
if (!symbol->test(Symbol::Flag::OmpInclusiveScan) &&
Expand All @@ -618,7 +615,7 @@ void OmpStructureChecker::Leave(const parser::OpenMPLoopConstruct &x) {
}
}
};
for (const auto &ompObj : objectList.v) {
for (const auto &ompObj : parser::omp::GetOmpObjectList(clause)->v) {
common::visit(
common::visitors{
[&](const parser::Designator &designator) {
Expand Down
75 changes: 33 additions & 42 deletions flang/lib/Semantics/check-omp-structure.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -624,11 +624,9 @@ void OmpStructureChecker::CheckMultListItems() {

// Linear clause
for (auto [_, clause] : FindClauses(llvm::omp::Clause::OMPC_linear)) {
auto &linearClause{std::get<parser::OmpClause::Linear>(clause->u)};
std::list<parser::Name> nameList;
SymbolSourceMap symbols;
GetSymbolsInObjectList(
std::get<parser::OmpObjectList>(linearClause.v.t), symbols);
GetSymbolsInObjectList(*GetOmpObjectList(*clause), symbols);
llvm::transform(symbols, std::back_inserter(nameList), [&](auto &&pair) {
return parser::Name{pair.second, const_cast<Symbol *>(pair.first)};
});
Expand Down Expand Up @@ -2101,29 +2099,29 @@ void OmpStructureChecker::Leave(const parser::OpenMPDeclareTargetConstruct &x) {
}
}

bool toClauseFound{false}, deviceTypeClauseFound{false},
enterClauseFound{false};
bool toClauseFound{false};
bool deviceTypeClauseFound{false};
bool enterClauseFound{false};
for (const parser::OmpClause &clause : x.v.Clauses().v) {
common::visit(
common::visitors{
[&](const parser::OmpClause::To &toClause) {
toClauseFound = true;
auto &objList{std::get<parser::OmpObjectList>(toClause.v.t)};
CheckSymbolNames(dirName.source, objList);
CheckVarIsNotPartOfAnotherVar(dirName.source, objList);
CheckThreadprivateOrDeclareTargetVar(objList);
},
[&](const parser::OmpClause::Link &linkClause) {
CheckSymbolNames(dirName.source, linkClause.v);
CheckVarIsNotPartOfAnotherVar(dirName.source, linkClause.v);
CheckThreadprivateOrDeclareTargetVar(linkClause.v);
},
[&](const parser::OmpClause::Enter &enterClause) {
enterClauseFound = true;
auto &objList{std::get<parser::OmpObjectList>(enterClause.v.t)};
CheckSymbolNames(dirName.source, objList);
CheckVarIsNotPartOfAnotherVar(dirName.source, objList);
CheckThreadprivateOrDeclareTargetVar(objList);
[&](const auto &c) {
using TypeC = llvm::remove_cvref_t<decltype(c)>;
if constexpr ( //
std::is_same_v<TypeC, parser::OmpClause::Enter> ||
std::is_same_v<TypeC, parser::OmpClause::Link> ||
std::is_same_v<TypeC, parser::OmpClause::To>) {
auto &objList{*GetOmpObjectList(c)};
CheckSymbolNames(dirName.source, objList);
CheckVarIsNotPartOfAnotherVar(dirName.source, objList);
CheckThreadprivateOrDeclareTargetVar(objList);
}
if constexpr (std::is_same_v<TypeC, parser::OmpClause::Enter>) {
enterClauseFound = true;
}
if constexpr (std::is_same_v<TypeC, parser::OmpClause::To>) {
toClauseFound = true;
}
},
[&](const parser::OmpClause::DeviceType &deviceTypeClause) {
deviceTypeClauseFound = true;
Expand All @@ -2134,7 +2132,6 @@ void OmpStructureChecker::Leave(const parser::OpenMPDeclareTargetConstruct &x) {
deviceConstructFound_ = true;
}
},
[&](const auto &) {},
},
clause.u);

Expand Down Expand Up @@ -2424,12 +2421,8 @@ void OmpStructureChecker::CheckTargetUpdate() {
}
if (toWrapper && fromWrapper) {
SymbolSourceMap toSymbols, fromSymbols;
auto &fromClause{std::get<parser::OmpClause::From>(fromWrapper->u).v};
auto &toClause{std::get<parser::OmpClause::To>(toWrapper->u).v};
GetSymbolsInObjectList(
std::get<parser::OmpObjectList>(fromClause.t), fromSymbols);
GetSymbolsInObjectList(
std::get<parser::OmpObjectList>(toClause.t), toSymbols);
GetSymbolsInObjectList(*GetOmpObjectList(*fromWrapper), fromSymbols);
GetSymbolsInObjectList(*GetOmpObjectList(*toWrapper), toSymbols);

for (auto &[symbol, source] : toSymbols) {
auto fromSymbol{fromSymbols.find(symbol)};
Expand Down Expand Up @@ -3269,7 +3262,7 @@ void OmpStructureChecker::Leave(const parser::OmpClauseList &) {
const auto &irClause{
std::get<parser::OmpClause::InReduction>(dataEnvClause->u)};
checkVarAppearsInDataEnvClause(
std::get<parser::OmpObjectList>(irClause.v.t), "IN_REDUCTION");
*GetOmpObjectList(irClause), "IN_REDUCTION");
}
}
}
Expand Down Expand Up @@ -3436,7 +3429,7 @@ void OmpStructureChecker::Enter(const parser::OmpClause::Destroy &x) {

void OmpStructureChecker::Enter(const parser::OmpClause::Reduction &x) {
CheckAllowedClause(llvm::omp::Clause::OMPC_reduction);
auto &objects{std::get<parser::OmpObjectList>(x.v.t)};
auto &objects{*GetOmpObjectList(x)};

if (OmpVerifyModifiers(x.v, llvm::omp::OMPC_reduction,
GetContext().clauseSource, context_)) {
Expand Down Expand Up @@ -3476,7 +3469,7 @@ void OmpStructureChecker::Enter(const parser::OmpClause::Reduction &x) {

void OmpStructureChecker::Enter(const parser::OmpClause::InReduction &x) {
CheckAllowedClause(llvm::omp::Clause::OMPC_in_reduction);
auto &objects{std::get<parser::OmpObjectList>(x.v.t)};
auto &objects{*GetOmpObjectList(x)};

if (OmpVerifyModifiers(x.v, llvm::omp::OMPC_in_reduction,
GetContext().clauseSource, context_)) {
Expand All @@ -3494,7 +3487,7 @@ void OmpStructureChecker::Enter(const parser::OmpClause::InReduction &x) {

void OmpStructureChecker::Enter(const parser::OmpClause::TaskReduction &x) {
CheckAllowedClause(llvm::omp::Clause::OMPC_task_reduction);
auto &objects{std::get<parser::OmpObjectList>(x.v.t)};
auto &objects{*GetOmpObjectList(x)};

if (OmpVerifyModifiers(x.v, llvm::omp::OMPC_task_reduction,
GetContext().clauseSource, context_)) {
Expand Down Expand Up @@ -4347,8 +4340,7 @@ void OmpStructureChecker::Enter(const parser::OmpClause::Map &x) {
}};

evaluate::ExpressionAnalyzer ea{context_};
const auto &objects{std::get<parser::OmpObjectList>(x.v.t)};
for (auto &object : objects.v) {
for (auto &object : GetOmpObjectList(x)->v) {
if (const parser::Designator *d{GetDesignatorFromObj(object)}) {
if (auto &&expr{ea.Analyze(*d)}) {
if (hasBasePointer(*expr)) {
Expand Down Expand Up @@ -4501,7 +4493,7 @@ void OmpStructureChecker::Enter(const parser::OmpClause::Depend &x) {
}
}
if (taskDep) {
auto &objList{std::get<parser::OmpObjectList>(taskDep->t)};
auto &objList{*GetOmpObjectList(*taskDep)};
if (dir == llvm::omp::OMPD_depobj) {
// [5.0:255:13], [5.1:288:6], [5.2:322:26]
// A depend clause on a depobj construct must only specify one locator.
Expand Down Expand Up @@ -4647,7 +4639,7 @@ void OmpStructureChecker::Enter(const parser::OmpClause::Copyprivate &x) {
void OmpStructureChecker::Enter(const parser::OmpClause::Lastprivate &x) {
CheckAllowedClause(llvm::omp::Clause::OMPC_lastprivate);

const auto &objectList{std::get<parser::OmpObjectList>(x.v.t)};
const auto &objectList{*GetOmpObjectList(x)};
CheckVarIsNotPartOfAnotherVar(
GetContext().clauseSource, objectList, "LASTPRIVATE");
CheckCrayPointee(objectList, "LASTPRIVATE");
Expand Down Expand Up @@ -4889,9 +4881,8 @@ void OmpStructureChecker::Enter(const parser::OmpClause::Enter &x) {
x.v, llvm::omp::OMPC_enter, GetContext().clauseSource, context_)) {
return;
}
const parser::OmpObjectList &objList{std::get<parser::OmpObjectList>(x.v.t)};
SymbolSourceMap symbols;
GetSymbolsInObjectList(objList, symbols);
GetSymbolsInObjectList(*GetOmpObjectList(x), symbols);
for (const auto &[symbol, source] : symbols) {
if (!IsExtendedListItem(*symbol)) {
context_.SayWithDecl(*symbol, source,
Expand All @@ -4914,7 +4905,7 @@ void OmpStructureChecker::Enter(const parser::OmpClause::From &x) {
CheckIteratorModifier(*iter);
}

const auto &objList{std::get<parser::OmpObjectList>(x.v.t)};
const auto &objList{*GetOmpObjectList(x)};
SymbolSourceMap symbols;
GetSymbolsInObjectList(objList, symbols);
CheckVariableListItem(symbols);
Expand Down Expand Up @@ -4954,7 +4945,7 @@ void OmpStructureChecker::Enter(const parser::OmpClause::To &x) {
CheckIteratorModifier(*iter);
}

const auto &objList{std::get<parser::OmpObjectList>(x.v.t)};
const auto &objList{*GetOmpObjectList(x)};
SymbolSourceMap symbols;
GetSymbolsInObjectList(objList, symbols);
CheckVariableListItem(symbols);
Expand Down
Loading