Skip to content
282 changes: 241 additions & 41 deletions flang/lib/Semantics/check-omp-atomic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
#include "check-omp-structure.h"

#include "flang/Common/indirection.h"
#include "flang/Common/template.h"
#include "flang/Evaluate/expression.h"
#include "flang/Evaluate/match.h"
#include "flang/Evaluate/rewrite.h"
#include "flang/Evaluate/tools.h"
#include "flang/Parser/char-block.h"
Expand Down Expand Up @@ -50,6 +52,127 @@ static bool operator!=(const evaluate::Expr<T> &e, const evaluate::Expr<U> &f) {
return !(e == f);
}

namespace {
template <typename...> struct IsIntegral {
static constexpr bool value{false};
};

template <common::TypeCategory C, int K>
struct IsIntegral<evaluate::Type<C, K>> {
static constexpr bool value{//
C == common::TypeCategory::Integer ||
C == common::TypeCategory::Unsigned ||
C == common::TypeCategory::Logical};
};

template <typename T> constexpr bool is_integral_v{IsIntegral<T>::value};

template <typename T, typename Op0, typename Op1>
using ReassocOpBase = evaluate::match::AnyOfPattern< //
evaluate::match::Add<T, Op0, Op1>, //
evaluate::match::Mul<T, Op0, Op1>>;

template <typename T, typename Op0, typename Op1>
struct ReassocOp : public ReassocOpBase<T, Op0, Op1> {
using Base = ReassocOpBase<T, Op0, Op1>;
using Base::Base;
};

template <typename T, typename Op0, typename Op1>
ReassocOp<T, Op0, Op1> reassocOp(const Op0 &op0, const Op1 &op1) {
return ReassocOp<T, Op0, Op1>(op0, op1);
}
} // namespace

struct ReassocRewriter : public evaluate::rewrite::Identity {
using Id = evaluate::rewrite::Identity;
using Id::operator();
struct NonIntegralTag {};

ReassocRewriter(const SomeExpr &atom) : atom_(atom) {}

// Try to find cases where the input expression is of the form
// (1) (a . b) . c, or
// (2) a . (b . c),
// where . denotes an associative operation (currently + or *), and a, b, c
// are some subexpresions.
// If one of the operands in the nested operation is the atomic variable
// (with some possible type conversions applied to it), bring it to the
// top-level operation, and move the top-level operand into the nested
// operation.
// For example, assuming x is the atomic variable:
// (a + x) + b -> (a + b) + x, i.e. (conceptually) swap x and b.
template <typename T, typename U,
typename = std::enable_if_t<is_integral_v<T>>>
evaluate::Expr<T> operator()(evaluate::Expr<T> &&x, const U &u) {
// As per the above comment, there are 3 subexpressions involved in this
// transformation. A match::Expr<T> will match evaluate::Expr<U> when T is
// same as U, plus it will store a pointer (ref) to the matched expression.
// When the match is successful, the sub[i].ref will point to a, b, x (in
// some order) from the example above.
evaluate::match::Expr<T> sub[3];
auto inner{reassocOp<T>(sub[0], sub[1])};
auto outer1{reassocOp<T>(inner, sub[2])}; // inner + something
auto outer2{reassocOp<T>(sub[2], inner)}; // something + inner
// There is no way to ensure that the outer operation is the same as
// the inner one. They are matched independently, so we need to compare
// the index in the member variant that represents the matched type.
if ((match(outer1, x) && outer1.ref.index() == inner.ref.index()) ||
(match(outer2, x) && outer2.ref.index() == inner.ref.index())) {
size_t atomIdx{[&]() { // sub[atomIdx] will be the atom.
size_t idx;
for (idx = 0; idx != 3; ++idx) {
if (IsAtom(*sub[idx].ref)) {
break;
}
}
return idx;
}()};

if (atomIdx > 2) {
return Id::operator()(std::move(x), u);
}
return common::visit(
[&](auto &&s) {
using Expr = evaluate::Expr<T>;
using TypeS = llvm::remove_cvref_t<decltype(s)>;
// This visitor has to be semantically correct for all possible
// types of s even though at runtime s will only be one of the
// matched types.
// Limit the construction to the operation types that we tried
// to match (otherwise TypeS(op1, op2) would fail for non-binary
// operations).
if constexpr (common::HasMember<TypeS,
typename decltype(outer1)::MatchTypes>) {
Expr atom{*sub[atomIdx].ref};
Expr op1{*sub[(atomIdx + 1) % 3].ref};
Expr op2{*sub[(atomIdx + 2) % 3].ref};
return Expr(
TypeS(atom, Expr(TypeS(std::move(op1), std::move(op2)))));
} else {
return Expr(TypeS(s));
}
},
evaluate::match::deparen(x).u);
}
return Id::operator()(std::move(x), u);
}

template <typename T, typename U,
typename = std::enable_if_t<!is_integral_v<T>>>
evaluate::Expr<T> operator()(
evaluate::Expr<T> &&x, const U &u, NonIntegralTag = {}) {
return Id::operator()(std::move(x), u);
}

private:
template <typename T> bool IsAtom(const evaluate::Expr<T> &x) const {
return IsSameOrConvertOf(evaluate::AsGenericExpr(AsRvalue(x)), atom_);
}

const SomeExpr &atom_;
};

struct AnalyzedCondStmt {
SomeExpr cond{evaluate::NullPointer{}}; // Default ctor is deleted
parser::CharBlock source;
Expand Down Expand Up @@ -199,6 +322,26 @@ static std::pair<parser::CharBlock, parser::CharBlock> SplitAssignmentSource(
llvm_unreachable("Could not find assignment operator");
}

static std::vector<SomeExpr> GetNonAtomExpressions(
const SomeExpr &atom, const std::vector<SomeExpr> &exprs) {
std::vector<SomeExpr> nonAtom;
for (const SomeExpr &e : exprs) {
if (!IsSameOrConvertOf(e, atom)) {
nonAtom.push_back(e);
}
}
return nonAtom;
}

static std::vector<SomeExpr> GetNonAtomArguments(
const SomeExpr &atom, const SomeExpr &expr) {
if (auto &&maybe{GetConvertInput(expr)}) {
return GetNonAtomExpressions(
atom, GetTopLevelOperationIgnoreResizing(*maybe).second);
}
return {};
}

static bool IsCheckForAssociated(const SomeExpr &cond) {
return GetTopLevelOperationIgnoreResizing(cond).first ==
operation::Operator::Associated;
Expand Down Expand Up @@ -625,7 +768,8 @@ void OmpStructureChecker::CheckAtomicWriteAssignment(
}
}

void OmpStructureChecker::CheckAtomicUpdateAssignment(
std::optional<evaluate::Assignment>
OmpStructureChecker::CheckAtomicUpdateAssignment(
const evaluate::Assignment &update, parser::CharBlock source) {
// [6.0:191:1-7]
// An update structured block is update-statement, an update statement
Expand All @@ -641,14 +785,46 @@ void OmpStructureChecker::CheckAtomicUpdateAssignment(
if (!IsVarOrFunctionRef(atom)) {
ErrorShouldBeVariable(atom, rsrc);
// Skip other checks.
return;
return std::nullopt;
}

CheckAtomicVariable(atom, lsrc);

auto [hasErrors, tryReassoc]{CheckAtomicUpdateAssignmentRhs(
atom, update.rhs, source, /*suppressDiagnostics=*/true)};

if (!hasErrors) {
CheckStorageOverlap(atom, GetNonAtomArguments(atom, update.rhs), source);
return std::nullopt;
} else if (tryReassoc) {
ReassocRewriter ra(atom);
SomeExpr raRhs{evaluate::rewrite::Mutator(ra)(update.rhs)};

std::tie(hasErrors, tryReassoc) = CheckAtomicUpdateAssignmentRhs(
atom, raRhs, source, /*suppressDiagnostics=*/true);
if (!hasErrors) {
CheckStorageOverlap(atom, GetNonAtomArguments(atom, raRhs), source);

evaluate::Assignment raAssign(update);
raAssign.rhs = raRhs;
return raAssign;
}
}

// This is guaranteed to report errors.
CheckAtomicUpdateAssignmentRhs(
atom, update.rhs, source, /*suppressDiagnostics=*/false);
return std::nullopt;
}

std::pair<bool, bool> OmpStructureChecker::CheckAtomicUpdateAssignmentRhs(
const SomeExpr &atom, const SomeExpr &rhs, parser::CharBlock source,
bool suppressDiagnostics) {
auto [lsrc, rsrc]{SplitAssignmentSource(source)};

std::pair<operation::Operator, std::vector<SomeExpr>> top{
operation::Operator::Unknown, {}};
if (auto &&maybeInput{GetConvertInput(update.rhs)}) {
if (auto &&maybeInput{GetConvertInput(rhs)}) {
top = GetTopLevelOperationIgnoreResizing(*maybeInput);
}
switch (top.first) {
Expand All @@ -665,29 +841,39 @@ void OmpStructureChecker::CheckAtomicUpdateAssignment(
case operation::Operator::Identity:
break;
case operation::Operator::Call:
context_.Say(source,
"A call to this function is not a valid ATOMIC UPDATE operation"_err_en_US);
return;
if (!suppressDiagnostics) {
context_.Say(source,
"A call to this function is not a valid ATOMIC UPDATE operation"_err_en_US);
}
return std::make_pair(true, false);
case operation::Operator::Convert:
context_.Say(source,
"An implicit or explicit type conversion is not a valid ATOMIC UPDATE operation"_err_en_US);
return;
if (!suppressDiagnostics) {
context_.Say(source,
"An implicit or explicit type conversion is not a valid ATOMIC UPDATE operation"_err_en_US);
}
return std::make_pair(true, false);
case operation::Operator::Intrinsic:
context_.Say(source,
"This intrinsic function is not a valid ATOMIC UPDATE operation"_err_en_US);
return;
if (!suppressDiagnostics) {
context_.Say(source,
"This intrinsic function is not a valid ATOMIC UPDATE operation"_err_en_US);
}
return std::make_pair(true, false);
case operation::Operator::Constant:
case operation::Operator::Unknown:
context_.Say(
source, "This is not a valid ATOMIC UPDATE operation"_err_en_US);
return;
if (!suppressDiagnostics) {
context_.Say(
source, "This is not a valid ATOMIC UPDATE operation"_err_en_US);
}
return std::make_pair(true, false);
default:
assert(
top.first != operation::Operator::Identity && "Handle this separately");
context_.Say(source,
"The %s operator is not a valid ATOMIC UPDATE operation"_err_en_US,
operation::ToString(top.first));
return;
if (!suppressDiagnostics) {
context_.Say(source,
"The %s operator is not a valid ATOMIC UPDATE operation"_err_en_US,
operation::ToString(top.first));
}
return std::make_pair(true, false);
}
// Check how many times `atom` occurs as an argument, if it's a subexpression
// of an argument, and collect the non-atom arguments.
Expand All @@ -708,39 +894,48 @@ void OmpStructureChecker::CheckAtomicUpdateAssignment(
return count;
}()};

bool hasError{false};
bool hasError{false}, tryReassoc{false};
if (subExpr) {
context_.Say(rsrc,
"The atomic variable %s cannot be a proper subexpression of an argument (here: %s) in the update operation"_err_en_US,
atom.AsFortran(), subExpr->AsFortran());
if (!suppressDiagnostics) {
context_.Say(rsrc,
"The atomic variable %s cannot be a proper subexpression of an argument (here: %s) in the update operation"_err_en_US,
atom.AsFortran(), subExpr->AsFortran());
}
hasError = true;
}
if (top.first == operation::Operator::Identity) {
// This is "x = y".
assert((atomCount == 0 || atomCount == 1) && "Unexpected count");
if (atomCount == 0) {
context_.Say(rsrc,
"The atomic variable %s should appear as an argument in the update operation"_err_en_US,
atom.AsFortran());
if (!suppressDiagnostics) {
context_.Say(rsrc,
"The atomic variable %s should appear as an argument in the update operation"_err_en_US,
atom.AsFortran());
}
hasError = true;
}
} else {
if (atomCount == 0) {
context_.Say(rsrc,
"The atomic variable %s should appear as an argument of the top-level %s operator"_err_en_US,
atom.AsFortran(), operation::ToString(top.first));
if (!suppressDiagnostics) {
context_.Say(rsrc,
"The atomic variable %s should appear as an argument of the top-level %s operator"_err_en_US,
atom.AsFortran(), operation::ToString(top.first));
}
// If `atom` is a proper subexpression, and it not present as an
// argument on its own, reassociation may be able to help.
tryReassoc = subExpr.has_value();
hasError = true;
} else if (atomCount > 1) {
context_.Say(rsrc,
"The atomic variable %s should be exactly one of the arguments of the top-level %s operator"_err_en_US,
atom.AsFortran(), operation::ToString(top.first));
if (!suppressDiagnostics) {
context_.Say(rsrc,
"The atomic variable %s should be exactly one of the arguments of the top-level %s operator"_err_en_US,
atom.AsFortran(), operation::ToString(top.first));
}
hasError = true;
}
}

if (!hasError) {
CheckStorageOverlap(atom, nonAtom, source);
}
return std::make_pair(hasError, tryReassoc);
}

void OmpStructureChecker::CheckAtomicConditionalUpdateAssignment(
Expand Down Expand Up @@ -843,11 +1038,13 @@ void OmpStructureChecker::CheckAtomicUpdateOnly(
SourcedActionStmt action{GetActionStmt(&body.front())};
if (auto maybeUpdate{GetEvaluateAssignment(action.stmt)}) {
const SomeExpr &atom{maybeUpdate->lhs};
CheckAtomicUpdateAssignment(*maybeUpdate, action.source);
auto maybeAssign{
CheckAtomicUpdateAssignment(*maybeUpdate, action.source)};
auto &updateAssign{maybeAssign.has_value() ? maybeAssign : maybeUpdate};

using Analysis = parser::OpenMPAtomicConstruct::Analysis;
x.analysis = AtomicAnalysis(atom)
.addOp0(Analysis::Update, maybeUpdate)
.addOp0(Analysis::Update, updateAssign)
.addOp1(Analysis::None);
} else if (!IsAssignment(action.stmt)) {
context_.Say(
Expand Down Expand Up @@ -963,29 +1160,32 @@ void OmpStructureChecker::CheckAtomicUpdateCapture(
using Analysis = parser::OpenMPAtomicConstruct::Analysis;
int action;

std::optional<evaluate::Assignment> updateAssign{update};
if (IsMaybeAtomicWrite(update)) {
action = Analysis::Write;
CheckAtomicWriteAssignment(update, uact.source);
} else {
action = Analysis::Update;
CheckAtomicUpdateAssignment(update, uact.source);
if (auto &&maybe{CheckAtomicUpdateAssignment(update, uact.source)}) {
updateAssign = maybe;
}
}
CheckAtomicCaptureAssignment(capture, atom, cact.source);

if (IsPointerAssignment(update) != IsPointerAssignment(capture)) {
if (IsPointerAssignment(*updateAssign) != IsPointerAssignment(capture)) {
context_.Say(cact.source,
"The update and capture assignments should both be pointer-assignments or both be non-pointer-assignments"_err_en_US);
return;
}

if (GetActionStmt(&body.front()).stmt == uact.stmt) {
x.analysis = AtomicAnalysis(atom)
.addOp0(action, update)
.addOp0(action, updateAssign)
.addOp1(Analysis::Read, capture);
} else {
x.analysis = AtomicAnalysis(atom)
.addOp0(Analysis::Read, capture)
.addOp1(action, update);
.addOp1(action, updateAssign);
}
}

Expand Down
Loading