Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions flang/include/flang/Semantics/openmp-utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,21 @@

#include <optional>
#include <string>
#include <type_traits>
#include <utility>

namespace Fortran::semantics {
class SemanticsContext;
class Symbol;

// Add this namespace to avoid potential conflicts
namespace omp {
template <typename T, typename U = std::remove_const_t<T>> U AsRvalue(T &t) {
return U(t);
}

template <typename T> T &&AsRvalue(T &&t) { return std::move(t); }

// There is no consistent way to get the source of an ActionStmt, but there
// is "source" in Statement<T>. This structure keeps the ActionStmt with the
// extracted source for further use.
Expand Down
271 changes: 0 additions & 271 deletions flang/lib/Lower/OpenMP/Atomic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,179 +43,6 @@ namespace omp {
using namespace Fortran::lower::omp;
}

namespace {
// An example of a type that can be used to get the return value from
// the visitor:
// visitor(type_identity<Xyz>) -> result_type
using SomeArgType = evaluate::Type<common::TypeCategory::Integer, 4>;

struct GetProc
: public evaluate::Traverse<GetProc, const evaluate::ProcedureDesignator *,
false> {
using Result = const evaluate::ProcedureDesignator *;
using Base = evaluate::Traverse<GetProc, Result, false>;
GetProc() : Base(*this) {}

using Base::operator();

static Result Default() { return nullptr; }

Result operator()(const evaluate::ProcedureDesignator &p) const { return &p; }
static Result Combine(Result a, Result b) { return a != nullptr ? a : b; }
};

struct WithType {
WithType(const evaluate::DynamicType &t) : type(t) {
assert(type.category() != common::TypeCategory::Derived &&
"Type cannot be a derived type");
}

template <typename VisitorTy> //
auto visit(VisitorTy &&visitor) const
-> std::invoke_result_t<VisitorTy, SomeArgType> {
switch (type.category()) {
case common::TypeCategory::Integer:
switch (type.kind()) {
case 1:
return visitor(llvm::type_identity<evaluate::Type<Integer, 1>>{});
case 2:
return visitor(llvm::type_identity<evaluate::Type<Integer, 2>>{});
case 4:
return visitor(llvm::type_identity<evaluate::Type<Integer, 4>>{});
case 8:
return visitor(llvm::type_identity<evaluate::Type<Integer, 8>>{});
case 16:
return visitor(llvm::type_identity<evaluate::Type<Integer, 16>>{});
}
break;
case common::TypeCategory::Unsigned:
switch (type.kind()) {
case 1:
return visitor(llvm::type_identity<evaluate::Type<Unsigned, 1>>{});
case 2:
return visitor(llvm::type_identity<evaluate::Type<Unsigned, 2>>{});
case 4:
return visitor(llvm::type_identity<evaluate::Type<Unsigned, 4>>{});
case 8:
return visitor(llvm::type_identity<evaluate::Type<Unsigned, 8>>{});
case 16:
return visitor(llvm::type_identity<evaluate::Type<Unsigned, 16>>{});
}
break;
case common::TypeCategory::Real:
switch (type.kind()) {
case 2:
return visitor(llvm::type_identity<evaluate::Type<Real, 2>>{});
case 3:
return visitor(llvm::type_identity<evaluate::Type<Real, 3>>{});
case 4:
return visitor(llvm::type_identity<evaluate::Type<Real, 4>>{});
case 8:
return visitor(llvm::type_identity<evaluate::Type<Real, 8>>{});
case 10:
return visitor(llvm::type_identity<evaluate::Type<Real, 10>>{});
case 16:
return visitor(llvm::type_identity<evaluate::Type<Real, 16>>{});
}
break;
case common::TypeCategory::Complex:
switch (type.kind()) {
case 2:
return visitor(llvm::type_identity<evaluate::Type<Complex, 2>>{});
case 3:
return visitor(llvm::type_identity<evaluate::Type<Complex, 3>>{});
case 4:
return visitor(llvm::type_identity<evaluate::Type<Complex, 4>>{});
case 8:
return visitor(llvm::type_identity<evaluate::Type<Complex, 8>>{});
case 10:
return visitor(llvm::type_identity<evaluate::Type<Complex, 10>>{});
case 16:
return visitor(llvm::type_identity<evaluate::Type<Complex, 16>>{});
}
break;
case common::TypeCategory::Logical:
switch (type.kind()) {
case 1:
return visitor(llvm::type_identity<evaluate::Type<Logical, 1>>{});
case 2:
return visitor(llvm::type_identity<evaluate::Type<Logical, 2>>{});
case 4:
return visitor(llvm::type_identity<evaluate::Type<Logical, 4>>{});
case 8:
return visitor(llvm::type_identity<evaluate::Type<Logical, 8>>{});
}
break;
case common::TypeCategory::Character:
switch (type.kind()) {
case 1:
return visitor(llvm::type_identity<evaluate::Type<Character, 1>>{});
case 2:
return visitor(llvm::type_identity<evaluate::Type<Character, 2>>{});
case 4:
return visitor(llvm::type_identity<evaluate::Type<Character, 4>>{});
}
break;
case common::TypeCategory::Derived:
(void)Derived;
break;
}
llvm_unreachable("Unhandled type");
}

const evaluate::DynamicType &type;

private:
// Shorter names.
static constexpr auto Character = common::TypeCategory::Character;
static constexpr auto Complex = common::TypeCategory::Complex;
static constexpr auto Derived = common::TypeCategory::Derived;
static constexpr auto Integer = common::TypeCategory::Integer;
static constexpr auto Logical = common::TypeCategory::Logical;
static constexpr auto Real = common::TypeCategory::Real;
static constexpr auto Unsigned = common::TypeCategory::Unsigned;
};

template <typename T, typename U = std::remove_const_t<T>>
U AsRvalue(T &t) {
U copy{t};
return std::move(copy);
}

template <typename T>
T &&AsRvalue(T &&t) {
return std::move(t);
}

struct ArgumentReplacer
: public evaluate::Traverse<ArgumentReplacer, bool, false> {
using Base = evaluate::Traverse<ArgumentReplacer, bool, false>;
using Result = bool;

Result Default() const { return false; }

ArgumentReplacer(evaluate::ActualArguments &&newArgs)
: Base(*this), args_(std::move(newArgs)) {}

using Base::operator();

template <typename T>
Result operator()(const evaluate::FunctionRef<T> &x) {
assert(!done_);
auto &mut = const_cast<evaluate::FunctionRef<T> &>(x);
mut.arguments() = args_;
done_ = true;
return true;
}

Result Combine(Result &&a, Result &&b) { return a || b; }

private:
bool done_{false};
evaluate::ActualArguments &&args_;
};
} // namespace

[[maybe_unused]] static void
dumpAtomicAnalysis(const parser::OpenMPAtomicConstruct::Analysis &analysis) {
auto whatStr = [](int k) {
Expand Down Expand Up @@ -412,85 +239,6 @@ makeMemOrderAttr(lower::AbstractConverter &converter,
return nullptr;
}

static bool replaceArgs(semantics::SomeExpr &expr,
evaluate::ActualArguments &&newArgs) {
return ArgumentReplacer(std::move(newArgs))(expr);
}

static semantics::SomeExpr makeCall(const evaluate::DynamicType &type,
const evaluate::ProcedureDesignator &proc,
const evaluate::ActualArguments &args) {
return WithType(type).visit([&](auto &&s) -> semantics::SomeExpr {
using Type = typename llvm::remove_cvref_t<decltype(s)>::type;
return evaluate::AsGenericExpr(
evaluate::FunctionRef<Type>(AsRvalue(proc), AsRvalue(args)));
});
}

static const evaluate::ProcedureDesignator &
getProcedureDesignator(const semantics::SomeExpr &call) {
const evaluate::ProcedureDesignator *proc = GetProc{}(call);
assert(proc && "Call has no procedure designator");
return *proc;
}

static semantics::SomeExpr //
genReducedMinMax(const semantics::SomeExpr &orig,
const semantics::SomeExpr *atomArg,
const std::vector<semantics::SomeExpr> &args) {
// Take a list of arguments to a min/max operation, e.g. [a0, a1, ...]
// One of the a_i's, say a_t, must be atomArg.
// Generate tmp = min/max(a0, a1, ... [except a_t]). Then generate
// call = min/max(a_t, tmp).
// Return "call".

// The min/max intrinsics have 2 mandatory arguments, the rest is optional.
// Make sure that the "tmp = min/max(...)" doesn't promote an optional
// argument to a non-optional position. This could happen if a_t is at
// position 0 or 1.
if (args.size() <= 2)
return orig;

evaluate::ActualArguments nonAtoms;

auto AsActual = [](const semantics::SomeExpr &x) {
semantics::SomeExpr copy = x;
return evaluate::ActualArgument(std::move(copy));
};
// Semantic checks guarantee that the "atom" shows exactly once in the
// argument list (with potential conversions around it).
// For the first two (non-optional) arguments, if "atom" is among them,
// replace it with another occurrence of the other non-optional argument.
if (atomArg == &args[0]) {
// (atom, x, y...) -> (x, x, y...)
nonAtoms.push_back(AsActual(args[1]));
nonAtoms.push_back(AsActual(args[1]));
} else if (atomArg == &args[1]) {
// (x, atom, y...) -> (x, x, y...)
nonAtoms.push_back(AsActual(args[0]));
nonAtoms.push_back(AsActual(args[0]));
} else {
// (x, y, z...) -> unchanged
nonAtoms.push_back(AsActual(args[0]));
nonAtoms.push_back(AsActual(args[1]));
}

// The rest of arguments are optional, so we can just skip "atom".
for (size_t i = 2, e = args.size(); i != e; ++i) {
if (atomArg != &args[i])
nonAtoms.push_back(AsActual(args[i]));
}

// The type of the intermediate min/max is the same as the type of its
// arguments, which may be different from the type of the original
// expression. The original expression may have additional coverts.
auto tmp =
makeCall(*atomArg->GetType(), getProcedureDesignator(orig), nonAtoms);
semantics::SomeExpr call = orig;
replaceArgs(call, {AsActual(*atomArg), AsActual(tmp)});
return call;
}

static mlir::Operation * //
genAtomicRead(lower::AbstractConverter &converter,
semantics::SemanticsContext &semaCtx, mlir::Location loc,
Expand Down Expand Up @@ -610,25 +358,6 @@ genAtomicUpdate(lower::AbstractConverter &converter,
auto [opcode, args] = evaluate::GetTopLevelOperationIgnoreResizing(input);
assert(!args.empty() && "Update operation without arguments");

// Pass args as an argument to avoid capturing a structured binding.
const semantics::SomeExpr *atomArg = [&](auto &args) {
for (const semantics::SomeExpr &e : args) {
if (evaluate::IsSameOrConvertOf(e, atom))
return &e;
}
llvm_unreachable("Atomic variable not in argument list");
}(args);

if (opcode == evaluate::operation::Operator::Min ||
opcode == evaluate::operation::Operator::Max) {
// Min and max operations are expanded inline, so reduce them to
// operations with exactly two (non-optional) arguments.
rhs = genReducedMinMax(rhs, atomArg, args);
input = *evaluate::GetConvertInput(rhs);
std::tie(opcode, args) =
evaluate::GetTopLevelOperationIgnoreResizing(input);
atomArg = nullptr; // No longer valid.
}
for (auto &arg : args) {
if (!evaluate::IsSameOrConvertOf(arg, atom)) {
mlir::Value val = fir::getBase(converter.genExprValue(arg, naCtx, &loc));
Expand Down
Loading