Skip to content

Commit 4f6ae2a

Browse files
authored
[flang][OpenMP] Reassociate ATOMIC update expressions (#153098)
An atomic update expression of form x = x + a + b is technically illegal, since the right-hand side is parsed as (x+a)+b, and the atomic variable x should be an argument to the top-level +. When the type of x is integer, the result of (x+a)+b is guaranteed to be the same as x+(a+b), so instead of reporting an error, the compiler can treat (x+a)+b as x+(a+b). This PR implements this kind of reassociation for integral types, and for the two arithmetic associative/commutative operators: + and *.
1 parent c198c14 commit 4f6ae2a

File tree

5 files changed

+329
-46
lines changed

5 files changed

+329
-46
lines changed

flang/lib/Semantics/check-omp-atomic.cpp

Lines changed: 241 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@
1313
#include "check-omp-structure.h"
1414

1515
#include "flang/Common/indirection.h"
16+
#include "flang/Common/template.h"
1617
#include "flang/Evaluate/expression.h"
18+
#include "flang/Evaluate/match.h"
1719
#include "flang/Evaluate/rewrite.h"
1820
#include "flang/Evaluate/tools.h"
1921
#include "flang/Parser/char-block.h"
@@ -50,6 +52,127 @@ static bool operator!=(const evaluate::Expr<T> &e, const evaluate::Expr<U> &f) {
5052
return !(e == f);
5153
}
5254

55+
namespace {
56+
template <typename...> struct IsIntegral {
57+
static constexpr bool value{false};
58+
};
59+
60+
template <common::TypeCategory C, int K>
61+
struct IsIntegral<evaluate::Type<C, K>> {
62+
static constexpr bool value{//
63+
C == common::TypeCategory::Integer ||
64+
C == common::TypeCategory::Unsigned ||
65+
C == common::TypeCategory::Logical};
66+
};
67+
68+
template <typename T> constexpr bool is_integral_v{IsIntegral<T>::value};
69+
70+
template <typename T, typename Op0, typename Op1>
71+
using ReassocOpBase = evaluate::match::AnyOfPattern< //
72+
evaluate::match::Add<T, Op0, Op1>, //
73+
evaluate::match::Mul<T, Op0, Op1>>;
74+
75+
template <typename T, typename Op0, typename Op1>
76+
struct ReassocOp : public ReassocOpBase<T, Op0, Op1> {
77+
using Base = ReassocOpBase<T, Op0, Op1>;
78+
using Base::Base;
79+
};
80+
81+
template <typename T, typename Op0, typename Op1>
82+
ReassocOp<T, Op0, Op1> reassocOp(const Op0 &op0, const Op1 &op1) {
83+
return ReassocOp<T, Op0, Op1>(op0, op1);
84+
}
85+
} // namespace
86+
87+
struct ReassocRewriter : public evaluate::rewrite::Identity {
88+
using Id = evaluate::rewrite::Identity;
89+
using Id::operator();
90+
struct NonIntegralTag {};
91+
92+
ReassocRewriter(const SomeExpr &atom) : atom_(atom) {}
93+
94+
// Try to find cases where the input expression is of the form
95+
// (1) (a . b) . c, or
96+
// (2) a . (b . c),
97+
// where . denotes an associative operation (currently + or *), and a, b, c
98+
// are some subexpresions.
99+
// If one of the operands in the nested operation is the atomic variable
100+
// (with some possible type conversions applied to it), bring it to the
101+
// top-level operation, and move the top-level operand into the nested
102+
// operation.
103+
// For example, assuming x is the atomic variable:
104+
// (a + x) + b -> (a + b) + x, i.e. (conceptually) swap x and b.
105+
template <typename T, typename U,
106+
typename = std::enable_if_t<is_integral_v<T>>>
107+
evaluate::Expr<T> operator()(evaluate::Expr<T> &&x, const U &u) {
108+
// As per the above comment, there are 3 subexpressions involved in this
109+
// transformation. A match::Expr<T> will match evaluate::Expr<U> when T is
110+
// same as U, plus it will store a pointer (ref) to the matched expression.
111+
// When the match is successful, the sub[i].ref will point to a, b, x (in
112+
// some order) from the example above.
113+
evaluate::match::Expr<T> sub[3];
114+
auto inner{reassocOp<T>(sub[0], sub[1])};
115+
auto outer1{reassocOp<T>(inner, sub[2])}; // inner + something
116+
auto outer2{reassocOp<T>(sub[2], inner)}; // something + inner
117+
// There is no way to ensure that the outer operation is the same as
118+
// the inner one. They are matched independently, so we need to compare
119+
// the index in the member variant that represents the matched type.
120+
if ((match(outer1, x) && outer1.ref.index() == inner.ref.index()) ||
121+
(match(outer2, x) && outer2.ref.index() == inner.ref.index())) {
122+
size_t atomIdx{[&]() { // sub[atomIdx] will be the atom.
123+
size_t idx;
124+
for (idx = 0; idx != 3; ++idx) {
125+
if (IsAtom(*sub[idx].ref)) {
126+
break;
127+
}
128+
}
129+
return idx;
130+
}()};
131+
132+
if (atomIdx > 2) {
133+
return Id::operator()(std::move(x), u);
134+
}
135+
return common::visit(
136+
[&](auto &&s) {
137+
using Expr = evaluate::Expr<T>;
138+
using TypeS = llvm::remove_cvref_t<decltype(s)>;
139+
// This visitor has to be semantically correct for all possible
140+
// types of s even though at runtime s will only be one of the
141+
// matched types.
142+
// Limit the construction to the operation types that we tried
143+
// to match (otherwise TypeS(op1, op2) would fail for non-binary
144+
// operations).
145+
if constexpr (common::HasMember<TypeS,
146+
typename decltype(outer1)::MatchTypes>) {
147+
Expr atom{*sub[atomIdx].ref};
148+
Expr op1{*sub[(atomIdx + 1) % 3].ref};
149+
Expr op2{*sub[(atomIdx + 2) % 3].ref};
150+
return Expr(
151+
TypeS(atom, Expr(TypeS(std::move(op1), std::move(op2)))));
152+
} else {
153+
return Expr(TypeS(s));
154+
}
155+
},
156+
evaluate::match::deparen(x).u);
157+
}
158+
return Id::operator()(std::move(x), u);
159+
}
160+
161+
template <typename T, typename U,
162+
typename = std::enable_if_t<!is_integral_v<T>>>
163+
evaluate::Expr<T> operator()(
164+
evaluate::Expr<T> &&x, const U &u, NonIntegralTag = {}) {
165+
return Id::operator()(std::move(x), u);
166+
}
167+
168+
private:
169+
template <typename T> bool IsAtom(const evaluate::Expr<T> &x) const {
170+
return IsSameOrConvertOf(evaluate::AsGenericExpr(AsRvalue(x)), atom_);
171+
}
172+
173+
const SomeExpr &atom_;
174+
};
175+
53176
struct AnalyzedCondStmt {
54177
SomeExpr cond{evaluate::NullPointer{}}; // Default ctor is deleted
55178
parser::CharBlock source;
@@ -199,6 +322,26 @@ static std::pair<parser::CharBlock, parser::CharBlock> SplitAssignmentSource(
199322
llvm_unreachable("Could not find assignment operator");
200323
}
201324

325+
static std::vector<SomeExpr> GetNonAtomExpressions(
326+
const SomeExpr &atom, const std::vector<SomeExpr> &exprs) {
327+
std::vector<SomeExpr> nonAtom;
328+
for (const SomeExpr &e : exprs) {
329+
if (!IsSameOrConvertOf(e, atom)) {
330+
nonAtom.push_back(e);
331+
}
332+
}
333+
return nonAtom;
334+
}
335+
336+
static std::vector<SomeExpr> GetNonAtomArguments(
337+
const SomeExpr &atom, const SomeExpr &expr) {
338+
if (auto &&maybe{GetConvertInput(expr)}) {
339+
return GetNonAtomExpressions(
340+
atom, GetTopLevelOperationIgnoreResizing(*maybe).second);
341+
}
342+
return {};
343+
}
344+
202345
static bool IsCheckForAssociated(const SomeExpr &cond) {
203346
return GetTopLevelOperationIgnoreResizing(cond).first ==
204347
operation::Operator::Associated;
@@ -625,7 +768,8 @@ void OmpStructureChecker::CheckAtomicWriteAssignment(
625768
}
626769
}
627770

628-
void OmpStructureChecker::CheckAtomicUpdateAssignment(
771+
std::optional<evaluate::Assignment>
772+
OmpStructureChecker::CheckAtomicUpdateAssignment(
629773
const evaluate::Assignment &update, parser::CharBlock source) {
630774
// [6.0:191:1-7]
631775
// An update structured block is update-statement, an update statement
@@ -641,14 +785,46 @@ void OmpStructureChecker::CheckAtomicUpdateAssignment(
641785
if (!IsVarOrFunctionRef(atom)) {
642786
ErrorShouldBeVariable(atom, rsrc);
643787
// Skip other checks.
644-
return;
788+
return std::nullopt;
645789
}
646790

647791
CheckAtomicVariable(atom, lsrc);
648792

793+
auto [hasErrors, tryReassoc]{CheckAtomicUpdateAssignmentRhs(
794+
atom, update.rhs, source, /*suppressDiagnostics=*/true)};
795+
796+
if (!hasErrors) {
797+
CheckStorageOverlap(atom, GetNonAtomArguments(atom, update.rhs), source);
798+
return std::nullopt;
799+
} else if (tryReassoc) {
800+
ReassocRewriter ra(atom);
801+
SomeExpr raRhs{evaluate::rewrite::Mutator(ra)(update.rhs)};
802+
803+
std::tie(hasErrors, tryReassoc) = CheckAtomicUpdateAssignmentRhs(
804+
atom, raRhs, source, /*suppressDiagnostics=*/true);
805+
if (!hasErrors) {
806+
CheckStorageOverlap(atom, GetNonAtomArguments(atom, raRhs), source);
807+
808+
evaluate::Assignment raAssign(update);
809+
raAssign.rhs = raRhs;
810+
return raAssign;
811+
}
812+
}
813+
814+
// This is guaranteed to report errors.
815+
CheckAtomicUpdateAssignmentRhs(
816+
atom, update.rhs, source, /*suppressDiagnostics=*/false);
817+
return std::nullopt;
818+
}
819+
820+
std::pair<bool, bool> OmpStructureChecker::CheckAtomicUpdateAssignmentRhs(
821+
const SomeExpr &atom, const SomeExpr &rhs, parser::CharBlock source,
822+
bool suppressDiagnostics) {
823+
auto [lsrc, rsrc]{SplitAssignmentSource(source)};
824+
649825
std::pair<operation::Operator, std::vector<SomeExpr>> top{
650826
operation::Operator::Unknown, {}};
651-
if (auto &&maybeInput{GetConvertInput(update.rhs)}) {
827+
if (auto &&maybeInput{GetConvertInput(rhs)}) {
652828
top = GetTopLevelOperationIgnoreResizing(*maybeInput);
653829
}
654830
switch (top.first) {
@@ -665,29 +841,39 @@ void OmpStructureChecker::CheckAtomicUpdateAssignment(
665841
case operation::Operator::Identity:
666842
break;
667843
case operation::Operator::Call:
668-
context_.Say(source,
669-
"A call to this function is not a valid ATOMIC UPDATE operation"_err_en_US);
670-
return;
844+
if (!suppressDiagnostics) {
845+
context_.Say(source,
846+
"A call to this function is not a valid ATOMIC UPDATE operation"_err_en_US);
847+
}
848+
return std::make_pair(true, false);
671849
case operation::Operator::Convert:
672-
context_.Say(source,
673-
"An implicit or explicit type conversion is not a valid ATOMIC UPDATE operation"_err_en_US);
674-
return;
850+
if (!suppressDiagnostics) {
851+
context_.Say(source,
852+
"An implicit or explicit type conversion is not a valid ATOMIC UPDATE operation"_err_en_US);
853+
}
854+
return std::make_pair(true, false);
675855
case operation::Operator::Intrinsic:
676-
context_.Say(source,
677-
"This intrinsic function is not a valid ATOMIC UPDATE operation"_err_en_US);
678-
return;
856+
if (!suppressDiagnostics) {
857+
context_.Say(source,
858+
"This intrinsic function is not a valid ATOMIC UPDATE operation"_err_en_US);
859+
}
860+
return std::make_pair(true, false);
679861
case operation::Operator::Constant:
680862
case operation::Operator::Unknown:
681-
context_.Say(
682-
source, "This is not a valid ATOMIC UPDATE operation"_err_en_US);
683-
return;
863+
if (!suppressDiagnostics) {
864+
context_.Say(
865+
source, "This is not a valid ATOMIC UPDATE operation"_err_en_US);
866+
}
867+
return std::make_pair(true, false);
684868
default:
685869
assert(
686870
top.first != operation::Operator::Identity && "Handle this separately");
687-
context_.Say(source,
688-
"The %s operator is not a valid ATOMIC UPDATE operation"_err_en_US,
689-
operation::ToString(top.first));
690-
return;
871+
if (!suppressDiagnostics) {
872+
context_.Say(source,
873+
"The %s operator is not a valid ATOMIC UPDATE operation"_err_en_US,
874+
operation::ToString(top.first));
875+
}
876+
return std::make_pair(true, false);
691877
}
692878
// Check how many times `atom` occurs as an argument, if it's a subexpression
693879
// of an argument, and collect the non-atom arguments.
@@ -708,39 +894,48 @@ void OmpStructureChecker::CheckAtomicUpdateAssignment(
708894
return count;
709895
}()};
710896

711-
bool hasError{false};
897+
bool hasError{false}, tryReassoc{false};
712898
if (subExpr) {
713-
context_.Say(rsrc,
714-
"The atomic variable %s cannot be a proper subexpression of an argument (here: %s) in the update operation"_err_en_US,
715-
atom.AsFortran(), subExpr->AsFortran());
899+
if (!suppressDiagnostics) {
900+
context_.Say(rsrc,
901+
"The atomic variable %s cannot be a proper subexpression of an argument (here: %s) in the update operation"_err_en_US,
902+
atom.AsFortran(), subExpr->AsFortran());
903+
}
716904
hasError = true;
717905
}
718906
if (top.first == operation::Operator::Identity) {
719907
// This is "x = y".
720908
assert((atomCount == 0 || atomCount == 1) && "Unexpected count");
721909
if (atomCount == 0) {
722-
context_.Say(rsrc,
723-
"The atomic variable %s should appear as an argument in the update operation"_err_en_US,
724-
atom.AsFortran());
910+
if (!suppressDiagnostics) {
911+
context_.Say(rsrc,
912+
"The atomic variable %s should appear as an argument in the update operation"_err_en_US,
913+
atom.AsFortran());
914+
}
725915
hasError = true;
726916
}
727917
} else {
728918
if (atomCount == 0) {
729-
context_.Say(rsrc,
730-
"The atomic variable %s should appear as an argument of the top-level %s operator"_err_en_US,
731-
atom.AsFortran(), operation::ToString(top.first));
919+
if (!suppressDiagnostics) {
920+
context_.Say(rsrc,
921+
"The atomic variable %s should appear as an argument of the top-level %s operator"_err_en_US,
922+
atom.AsFortran(), operation::ToString(top.first));
923+
}
924+
// If `atom` is a proper subexpression, and it not present as an
925+
// argument on its own, reassociation may be able to help.
926+
tryReassoc = subExpr.has_value();
732927
hasError = true;
733928
} else if (atomCount > 1) {
734-
context_.Say(rsrc,
735-
"The atomic variable %s should be exactly one of the arguments of the top-level %s operator"_err_en_US,
736-
atom.AsFortran(), operation::ToString(top.first));
929+
if (!suppressDiagnostics) {
930+
context_.Say(rsrc,
931+
"The atomic variable %s should be exactly one of the arguments of the top-level %s operator"_err_en_US,
932+
atom.AsFortran(), operation::ToString(top.first));
933+
}
737934
hasError = true;
738935
}
739936
}
740937

741-
if (!hasError) {
742-
CheckStorageOverlap(atom, nonAtom, source);
743-
}
938+
return std::make_pair(hasError, tryReassoc);
744939
}
745940

746941
void OmpStructureChecker::CheckAtomicConditionalUpdateAssignment(
@@ -843,11 +1038,13 @@ void OmpStructureChecker::CheckAtomicUpdateOnly(
8431038
SourcedActionStmt action{GetActionStmt(&body.front())};
8441039
if (auto maybeUpdate{GetEvaluateAssignment(action.stmt)}) {
8451040
const SomeExpr &atom{maybeUpdate->lhs};
846-
CheckAtomicUpdateAssignment(*maybeUpdate, action.source);
1041+
auto maybeAssign{
1042+
CheckAtomicUpdateAssignment(*maybeUpdate, action.source)};
1043+
auto &updateAssign{maybeAssign.has_value() ? maybeAssign : maybeUpdate};
8471044

8481045
using Analysis = parser::OpenMPAtomicConstruct::Analysis;
8491046
x.analysis = AtomicAnalysis(atom)
850-
.addOp0(Analysis::Update, maybeUpdate)
1047+
.addOp0(Analysis::Update, updateAssign)
8511048
.addOp1(Analysis::None);
8521049
} else if (!IsAssignment(action.stmt)) {
8531050
context_.Say(
@@ -963,29 +1160,32 @@ void OmpStructureChecker::CheckAtomicUpdateCapture(
9631160
using Analysis = parser::OpenMPAtomicConstruct::Analysis;
9641161
int action;
9651162

1163+
std::optional<evaluate::Assignment> updateAssign{update};
9661164
if (IsMaybeAtomicWrite(update)) {
9671165
action = Analysis::Write;
9681166
CheckAtomicWriteAssignment(update, uact.source);
9691167
} else {
9701168
action = Analysis::Update;
971-
CheckAtomicUpdateAssignment(update, uact.source);
1169+
if (auto &&maybe{CheckAtomicUpdateAssignment(update, uact.source)}) {
1170+
updateAssign = maybe;
1171+
}
9721172
}
9731173
CheckAtomicCaptureAssignment(capture, atom, cact.source);
9741174

975-
if (IsPointerAssignment(update) != IsPointerAssignment(capture)) {
1175+
if (IsPointerAssignment(*updateAssign) != IsPointerAssignment(capture)) {
9761176
context_.Say(cact.source,
9771177
"The update and capture assignments should both be pointer-assignments or both be non-pointer-assignments"_err_en_US);
9781178
return;
9791179
}
9801180

9811181
if (GetActionStmt(&body.front()).stmt == uact.stmt) {
9821182
x.analysis = AtomicAnalysis(atom)
983-
.addOp0(action, update)
1183+
.addOp0(action, updateAssign)
9841184
.addOp1(Analysis::Read, capture);
9851185
} else {
9861186
x.analysis = AtomicAnalysis(atom)
9871187
.addOp0(Analysis::Read, capture)
988-
.addOp1(action, update);
1188+
.addOp1(action, updateAssign);
9891189
}
9901190
}
9911191

0 commit comments

Comments
 (0)