Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions flang/include/flang/Semantics/openmp-utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ template <typename T, typename U = std::remove_const_t<T>> U AsRvalue(T &t) {
template <typename T> T &&AsRvalue(T &&t) { return std::move(t); }

const Scope &GetScopingUnit(const Scope &scope);
const Scope &GetProgramUnit(const Scope &scope);

// There is no consistent way to get the source of an ActionStmt, but there
// is "source" in Statement<T>. This structure keeps the ActionStmt with the
Expand Down
25 changes: 14 additions & 11 deletions flang/include/flang/Semantics/symbol.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "flang/Semantics/module-dependences.h"
#include "flang/Support/Fortran.h"
#include "llvm/ADT/DenseMapInfo.h"
#include "llvm/Frontend/OpenMP/OMP.h"

#include <array>
#include <functional>
Expand Down Expand Up @@ -50,32 +51,34 @@ using MutableSymbolVector = std::vector<MutableSymbolRef>;

// Mixin for details with OpenMP declarative constructs.
class WithOmpDeclarative {
using OmpAtomicOrderType = common::OmpMemoryOrderType;

public:
ENUM_CLASS(RequiresFlag, ReverseOffload, UnifiedAddress, UnifiedSharedMemory,
DynamicAllocators);
using RequiresFlags = common::EnumSet<RequiresFlag, RequiresFlag_enumSize>;
// The set of requirements for any program unit include requirements
// from any module used in the program unit.
using RequiresClauses =
common::EnumSet<llvm::omp::Clause, llvm::omp::Clause_enumSize>;

bool has_ompRequires() const { return ompRequires_.has_value(); }
const RequiresFlags *ompRequires() const {
const RequiresClauses *ompRequires() const {
return ompRequires_ ? &*ompRequires_ : nullptr;
}
void set_ompRequires(RequiresFlags flags) { ompRequires_ = flags; }
void set_ompRequires(RequiresClauses clauses) { ompRequires_ = clauses; }

bool has_ompAtomicDefaultMemOrder() const {
return ompAtomicDefaultMemOrder_.has_value();
}
const OmpAtomicOrderType *ompAtomicDefaultMemOrder() const {
const common::OmpMemoryOrderType *ompAtomicDefaultMemOrder() const {
return ompAtomicDefaultMemOrder_ ? &*ompAtomicDefaultMemOrder_ : nullptr;
}
void set_ompAtomicDefaultMemOrder(OmpAtomicOrderType flags) {
void set_ompAtomicDefaultMemOrder(common::OmpMemoryOrderType flags) {
ompAtomicDefaultMemOrder_ = flags;
}

friend llvm::raw_ostream &operator<<(
llvm::raw_ostream &, const WithOmpDeclarative &);

private:
std::optional<RequiresFlags> ompRequires_;
std::optional<OmpAtomicOrderType> ompAtomicDefaultMemOrder_;
std::optional<RequiresClauses> ompRequires_;
std::optional<common::OmpMemoryOrderType> ompAtomicDefaultMemOrder_;
};

// A module or submodule.
Expand Down
15 changes: 7 additions & 8 deletions flang/lib/Lower/OpenMP/OpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4208,18 +4208,17 @@ bool Fortran::lower::markOpenMPDeferredDeclareTargetFunctions(
void Fortran::lower::genOpenMPRequires(mlir::Operation *mod,
const semantics::Symbol *symbol) {
using MlirRequires = mlir::omp::ClauseRequires;
using SemaRequires = semantics::WithOmpDeclarative::RequiresFlag;

if (auto offloadMod =
llvm::dyn_cast<mlir::omp::OffloadModuleInterface>(mod)) {
semantics::WithOmpDeclarative::RequiresFlags semaFlags;
semantics::WithOmpDeclarative::RequiresClauses reqs;
if (symbol) {
common::visit(
[&](const auto &details) {
if constexpr (std::is_base_of_v<semantics::WithOmpDeclarative,
std::decay_t<decltype(details)>>) {
if (details.has_ompRequires())
semaFlags = *details.ompRequires();
reqs = *details.ompRequires();
}
},
symbol->details());
Expand All @@ -4228,14 +4227,14 @@ void Fortran::lower::genOpenMPRequires(mlir::Operation *mod,
// Use pre-populated omp.requires module attribute if it was set, so that
// the "-fopenmp-force-usm" compiler option is honored.
MlirRequires mlirFlags = offloadMod.getRequires();
if (semaFlags.test(SemaRequires::ReverseOffload))
if (reqs.test(llvm::omp::Clause::OMPC_dynamic_allocators))
mlirFlags = mlirFlags | MlirRequires::dynamic_allocators;
if (reqs.test(llvm::omp::Clause::OMPC_reverse_offload))
mlirFlags = mlirFlags | MlirRequires::reverse_offload;
if (semaFlags.test(SemaRequires::UnifiedAddress))
if (reqs.test(llvm::omp::Clause::OMPC_unified_address))
mlirFlags = mlirFlags | MlirRequires::unified_address;
if (semaFlags.test(SemaRequires::UnifiedSharedMemory))
if (reqs.test(llvm::omp::Clause::OMPC_unified_shared_memory))
mlirFlags = mlirFlags | MlirRequires::unified_shared_memory;
if (semaFlags.test(SemaRequires::DynamicAllocators))
mlirFlags = mlirFlags | MlirRequires::dynamic_allocators;

offloadMod.setRequires(mlirFlags);
}
Expand Down
37 changes: 37 additions & 0 deletions flang/lib/Semantics/mod-file.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,15 @@
#include "flang/Semantics/semantics.h"
#include "flang/Semantics/symbol.h"
#include "flang/Semantics/tools.h"
#include "llvm/Frontend/OpenMP/OMP.h"
#include "llvm/Support/FileSystem.h"
#include "llvm/Support/MemoryBuffer.h"
#include "llvm/Support/raw_ostream.h"
#include <algorithm>
#include <fstream>
#include <set>
#include <string_view>
#include <type_traits>
#include <variant>
#include <vector>

Expand Down Expand Up @@ -359,6 +361,40 @@ void ModFileWriter::PrepareRenamings(const Scope &scope) {
}
}

static void PutOpenMPRequirements(llvm::raw_ostream &os, const Symbol &symbol) {
using RequiresClauses = WithOmpDeclarative::RequiresClauses;
using OmpMemoryOrderType = common::OmpMemoryOrderType;

const auto [reqs, order]{common::visit(
[&](auto &&details)
-> std::pair<const RequiresClauses *, const OmpMemoryOrderType *> {
if constexpr (std::is_convertible_v<decltype(details),
const WithOmpDeclarative &>) {
return {details.ompRequires(), details.ompAtomicDefaultMemOrder()};
} else {
return {nullptr, nullptr};
}
},
symbol.details())};

if (order) {
llvm::omp::Clause atmo{llvm::omp::Clause::OMPC_atomic_default_mem_order};
os << "!$omp requires "
<< parser::ToLowerCaseLetters(llvm::omp::getOpenMPClauseName(atmo))
<< '(' << parser::ToLowerCaseLetters(EnumToString(*order)) << ")\n";
}
if (reqs) {
os << "!$omp requires";
reqs->IterateOverMembers([&](llvm::omp::Clause f) {
if (f != llvm::omp::Clause::OMPC_atomic_default_mem_order) {
os << ' '
<< parser::ToLowerCaseLetters(llvm::omp::getOpenMPClauseName(f));
}
});
os << "\n";
}
}

// Put out the visible symbols from scope.
void ModFileWriter::PutSymbols(
const Scope &scope, UnorderedSymbolSet *hermeticModules) {
Expand Down Expand Up @@ -396,6 +432,7 @@ void ModFileWriter::PutSymbols(
for (const Symbol &symbol : uses) {
PutUse(symbol);
}
PutOpenMPRequirements(decls_, DEREF(scope.symbol()));
for (const auto &set : scope.equivalenceSets()) {
if (!set.empty() &&
!set.front().symbol.test(Symbol::Flag::CompilerCreated)) {
Expand Down
23 changes: 22 additions & 1 deletion flang/lib/Semantics/openmp-utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "flang/Semantics/openmp-utils.h"

#include "flang/Common/Fortran-consts.h"
#include "flang/Common/idioms.h"
#include "flang/Common/indirection.h"
#include "flang/Common/reference.h"
#include "flang/Common/visit.h"
Expand Down Expand Up @@ -59,6 +60,26 @@ const Scope &GetScopingUnit(const Scope &scope) {
return *iter;
}

const Scope &GetProgramUnit(const Scope &scope) {
const Scope *unit{nullptr};
for (const Scope *iter{&scope}; !iter->IsTopLevel(); iter = &iter->parent()) {
switch (iter->kind()) {
case Scope::Kind::BlockData:
case Scope::Kind::MainProgram:
case Scope::Kind::Module:
return *iter;
case Scope::Kind::Subprogram:
// Ignore subprograms that are nested.
unit = iter;
break;
default:
break;
}
}
assert(unit && "Scope not in a program unit");
return *unit;
}

SourcedActionStmt GetActionStmt(const parser::ExecutionPartConstruct *x) {
if (x == nullptr) {
return SourcedActionStmt{};
Expand Down Expand Up @@ -202,7 +223,7 @@ std::optional<SomeExpr> GetEvaluateExpr(const parser::Expr &parserExpr) {
// ForwardOwningPointer typedExpr
// `- GenericExprWrapper ^.get()
// `- std::optional<Expr> ^->v
return typedExpr.get()->v;
return DEREF(typedExpr.get()).v;
}

std::optional<evaluate::DynamicType> GetDynamicType(
Expand Down
Loading