Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
256 changes: 169 additions & 87 deletions flang/lib/Semantics/check-omp-structure.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,30 @@ std::string TryVersion(unsigned version) {
return "try -fopenmp-version=" + std::to_string(version);
}

static const parser::Designator *GetDesignatorFromObj(
const parser::OmpObject &object) {
return std::get_if<parser::Designator>(&object.u);
}

static const parser::DataRef *GetDataRefFromObj(
const parser::OmpObject &object) {
if (auto *desg{GetDesignatorFromObj(object)}) {
return std::get_if<parser::DataRef>(&desg->u);
}
return nullptr;
}

static const parser::ArrayElement *GetArrayElementFromObj(
const parser::OmpObject &object) {
if (auto *dataRef{GetDataRefFromObj(object)}) {
using ElementIndirection = common::Indirection<parser::ArrayElement>;
if (auto *ind{std::get_if<ElementIndirection>(&dataRef->u)}) {
return &ind->value();
}
}
return nullptr;
}

// 'OmpWorkshareBlockChecker' is used to check the validity of the assignment
// statements and the expressions enclosed in an OpenMP Workshare construct
class OmpWorkshareBlockChecker {
Expand Down Expand Up @@ -222,6 +246,10 @@ bool OmpStructureChecker::CheckAllowedClause(llvmOmpClause clause) {
return CheckAllowed(clause);
}

bool OmpStructureChecker::IsCommonBlock(const Symbol &sym) {
return sym.detailsIf<CommonBlockDetails>() != nullptr;
}

bool OmpStructureChecker::IsVariableListItem(const Symbol &sym) {
return evaluate::IsVariable(sym) || sym.attrs().test(Attr::POINTER);
}
Expand Down Expand Up @@ -2883,6 +2911,8 @@ void OmpStructureChecker::Enter(const parser::OmpClause::Destroy &x) {

void OmpStructureChecker::Enter(const parser::OmpClause::Reduction &x) {
CheckAllowedClause(llvm::omp::Clause::OMPC_reduction);
auto &objects{std::get<parser::OmpObjectList>(x.v.t)};

if (OmpVerifyModifiers(x.v, llvm::omp::OMPC_reduction,
GetContext().clauseSource, context_)) {
if (CheckReductionOperators(x)) {
Expand All @@ -2895,6 +2925,13 @@ void OmpStructureChecker::Enter(const parser::OmpClause::Reduction &x) {
CheckReductionModifier(*maybeModifier);
}
}
CheckReductionObjects(objects, llvm::omp::Clause::OMPC_reduction);

// If this is a worksharing construct then ensure the reduction variable
// is not private in the parallel region that it binds to.
if (llvm::omp::nestedReduceWorkshareAllowedSet.test(GetContext().directive)) {
CheckSharedBindingInOuterContext(objects);
}
}

bool OmpStructureChecker::CheckReductionOperators(
Expand Down Expand Up @@ -2963,6 +3000,69 @@ bool OmpStructureChecker::CheckIntrinsicOperator(
return false;
}

/// Check restrictions on objects that are common to all reduction clauses.
void OmpStructureChecker::CheckReductionObjects(
const parser::OmpObjectList &objects, llvm::omp::Clause clauseId) {
unsigned version{context_.langOptions().OpenMPVersion};
SymbolSourceMap symbols;
GetSymbolsInObjectList(objects, symbols);

// Array sections must be a contiguous storage, have non-zero length.
for (const parser::OmpObject &object : objects.v) {
CheckIfContiguous(object);
}
CheckReductionArraySection(objects);
// An object must be definable.
CheckDefinableObjects(symbols, clauseId);
// Procedure pointers are not allowed.
CheckProcedurePointer(symbols, clauseId);
// Pointers must not have INTENT(IN).
CheckIntentInPointer(symbols, clauseId);

// Disallow common blocks.
// Iterate on objects because `GetSymbolsInObjectList` expands common block
// names into the lists of their members.
for (const parser::OmpObject &object : objects.v) {
auto *symbol{GetObjectSymbol(object)};
assert(symbol && "Expecting a symbol for object");
if (IsCommonBlock(*symbol)) {
auto source{GetObjectSource(object)};
context_.Say(source ? *source : GetContext().clauseSource,
"Common block names are not allowed in %s clause"_err_en_US,
parser::ToUpperCaseLetters(getClauseName(clauseId).str()));
}
}

if (version >= 50) {
// Object cannot be a part of another object (except array elements)
CheckStructureComponent(objects, clauseId);
// If object is an array section or element, the base expression must be
// a language identifier.
for (const parser::OmpObject &object : objects.v) {
if (auto *elem{GetArrayElementFromObj(object)}) {
const parser::DataRef &base = elem->base;
if (!std::holds_alternative<parser::Name>(base.u)) {
auto source{GetObjectSource(object)};
context_.Say(source ? *source : GetContext().clauseSource,
"The base expression of an array element in %s clause must be an identifier"_err_en_US,
parser::ToUpperCaseLetters(getClauseName(clauseId).str()));
}
}
}
// Type parameter inquiries are not allowed.
for (const parser::OmpObject &object : objects.v) {
if (auto *dataRef{GetDataRefFromObj(object)}) {
if (IsDataRefTypeParamInquiry(dataRef)) {
auto source{GetObjectSource(object)};
context_.Say(source ? *source : GetContext().clauseSource,
"Type parameter inquiry is not permitted in %s clause"_err_en_US,
parser::ToUpperCaseLetters(getClauseName(clauseId).str()));
}
}
}
}
}

static bool IsReductionAllowedForType(
const parser::OmpClause::Reduction &x, const DeclTypeSpec &type) {
auto &modifiers{OmpGetModifiers(x.v)};
Expand Down Expand Up @@ -3052,26 +3152,18 @@ static bool IsReductionAllowedForType(
void OmpStructureChecker::CheckReductionTypeList(
const parser::OmpClause::Reduction &x) {
const auto &ompObjectList{std::get<parser::OmpObjectList>(x.v.t)};
CheckIntentInPointerAndDefinable(
ompObjectList, llvm::omp::Clause::OMPC_reduction);
CheckReductionArraySection(ompObjectList);
// If this is a worksharing construct then ensure the reduction variable
// is not private in the parallel region that it binds to.
if (llvm::omp::nestedReduceWorkshareAllowedSet.test(GetContext().directive)) {
CheckSharedBindingInOuterContext(ompObjectList);
}

SymbolSourceMap symbols;
GetSymbolsInObjectList(ompObjectList, symbols);

for (auto &[symbol, source] : symbols) {
if (IsProcedurePointer(*symbol)) {
context_.Say(source,
"A procedure pointer '%s' must not appear in a REDUCTION clause."_err_en_US,
symbol->name());
} else if (!IsReductionAllowedForType(x, DEREF(symbol->GetType()))) {
context_.Say(source,
"The type of '%s' is incompatible with the reduction operator."_err_en_US,
symbol->name());
if (auto *type{symbol->GetType()}) {
if (!IsReductionAllowedForType(x, *type)) {
context_.Say(source,
"The type of '%s' is incompatible with the reduction operator."_err_en_US,
symbol->name());
}
} else {
assert(IsProcedurePointer(*symbol) && "Unexpected symbol properties");
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is checked elsewhere (in the part that is independent of the reduction operator/type). The assertion is here to catch cases where type is not available that are not procedure pointers.

}
}
}
Expand Down Expand Up @@ -3127,43 +3219,14 @@ void OmpStructureChecker::CheckReductionModifier(
}
}

void OmpStructureChecker::CheckIntentInPointerAndDefinable(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The checks from this function already existed in separate functions. They are used instead.

const parser::OmpObjectList &objectList, const llvm::omp::Clause clause) {
for (const auto &ompObject : objectList.v) {
if (const auto *name{parser::Unwrap<parser::Name>(ompObject)}) {
if (const auto *symbol{name->symbol}) {
if (IsPointer(symbol->GetUltimate()) &&
IsIntentIn(symbol->GetUltimate())) {
context_.Say(GetContext().clauseSource,
"Pointer '%s' with the INTENT(IN) attribute may not appear "
"in a %s clause"_err_en_US,
symbol->name(),
parser::ToUpperCaseLetters(getClauseName(clause).str()));
} else if (auto msg{WhyNotDefinable(name->source,
context_.FindScope(name->source), DefinabilityFlags{},
*symbol)}) {
context_
.Say(GetContext().clauseSource,
"Variable '%s' on the %s clause is not definable"_err_en_US,
symbol->name(),
parser::ToUpperCaseLetters(getClauseName(clause).str()))
.Attach(std::move(msg->set_severity(parser::Severity::Because)));
}
}
}
}
}

void OmpStructureChecker::CheckReductionArraySection(
const parser::OmpObjectList &ompObjectList) {
for (const auto &ompObject : ompObjectList.v) {
if (const auto *dataRef{parser::Unwrap<parser::DataRef>(ompObject)}) {
if (const auto *arrayElement{
parser::Unwrap<parser::ArrayElement>(ompObject)}) {
if (arrayElement) {
CheckArraySection(*arrayElement, GetLastName(*dataRef),
llvm::omp::Clause::OMPC_reduction);
}
CheckArraySection(*arrayElement, GetLastName(*dataRef),
llvm::omp::Clause::OMPC_reduction);
}
}
}
Expand Down Expand Up @@ -3232,9 +3295,11 @@ void OmpStructureChecker::Enter(const parser::OmpClause::Shared &x) {
CheckIsVarPartOfAnotherVar(GetContext().clauseSource, x.v, "SHARED");
}
void OmpStructureChecker::Enter(const parser::OmpClause::Private &x) {
SymbolSourceMap symbols;
GetSymbolsInObjectList(x.v, symbols);
CheckAllowedClause(llvm::omp::Clause::OMPC_private);
CheckIsVarPartOfAnotherVar(GetContext().clauseSource, x.v, "PRIVATE");
CheckIntentInPointer(x.v, llvm::omp::Clause::OMPC_private);
CheckIntentInPointer(symbols, llvm::omp::Clause::OMPC_private);
}

void OmpStructureChecker::Enter(const parser::OmpClause::Nowait &x) {
Expand Down Expand Up @@ -3891,11 +3956,11 @@ void OmpStructureChecker::CheckCopyingPolymorphicAllocatable(

void OmpStructureChecker::Enter(const parser::OmpClause::Copyprivate &x) {
CheckAllowedClause(llvm::omp::Clause::OMPC_copyprivate);
CheckIntentInPointer(x.v, llvm::omp::Clause::OMPC_copyprivate);
SymbolSourceMap currSymbols;
GetSymbolsInObjectList(x.v, currSymbols);
SymbolSourceMap symbols;
GetSymbolsInObjectList(x.v, symbols);
CheckIntentInPointer(symbols, llvm::omp::Clause::OMPC_copyprivate);
CheckCopyingPolymorphicAllocatable(
currSymbols, llvm::omp::Clause::OMPC_copyprivate);
symbols, llvm::omp::Clause::OMPC_copyprivate);
if (GetContext().directive == llvm::omp::Directive::OMPD_single) {
context_.Say(GetContext().clauseSource,
"%s clause is not allowed on the OMP %s directive,"
Expand Down Expand Up @@ -3945,29 +4010,26 @@ void OmpStructureChecker::Enter(const parser::OmpClause::Copyin &x) {
currSymbols, llvm::omp::Clause::OMPC_copyin);
}

void OmpStructureChecker::CheckStructureElement(
const parser::OmpObjectList &ompObjectList,
const llvm::omp::Clause clause) {
for (const auto &ompObject : ompObjectList.v) {
void OmpStructureChecker::CheckStructureComponent(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is almost the same as the code below, but with the lambda extracted out (for better formatting). The message was changed slightly, and the source location is now the component name.

const parser::OmpObjectList &objects, llvm::omp::Clause clauseId) {
auto CheckComponent{[&](const parser::Designator &designator) {
if (auto *desg{std::get_if<parser::DataRef>(&designator.u)}) {
if (auto *comp{parser::Unwrap<parser::StructureComponent>(*desg)}) {
context_.Say(comp->component.source,
"A variable that is part of another variable cannot appear on the %s clause"_err_en_US,
parser::ToUpperCaseLetters(getClauseName(clauseId).str()));
}
}
}};

for (const auto &object : objects.v) {
common::visit(
common::visitors{
[&](const parser::Designator &designator) {
if (std::get_if<parser::DataRef>(&designator.u)) {
if (parser::Unwrap<parser::StructureComponent>(ompObject)) {
context_.Say(GetContext().clauseSource,
"A variable that is part of another variable "
"(structure element) cannot appear on the %s "
"%s clause"_err_en_US,
ContextDirectiveAsFortran(),
parser::ToUpperCaseLetters(getClauseName(clause).str()));
}
}
},
CheckComponent,
[&](const parser::Name &name) {},
},
ompObject.u);
object.u);
}
return;
}

void OmpStructureChecker::Enter(const parser::OmpClause::Update &x) {
Expand Down Expand Up @@ -4009,7 +4071,7 @@ void OmpStructureChecker::Enter(const parser::OmpClause::Update &x) {
}

void OmpStructureChecker::Enter(const parser::OmpClause::UseDevicePtr &x) {
CheckStructureElement(x.v, llvm::omp::Clause::OMPC_use_device_ptr);
CheckStructureComponent(x.v, llvm::omp::Clause::OMPC_use_device_ptr);
CheckAllowedClause(llvm::omp::Clause::OMPC_use_device_ptr);
SymbolSourceMap currSymbols;
GetSymbolsInObjectList(x.v, currSymbols);
Expand Down Expand Up @@ -4038,7 +4100,7 @@ void OmpStructureChecker::Enter(const parser::OmpClause::UseDevicePtr &x) {
}

void OmpStructureChecker::Enter(const parser::OmpClause::UseDeviceAddr &x) {
CheckStructureElement(x.v, llvm::omp::Clause::OMPC_use_device_addr);
CheckStructureComponent(x.v, llvm::omp::Clause::OMPC_use_device_addr);
CheckAllowedClause(llvm::omp::Clause::OMPC_use_device_addr);
SymbolSourceMap currSymbols;
GetSymbolsInObjectList(x.v, currSymbols);
Expand Down Expand Up @@ -4214,6 +4276,26 @@ llvm::StringRef OmpStructureChecker::getDirectiveName(
return llvm::omp::getOpenMPDirectiveName(directive);
}

const Symbol *OmpStructureChecker::GetObjectSymbol(
const parser::OmpObject &object) {
if (auto *name{std::get_if<parser::Name>(&object.u)}) {
return &name->symbol->GetUltimate();
} else if (auto *desg{std::get_if<parser::Designator>(&object.u)}) {
return &GetLastName(*desg).symbol->GetUltimate();
}
return nullptr;
}

std::optional<parser::CharBlock> OmpStructureChecker::GetObjectSource(
const parser::OmpObject &object) {
if (auto *name{std::get_if<parser::Name>(&object.u)}) {
return name->source;
} else if (auto *desg{std::get_if<parser::Designator>(&object.u)}) {
return GetLastName(*desg).source;
}
return std::nullopt;
}

void OmpStructureChecker::CheckDependList(const parser::DataRef &d) {
common::visit(
common::visitors{
Expand Down Expand Up @@ -4267,15 +4349,6 @@ void OmpStructureChecker::CheckArraySection(
"DEPEND "
"clause"_err_en_US);
}
const auto stride{GetIntValue(strideExpr)};
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is checked directly via IsContiguous, no longer necessary here.

if ((stride && stride != 1)) {
context_.Say(GetContext().clauseSource,
"A list item that appears in a REDUCTION clause"
" should have a contiguous storage array "
"section."_err_en_US,
ContextDirectiveAsFortran());
break;
}
}
}
}
Expand All @@ -4286,14 +4359,23 @@ void OmpStructureChecker::CheckArraySection(
}

void OmpStructureChecker::CheckIntentInPointer(
const parser::OmpObjectList &objectList, const llvm::omp::Clause clause) {
SymbolSourceMap symbols;
GetSymbolsInObjectList(objectList, symbols);
SymbolSourceMap &symbols, llvm::omp::Clause clauseId) {
for (auto &[symbol, source] : symbols) {
if (IsPointer(*symbol) && IsIntentIn(*symbol)) {
context_.Say(source,
"Pointer '%s' with the INTENT(IN) attribute may not appear "
"in a %s clause"_err_en_US,
"Pointer '%s' with the INTENT(IN) attribute may not appear in a %s clause"_err_en_US,
symbol->name(),
parser::ToUpperCaseLetters(getClauseName(clauseId).str()));
}
}
}

void OmpStructureChecker::CheckProcedurePointer(
SymbolSourceMap &symbols, llvm::omp::Clause clause) {
for (const auto &[symbol, source] : symbols) {
if (IsProcedurePointer(*symbol)) {
context_.Say(source,
"Procedure pointer '%s' may not appear in a %s clause"_err_en_US,
symbol->name(),
parser::ToUpperCaseLetters(getClauseName(clause).str()));
}
Expand Down
Loading
Loading