- 
                Notifications
    
You must be signed in to change notification settings  - Fork 15.1k
 
[flang][OpenMP] Reassociate ATOMIC update expressions #153488
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
An atomic update expression of form x = x + a + b is technically illegal, since the right-hand side is parsed as (x+a)+b, and the atomic variable x should be an argument to the top-level +. When the type of x is integer, the result of (x+a)+b is guaranteed to be the same as x+(a+b), so instead of reporting an error, the compiler can treat (x+a)+b as x+(a+b). This PR implements this kind of reassociation for integral types, and for the two arithmetic associative/commutative operators: + and *. Reinstate PR153098 one more time with fixes for the issues that came up: - unused variable "lsrc", - use of ‘outer1’ before deduction of ‘auto’.
| 
          
 @llvm/pr-subscribers-flang-semantics Author: Krzysztof Parzyszek (kparzysz) ChangesAn atomic update expression of form This PR implements this kind of reassociation for integral types, and for the two arithmetic associative/commutative operators: + and *. Reinstate PR153098 one more time with fixes for the issues that came up: 
 Patch is 22.32 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/153488.diff 5 Files Affected: 
 diff --git a/flang/lib/Semantics/check-omp-atomic.cpp b/flang/lib/Semantics/check-omp-atomic.cpp
index 0c0e6158485e9..9d92be6327fdb 100644
--- a/flang/lib/Semantics/check-omp-atomic.cpp
+++ b/flang/lib/Semantics/check-omp-atomic.cpp
@@ -13,7 +13,9 @@
 #include "check-omp-structure.h"
 
 #include "flang/Common/indirection.h"
+#include "flang/Common/template.h"
 #include "flang/Evaluate/expression.h"
+#include "flang/Evaluate/match.h"
 #include "flang/Evaluate/rewrite.h"
 #include "flang/Evaluate/tools.h"
 #include "flang/Parser/char-block.h"
@@ -50,6 +52,138 @@ static bool operator!=(const evaluate::Expr<T> &e, const evaluate::Expr<U> &f) {
   return !(e == f);
 }
 
+namespace {
+template <typename...> struct IsIntegral {
+  static constexpr bool value{false};
+};
+
+template <common::TypeCategory C, int K>
+struct IsIntegral<evaluate::Type<C, K>> {
+  static constexpr bool value{//
+      C == common::TypeCategory::Integer ||
+      C == common::TypeCategory::Unsigned ||
+      C == common::TypeCategory::Logical};
+};
+
+template <typename T> constexpr bool is_integral_v{IsIntegral<T>::value};
+
+template <typename T, typename Op0, typename Op1>
+using ReassocOpBase = evaluate::match::AnyOfPattern< //
+    evaluate::match::Add<T, Op0, Op1>, //
+    evaluate::match::Mul<T, Op0, Op1>>;
+
+template <typename T, typename Op0, typename Op1>
+struct ReassocOp : public ReassocOpBase<T, Op0, Op1> {
+  using Base = ReassocOpBase<T, Op0, Op1>;
+  using Base::Base;
+};
+
+template <typename T, typename Op0, typename Op1>
+ReassocOp<T, Op0, Op1> reassocOp(const Op0 &op0, const Op1 &op1) {
+  return ReassocOp<T, Op0, Op1>(op0, op1);
+}
+} // namespace
+
+struct ReassocRewriter : public evaluate::rewrite::Identity {
+  using Id = evaluate::rewrite::Identity;
+  using Id::operator();
+  struct NonIntegralTag {};
+
+  ReassocRewriter(const SomeExpr &atom) : atom_(atom) {}
+
+  // Try to find cases where the input expression is of the form
+  // (1) (a . b) . c, or
+  // (2) a . (b . c),
+  // where . denotes an associative operation (currently + or *), and a, b, c
+  // are some subexpresions.
+  // If one of the operands in the nested operation is the atomic variable
+  // (with some possible type conversions applied to it), bring it to the
+  // top-level operation, and move the top-level operand into the nested
+  // operation.
+  // For example, assuming x is the atomic variable:
+  //   (a + x) + b  ->  (a + b) + x,  i.e. (conceptually) swap x and b.
+  template <typename T, typename U,
+      typename = std::enable_if_t<is_integral_v<T>>>
+  evaluate::Expr<T> operator()(evaluate::Expr<T> &&x, const U &u) {
+    // As per the above comment, there are 3 subexpressions involved in this
+    // transformation. A match::Expr<T> will match evaluate::Expr<U> when T is
+    // same as U, plus it will store a pointer (ref) to the matched expression.
+    // When the match is successful, the sub[i].ref will point to a, b, x (in
+    // some order) from the example above.
+    evaluate::match::Expr<T> sub[3];
+    auto inner{reassocOp<T>(sub[0], sub[1])};
+    auto outer1{reassocOp<T>(inner, sub[2])}; // inner + something
+    auto outer2{reassocOp<T>(sub[2], inner)}; // something + inner
+#if !defined(__clang__) && !defined(_MSC_VER) && \
+      (__GNUC__ < 8 || (__GNUC__ == 8 && __GNUC_MINOR__ < 5))
+    // If GCC version < 8.5, use this definition. For the other definition
+    // (which is equivalent), GCC 7.5 emits a somewhat cryptic error:
+    //    use of ‘outer1’ before deduction of ‘auto’
+    // inside of the visitor function in common::visit.
+    // Since this works with clang, MSVC and at least GCC 8.5, I'm assuming
+    // that this is some kind of a GCC issue.
+    using MatchTypes = std::tuple<evaluate::Add<T>, evaluate::Multiply<T>>;
+#else
+    using MatchTypes = typename decltype(outer1)::MatchTypes;
+#endif
+    // There is no way to ensure that the outer operation is the same as
+    // the inner one. They are matched independently, so we need to compare
+    // the index in the member variant that represents the matched type.
+    if ((match(outer1, x) && outer1.ref.index() == inner.ref.index()) ||
+        (match(outer2, x) && outer2.ref.index() == inner.ref.index())) {
+      size_t atomIdx{[&]() { // sub[atomIdx] will be the atom.
+        size_t idx;
+        for (idx = 0; idx != 3; ++idx) {
+          if (IsAtom(*sub[idx].ref)) {
+            break;
+          }
+        }
+        return idx;
+      }()};
+
+      if (atomIdx > 2) {
+        return Id::operator()(std::move(x), u);
+      }
+      return common::visit(
+          [&](auto &&s) {
+            using Expr = evaluate::Expr<T>;
+            using TypeS = llvm::remove_cvref_t<decltype(s)>;
+            // This visitor has to be semantically correct for all possible
+            // types of s even though at runtime s will only be one of the
+            // matched types.
+            // Limit the construction to the operation types that we tried
+            // to match (otherwise TypeS(op1, op2) would fail for non-binary
+            // operations).
+            if constexpr (common::HasMember<TypeS, MatchTypes>) {
+              Expr atom{*sub[atomIdx].ref};
+              Expr op1{*sub[(atomIdx + 1) % 3].ref};
+              Expr op2{*sub[(atomIdx + 2) % 3].ref};
+              return Expr(
+                  TypeS(atom, Expr(TypeS(std::move(op1), std::move(op2)))));
+            } else {
+              return Expr(TypeS(s));
+            }
+          },
+          evaluate::match::deparen(x).u);
+    }
+    return Id::operator()(std::move(x), u);
+  }
+
+  template <typename T, typename U,
+      typename = std::enable_if_t<!is_integral_v<T>>>
+  evaluate::Expr<T> operator()(
+      evaluate::Expr<T> &&x, const U &u, NonIntegralTag = {}) {
+    return Id::operator()(std::move(x), u);
+  }
+
+private:
+  template <typename T> bool IsAtom(const evaluate::Expr<T> &x) const {
+    return IsSameOrConvertOf(evaluate::AsGenericExpr(AsRvalue(x)), atom_);
+  }
+
+  const SomeExpr &atom_;
+};
+
 struct AnalyzedCondStmt {
   SomeExpr cond{evaluate::NullPointer{}}; // Default ctor is deleted
   parser::CharBlock source;
@@ -199,6 +333,26 @@ static std::pair<parser::CharBlock, parser::CharBlock> SplitAssignmentSource(
   llvm_unreachable("Could not find assignment operator");
 }
 
+static std::vector<SomeExpr> GetNonAtomExpressions(
+    const SomeExpr &atom, const std::vector<SomeExpr> &exprs) {
+  std::vector<SomeExpr> nonAtom;
+  for (const SomeExpr &e : exprs) {
+    if (!IsSameOrConvertOf(e, atom)) {
+      nonAtom.push_back(e);
+    }
+  }
+  return nonAtom;
+}
+
+static std::vector<SomeExpr> GetNonAtomArguments(
+    const SomeExpr &atom, const SomeExpr &expr) {
+  if (auto &&maybe{GetConvertInput(expr)}) {
+    return GetNonAtomExpressions(
+        atom, GetTopLevelOperationIgnoreResizing(*maybe).second);
+  }
+  return {};
+}
+
 static bool IsCheckForAssociated(const SomeExpr &cond) {
   return GetTopLevelOperationIgnoreResizing(cond).first ==
       operation::Operator::Associated;
@@ -576,6 +730,7 @@ void OmpStructureChecker::CheckAtomicCaptureAssignment(
     const evaluate::Assignment &capture, const SomeExpr &atom,
     parser::CharBlock source) {
   auto [lsrc, rsrc]{SplitAssignmentSource(source)};
+  (void)lsrc;
   const SomeExpr &cap{capture.lhs};
 
   if (!IsVarOrFunctionRef(atom)) {
@@ -592,6 +747,7 @@ void OmpStructureChecker::CheckAtomicCaptureAssignment(
 void OmpStructureChecker::CheckAtomicReadAssignment(
     const evaluate::Assignment &read, parser::CharBlock source) {
   auto [lsrc, rsrc]{SplitAssignmentSource(source)};
+  (void)lsrc;
 
   if (auto maybe{GetConvertInput(read.rhs)}) {
     const SomeExpr &atom{*maybe};
@@ -625,7 +781,8 @@ void OmpStructureChecker::CheckAtomicWriteAssignment(
   }
 }
 
-void OmpStructureChecker::CheckAtomicUpdateAssignment(
+std::optional<evaluate::Assignment>
+OmpStructureChecker::CheckAtomicUpdateAssignment(
     const evaluate::Assignment &update, parser::CharBlock source) {
   // [6.0:191:1-7]
   // An update structured block is update-statement, an update statement
@@ -641,14 +798,47 @@ void OmpStructureChecker::CheckAtomicUpdateAssignment(
   if (!IsVarOrFunctionRef(atom)) {
     ErrorShouldBeVariable(atom, rsrc);
     // Skip other checks.
-    return;
+    return std::nullopt;
   }
 
   CheckAtomicVariable(atom, lsrc);
 
+  auto [hasErrors, tryReassoc]{CheckAtomicUpdateAssignmentRhs(
+      atom, update.rhs, source, /*suppressDiagnostics=*/true)};
+
+  if (!hasErrors) {
+    CheckStorageOverlap(atom, GetNonAtomArguments(atom, update.rhs), source);
+    return std::nullopt;
+  } else if (tryReassoc) {
+    ReassocRewriter ra(atom);
+    SomeExpr raRhs{evaluate::rewrite::Mutator(ra)(update.rhs)};
+
+    std::tie(hasErrors, tryReassoc) = CheckAtomicUpdateAssignmentRhs(
+        atom, raRhs, source, /*suppressDiagnostics=*/true);
+    if (!hasErrors) {
+      CheckStorageOverlap(atom, GetNonAtomArguments(atom, raRhs), source);
+
+      evaluate::Assignment raAssign(update);
+      raAssign.rhs = raRhs;
+      return raAssign;
+    }
+  }
+
+  // This is guaranteed to report errors.
+  CheckAtomicUpdateAssignmentRhs(
+      atom, update.rhs, source, /*suppressDiagnostics=*/false);
+  return std::nullopt;
+}
+
+std::pair<bool, bool> OmpStructureChecker::CheckAtomicUpdateAssignmentRhs(
+    const SomeExpr &atom, const SomeExpr &rhs, parser::CharBlock source,
+    bool suppressDiagnostics) {
+  auto [lsrc, rsrc]{SplitAssignmentSource(source)};
+  (void)lsrc;
+
   std::pair<operation::Operator, std::vector<SomeExpr>> top{
       operation::Operator::Unknown, {}};
-  if (auto &&maybeInput{GetConvertInput(update.rhs)}) {
+  if (auto &&maybeInput{GetConvertInput(rhs)}) {
     top = GetTopLevelOperationIgnoreResizing(*maybeInput);
   }
   switch (top.first) {
@@ -665,29 +855,39 @@ void OmpStructureChecker::CheckAtomicUpdateAssignment(
   case operation::Operator::Identity:
     break;
   case operation::Operator::Call:
-    context_.Say(source,
-        "A call to this function is not a valid ATOMIC UPDATE operation"_err_en_US);
-    return;
+    if (!suppressDiagnostics) {
+      context_.Say(source,
+          "A call to this function is not a valid ATOMIC UPDATE operation"_err_en_US);
+    }
+    return std::make_pair(true, false);
   case operation::Operator::Convert:
-    context_.Say(source,
-        "An implicit or explicit type conversion is not a valid ATOMIC UPDATE operation"_err_en_US);
-    return;
+    if (!suppressDiagnostics) {
+      context_.Say(source,
+          "An implicit or explicit type conversion is not a valid ATOMIC UPDATE operation"_err_en_US);
+    }
+    return std::make_pair(true, false);
   case operation::Operator::Intrinsic:
-    context_.Say(source,
-        "This intrinsic function is not a valid ATOMIC UPDATE operation"_err_en_US);
-    return;
+    if (!suppressDiagnostics) {
+      context_.Say(source,
+          "This intrinsic function is not a valid ATOMIC UPDATE operation"_err_en_US);
+    }
+    return std::make_pair(true, false);
   case operation::Operator::Constant:
   case operation::Operator::Unknown:
-    context_.Say(
-        source, "This is not a valid ATOMIC UPDATE operation"_err_en_US);
-    return;
+    if (!suppressDiagnostics) {
+      context_.Say(
+          source, "This is not a valid ATOMIC UPDATE operation"_err_en_US);
+    }
+    return std::make_pair(true, false);
   default:
     assert(
         top.first != operation::Operator::Identity && "Handle this separately");
-    context_.Say(source,
-        "The %s operator is not a valid ATOMIC UPDATE operation"_err_en_US,
-        operation::ToString(top.first));
-    return;
+    if (!suppressDiagnostics) {
+      context_.Say(source,
+          "The %s operator is not a valid ATOMIC UPDATE operation"_err_en_US,
+          operation::ToString(top.first));
+    }
+    return std::make_pair(true, false);
   }
   // Check how many times `atom` occurs as an argument, if it's a subexpression
   // of an argument, and collect the non-atom arguments.
@@ -708,39 +908,48 @@ void OmpStructureChecker::CheckAtomicUpdateAssignment(
     return count;
   }()};
 
-  bool hasError{false};
+  bool hasError{false}, tryReassoc{false};
   if (subExpr) {
-    context_.Say(rsrc,
-        "The atomic variable %s cannot be a proper subexpression of an argument (here: %s) in the update operation"_err_en_US,
-        atom.AsFortran(), subExpr->AsFortran());
+    if (!suppressDiagnostics) {
+      context_.Say(rsrc,
+          "The atomic variable %s cannot be a proper subexpression of an argument (here: %s) in the update operation"_err_en_US,
+          atom.AsFortran(), subExpr->AsFortran());
+    }
     hasError = true;
   }
   if (top.first == operation::Operator::Identity) {
     // This is "x = y".
     assert((atomCount == 0 || atomCount == 1) && "Unexpected count");
     if (atomCount == 0) {
-      context_.Say(rsrc,
-          "The atomic variable %s should appear as an argument in the update operation"_err_en_US,
-          atom.AsFortran());
+      if (!suppressDiagnostics) {
+        context_.Say(rsrc,
+            "The atomic variable %s should appear as an argument in the update operation"_err_en_US,
+            atom.AsFortran());
+      }
       hasError = true;
     }
   } else {
     if (atomCount == 0) {
-      context_.Say(rsrc,
-          "The atomic variable %s should appear as an argument of the top-level %s operator"_err_en_US,
-          atom.AsFortran(), operation::ToString(top.first));
+      if (!suppressDiagnostics) {
+        context_.Say(rsrc,
+            "The atomic variable %s should appear as an argument of the top-level %s operator"_err_en_US,
+            atom.AsFortran(), operation::ToString(top.first));
+      }
+      // If `atom` is a proper subexpression, and it not present as an
+      // argument on its own, reassociation may be able to help.
+      tryReassoc = subExpr.has_value();
       hasError = true;
     } else if (atomCount > 1) {
-      context_.Say(rsrc,
-          "The atomic variable %s should be exactly one of the arguments of the top-level %s operator"_err_en_US,
-          atom.AsFortran(), operation::ToString(top.first));
+      if (!suppressDiagnostics) {
+        context_.Say(rsrc,
+            "The atomic variable %s should be exactly one of the arguments of the top-level %s operator"_err_en_US,
+            atom.AsFortran(), operation::ToString(top.first));
+      }
       hasError = true;
     }
   }
 
-  if (!hasError) {
-    CheckStorageOverlap(atom, nonAtom, source);
-  }
+  return std::make_pair(hasError, tryReassoc);
 }
 
 void OmpStructureChecker::CheckAtomicConditionalUpdateAssignment(
@@ -843,11 +1052,13 @@ void OmpStructureChecker::CheckAtomicUpdateOnly(
     SourcedActionStmt action{GetActionStmt(&body.front())};
     if (auto maybeUpdate{GetEvaluateAssignment(action.stmt)}) {
       const SomeExpr &atom{maybeUpdate->lhs};
-      CheckAtomicUpdateAssignment(*maybeUpdate, action.source);
+      auto maybeAssign{
+          CheckAtomicUpdateAssignment(*maybeUpdate, action.source)};
+      auto &updateAssign{maybeAssign.has_value() ? maybeAssign : maybeUpdate};
 
       using Analysis = parser::OpenMPAtomicConstruct::Analysis;
       x.analysis = AtomicAnalysis(atom)
-                       .addOp0(Analysis::Update, maybeUpdate)
+                       .addOp0(Analysis::Update, updateAssign)
                        .addOp1(Analysis::None);
     } else if (!IsAssignment(action.stmt)) {
       context_.Say(
@@ -963,16 +1174,19 @@ void OmpStructureChecker::CheckAtomicUpdateCapture(
   using Analysis = parser::OpenMPAtomicConstruct::Analysis;
   int action;
 
+  std::optional<evaluate::Assignment> updateAssign{update};
   if (IsMaybeAtomicWrite(update)) {
     action = Analysis::Write;
     CheckAtomicWriteAssignment(update, uact.source);
   } else {
     action = Analysis::Update;
-    CheckAtomicUpdateAssignment(update, uact.source);
+    if (auto &&maybe{CheckAtomicUpdateAssignment(update, uact.source)}) {
+      updateAssign = maybe;
+    }
   }
   CheckAtomicCaptureAssignment(capture, atom, cact.source);
 
-  if (IsPointerAssignment(update) != IsPointerAssignment(capture)) {
+  if (IsPointerAssignment(*updateAssign) != IsPointerAssignment(capture)) {
     context_.Say(cact.source,
         "The update and capture assignments should both be pointer-assignments or both be non-pointer-assignments"_err_en_US);
     return;
@@ -980,12 +1194,12 @@ void OmpStructureChecker::CheckAtomicUpdateCapture(
 
   if (GetActionStmt(&body.front()).stmt == uact.stmt) {
     x.analysis = AtomicAnalysis(atom)
-                     .addOp0(action, update)
+                     .addOp0(action, updateAssign)
                      .addOp1(Analysis::Read, capture);
   } else {
     x.analysis = AtomicAnalysis(atom)
                      .addOp0(Analysis::Read, capture)
-                     .addOp1(action, update);
+                     .addOp1(action, updateAssign);
   }
 }
 
diff --git a/flang/lib/Semantics/check-omp-structure.h b/flang/lib/Semantics/check-omp-structure.h
index 6b33ca6ab583f..a973aee28d0e2 100644
--- a/flang/lib/Semantics/check-omp-structure.h
+++ b/flang/lib/Semantics/check-omp-structure.h
@@ -267,8 +267,10 @@ class OmpStructureChecker
       const evaluate::Assignment &read, parser::CharBlock source);
   void CheckAtomicWriteAssignment(
       const evaluate::Assignment &write, parser::CharBlock source);
-  void CheckAtomicUpdateAssignment(
+  std::optional<evaluate::Assignment> CheckAtomicUpdateAssignment(
       const evaluate::Assignment &update, parser::CharBlock source);
+  std::pair<bool, bool> CheckAtomicUpdateAssignmentRhs(const SomeExpr &atom,
+      const SomeExpr &rhs, parser::CharBlock source, bool suppressDiagnostics);
   void CheckAtomicConditionalUpdateAssignment(const SomeExpr &cond,
       parser::CharBlock condSource, const evaluate::Assignment &assign,
       parser::CharBlock assignSource);
diff --git a/flang/test/Lower/OpenMP/atomic-update-reassoc.f90 b/flang/test/Lower/OpenMP/atomic-update-reassoc.f90
new file mode 100644
index 0000000000000..96ebb56b8d6ca
--- /dev/null
+++ b/flang/test/Lower/OpenMP/atomic-update-reassoc.f90
@@ -0,0 +1,75 @@
+!RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=60 %s -o - | FileCheck %s
+
+subroutine f00(x, y)
+  implicit none
+  integer :: x, y
+
+  !$omp atomic update
+  x = ((x + 1) + y) + 2
+end
+
+!CHECK-LABEL: func.func @_QPf00
+!CHECK: %[[X:[0-9]+]]:2 = hlfir.declare %arg0
+!CHECK: %[[Y:[0-9]+]]:2 = hlfir.declare %arg1
+!CHECK: %c1_i32 = arith.constant 1 : i32
+!CHECK: %[[LOAD_Y:[0-9]+]] = fir.load %[[Y]]#0 : !fir.ref<i32>
+!CHECK: %[[Y_1:[0-9]+]] = arith.addi %c1_i32, %[[LOAD_Y]] : i32
+!CHECK: %c2_i32 = arith.constant 2 : i32
+!CHECK: %[[Y_1_2:[0-9]+]] = arith.addi %[[Y_1]], %c2_i32 : i32
+!CHECK: omp.atomic.update memory_order(relaxed) %[[X]]#0 : !fir.ref<i32> {
+!CHECK: ^bb0(%[[ARG:arg[0-9]+]]: i32):
+!CHECK:   %[[ARG_P:[0-9]+]] = arith.addi %[[ARG]], %[[Y_1_2]] : i32
+!CHECK:   omp.yield(%[[ARG_P]] : i32)
+!CHECK: }
+
+
+subroutine f01(x, y)
+  implicit none
+  real :: x
+  integer :: y
+
+  !$omp atomic update
+  x = (int(x) + y) + 1
+end
+
+!CHECK-LABEL: func.func @_QPf01
+!CHECK: %[[X:[0-9]+]]:2 = hlfir.declare %arg0
+!CHECK: %[[Y:[0-9]+]]:2 = hlfir.declare %arg1
+!CHECK: %[[LOAD_Y:[0-9]+]] = fir.load %[[Y]]#0 : !fir.ref<i32>
+!CHECK: %c1_i32 = arith.constant 1 : i32
+!CHECK: %[[Y_1:[0-9]+]] = arith.addi %[[LOAD_Y]], %c1_i32 : i32
+!CHECK: omp.atomic.update memory_order(relaxed) %[[X]]#0 : !fir.ref<f32> {
+!CHECK: ^bb0(%[[ARG:arg[0-9]+]]: f32):
+!CHECK:   %[[ARG_I:[0-9]+]] = fir.convert %[[ARG]] : (f32) -> i32
+!CHECK:   %[[ARG_P:[0-9]+]] = arith.addi %[[ARG_I]], %[[Y_1]] : i32
+!CHECK:   %[[ARG_F:[0-9]+]] = fir.convert %[[ARG_P]] : (i32) -> f32
+!CHECK:   omp.yield(%[[ARG_F]] : f32)
+!CHECK: }
+
+
+subroutine f02(x, a, b, c)
+  implicit none
+  integer(kind=4) :: x
+  integer(kind=8) :: a, b, c
+
+  !$omp atomic update
+  x = ((b + a) + x) + c
+end
+
+!CHECK-LABEL: func.func @_QPf02
+!CHECK: %[[A:[0-9]+]]:2 = hlfir.declare %arg1
+!CHECK: %[[B:[0-9]+]]:2 = hlfir.declare %arg2
+!CHECK: %[[C:[0-9]+]]:2 = hlfir.declare %arg3
+!CHECK: %[[X:[0-9]+]]:2 = hlfir.declare %arg0
+!CHECK: %[[LOAD_B:[0-9]+]] = fir.load %[[B]]#0 : !fir.ref<i64>
+!CHECK: %[[LOAD_A:[0-9]+]] = fir.load %[[A]]#0 : !fir.ref<i64>
+!CHECK: %[[A_B:[0-9]+]] = arith.addi %[[LOAD_B]], ...
[truncated]
 | 
    
| 
          
 @llvm/pr-subscribers-flang-openmp Author: Krzysztof Parzyszek (kparzysz) ChangesAn atomic update expression of form This PR implements this kind of reassociation for integral types, and for the two arithmetic associative/commutative operators: + and *. Reinstate PR153098 one more time with fixes for the issues that came up: 
 Patch is 22.32 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/153488.diff 5 Files Affected: 
 diff --git a/flang/lib/Semantics/check-omp-atomic.cpp b/flang/lib/Semantics/check-omp-atomic.cpp
index 0c0e6158485e9..9d92be6327fdb 100644
--- a/flang/lib/Semantics/check-omp-atomic.cpp
+++ b/flang/lib/Semantics/check-omp-atomic.cpp
@@ -13,7 +13,9 @@
 #include "check-omp-structure.h"
 
 #include "flang/Common/indirection.h"
+#include "flang/Common/template.h"
 #include "flang/Evaluate/expression.h"
+#include "flang/Evaluate/match.h"
 #include "flang/Evaluate/rewrite.h"
 #include "flang/Evaluate/tools.h"
 #include "flang/Parser/char-block.h"
@@ -50,6 +52,138 @@ static bool operator!=(const evaluate::Expr<T> &e, const evaluate::Expr<U> &f) {
   return !(e == f);
 }
 
+namespace {
+template <typename...> struct IsIntegral {
+  static constexpr bool value{false};
+};
+
+template <common::TypeCategory C, int K>
+struct IsIntegral<evaluate::Type<C, K>> {
+  static constexpr bool value{//
+      C == common::TypeCategory::Integer ||
+      C == common::TypeCategory::Unsigned ||
+      C == common::TypeCategory::Logical};
+};
+
+template <typename T> constexpr bool is_integral_v{IsIntegral<T>::value};
+
+template <typename T, typename Op0, typename Op1>
+using ReassocOpBase = evaluate::match::AnyOfPattern< //
+    evaluate::match::Add<T, Op0, Op1>, //
+    evaluate::match::Mul<T, Op0, Op1>>;
+
+template <typename T, typename Op0, typename Op1>
+struct ReassocOp : public ReassocOpBase<T, Op0, Op1> {
+  using Base = ReassocOpBase<T, Op0, Op1>;
+  using Base::Base;
+};
+
+template <typename T, typename Op0, typename Op1>
+ReassocOp<T, Op0, Op1> reassocOp(const Op0 &op0, const Op1 &op1) {
+  return ReassocOp<T, Op0, Op1>(op0, op1);
+}
+} // namespace
+
+struct ReassocRewriter : public evaluate::rewrite::Identity {
+  using Id = evaluate::rewrite::Identity;
+  using Id::operator();
+  struct NonIntegralTag {};
+
+  ReassocRewriter(const SomeExpr &atom) : atom_(atom) {}
+
+  // Try to find cases where the input expression is of the form
+  // (1) (a . b) . c, or
+  // (2) a . (b . c),
+  // where . denotes an associative operation (currently + or *), and a, b, c
+  // are some subexpresions.
+  // If one of the operands in the nested operation is the atomic variable
+  // (with some possible type conversions applied to it), bring it to the
+  // top-level operation, and move the top-level operand into the nested
+  // operation.
+  // For example, assuming x is the atomic variable:
+  //   (a + x) + b  ->  (a + b) + x,  i.e. (conceptually) swap x and b.
+  template <typename T, typename U,
+      typename = std::enable_if_t<is_integral_v<T>>>
+  evaluate::Expr<T> operator()(evaluate::Expr<T> &&x, const U &u) {
+    // As per the above comment, there are 3 subexpressions involved in this
+    // transformation. A match::Expr<T> will match evaluate::Expr<U> when T is
+    // same as U, plus it will store a pointer (ref) to the matched expression.
+    // When the match is successful, the sub[i].ref will point to a, b, x (in
+    // some order) from the example above.
+    evaluate::match::Expr<T> sub[3];
+    auto inner{reassocOp<T>(sub[0], sub[1])};
+    auto outer1{reassocOp<T>(inner, sub[2])}; // inner + something
+    auto outer2{reassocOp<T>(sub[2], inner)}; // something + inner
+#if !defined(__clang__) && !defined(_MSC_VER) && \
+      (__GNUC__ < 8 || (__GNUC__ == 8 && __GNUC_MINOR__ < 5))
+    // If GCC version < 8.5, use this definition. For the other definition
+    // (which is equivalent), GCC 7.5 emits a somewhat cryptic error:
+    //    use of ‘outer1’ before deduction of ‘auto’
+    // inside of the visitor function in common::visit.
+    // Since this works with clang, MSVC and at least GCC 8.5, I'm assuming
+    // that this is some kind of a GCC issue.
+    using MatchTypes = std::tuple<evaluate::Add<T>, evaluate::Multiply<T>>;
+#else
+    using MatchTypes = typename decltype(outer1)::MatchTypes;
+#endif
+    // There is no way to ensure that the outer operation is the same as
+    // the inner one. They are matched independently, so we need to compare
+    // the index in the member variant that represents the matched type.
+    if ((match(outer1, x) && outer1.ref.index() == inner.ref.index()) ||
+        (match(outer2, x) && outer2.ref.index() == inner.ref.index())) {
+      size_t atomIdx{[&]() { // sub[atomIdx] will be the atom.
+        size_t idx;
+        for (idx = 0; idx != 3; ++idx) {
+          if (IsAtom(*sub[idx].ref)) {
+            break;
+          }
+        }
+        return idx;
+      }()};
+
+      if (atomIdx > 2) {
+        return Id::operator()(std::move(x), u);
+      }
+      return common::visit(
+          [&](auto &&s) {
+            using Expr = evaluate::Expr<T>;
+            using TypeS = llvm::remove_cvref_t<decltype(s)>;
+            // This visitor has to be semantically correct for all possible
+            // types of s even though at runtime s will only be one of the
+            // matched types.
+            // Limit the construction to the operation types that we tried
+            // to match (otherwise TypeS(op1, op2) would fail for non-binary
+            // operations).
+            if constexpr (common::HasMember<TypeS, MatchTypes>) {
+              Expr atom{*sub[atomIdx].ref};
+              Expr op1{*sub[(atomIdx + 1) % 3].ref};
+              Expr op2{*sub[(atomIdx + 2) % 3].ref};
+              return Expr(
+                  TypeS(atom, Expr(TypeS(std::move(op1), std::move(op2)))));
+            } else {
+              return Expr(TypeS(s));
+            }
+          },
+          evaluate::match::deparen(x).u);
+    }
+    return Id::operator()(std::move(x), u);
+  }
+
+  template <typename T, typename U,
+      typename = std::enable_if_t<!is_integral_v<T>>>
+  evaluate::Expr<T> operator()(
+      evaluate::Expr<T> &&x, const U &u, NonIntegralTag = {}) {
+    return Id::operator()(std::move(x), u);
+  }
+
+private:
+  template <typename T> bool IsAtom(const evaluate::Expr<T> &x) const {
+    return IsSameOrConvertOf(evaluate::AsGenericExpr(AsRvalue(x)), atom_);
+  }
+
+  const SomeExpr &atom_;
+};
+
 struct AnalyzedCondStmt {
   SomeExpr cond{evaluate::NullPointer{}}; // Default ctor is deleted
   parser::CharBlock source;
@@ -199,6 +333,26 @@ static std::pair<parser::CharBlock, parser::CharBlock> SplitAssignmentSource(
   llvm_unreachable("Could not find assignment operator");
 }
 
+static std::vector<SomeExpr> GetNonAtomExpressions(
+    const SomeExpr &atom, const std::vector<SomeExpr> &exprs) {
+  std::vector<SomeExpr> nonAtom;
+  for (const SomeExpr &e : exprs) {
+    if (!IsSameOrConvertOf(e, atom)) {
+      nonAtom.push_back(e);
+    }
+  }
+  return nonAtom;
+}
+
+static std::vector<SomeExpr> GetNonAtomArguments(
+    const SomeExpr &atom, const SomeExpr &expr) {
+  if (auto &&maybe{GetConvertInput(expr)}) {
+    return GetNonAtomExpressions(
+        atom, GetTopLevelOperationIgnoreResizing(*maybe).second);
+  }
+  return {};
+}
+
 static bool IsCheckForAssociated(const SomeExpr &cond) {
   return GetTopLevelOperationIgnoreResizing(cond).first ==
       operation::Operator::Associated;
@@ -576,6 +730,7 @@ void OmpStructureChecker::CheckAtomicCaptureAssignment(
     const evaluate::Assignment &capture, const SomeExpr &atom,
     parser::CharBlock source) {
   auto [lsrc, rsrc]{SplitAssignmentSource(source)};
+  (void)lsrc;
   const SomeExpr &cap{capture.lhs};
 
   if (!IsVarOrFunctionRef(atom)) {
@@ -592,6 +747,7 @@ void OmpStructureChecker::CheckAtomicCaptureAssignment(
 void OmpStructureChecker::CheckAtomicReadAssignment(
     const evaluate::Assignment &read, parser::CharBlock source) {
   auto [lsrc, rsrc]{SplitAssignmentSource(source)};
+  (void)lsrc;
 
   if (auto maybe{GetConvertInput(read.rhs)}) {
     const SomeExpr &atom{*maybe};
@@ -625,7 +781,8 @@ void OmpStructureChecker::CheckAtomicWriteAssignment(
   }
 }
 
-void OmpStructureChecker::CheckAtomicUpdateAssignment(
+std::optional<evaluate::Assignment>
+OmpStructureChecker::CheckAtomicUpdateAssignment(
     const evaluate::Assignment &update, parser::CharBlock source) {
   // [6.0:191:1-7]
   // An update structured block is update-statement, an update statement
@@ -641,14 +798,47 @@ void OmpStructureChecker::CheckAtomicUpdateAssignment(
   if (!IsVarOrFunctionRef(atom)) {
     ErrorShouldBeVariable(atom, rsrc);
     // Skip other checks.
-    return;
+    return std::nullopt;
   }
 
   CheckAtomicVariable(atom, lsrc);
 
+  auto [hasErrors, tryReassoc]{CheckAtomicUpdateAssignmentRhs(
+      atom, update.rhs, source, /*suppressDiagnostics=*/true)};
+
+  if (!hasErrors) {
+    CheckStorageOverlap(atom, GetNonAtomArguments(atom, update.rhs), source);
+    return std::nullopt;
+  } else if (tryReassoc) {
+    ReassocRewriter ra(atom);
+    SomeExpr raRhs{evaluate::rewrite::Mutator(ra)(update.rhs)};
+
+    std::tie(hasErrors, tryReassoc) = CheckAtomicUpdateAssignmentRhs(
+        atom, raRhs, source, /*suppressDiagnostics=*/true);
+    if (!hasErrors) {
+      CheckStorageOverlap(atom, GetNonAtomArguments(atom, raRhs), source);
+
+      evaluate::Assignment raAssign(update);
+      raAssign.rhs = raRhs;
+      return raAssign;
+    }
+  }
+
+  // This is guaranteed to report errors.
+  CheckAtomicUpdateAssignmentRhs(
+      atom, update.rhs, source, /*suppressDiagnostics=*/false);
+  return std::nullopt;
+}
+
+std::pair<bool, bool> OmpStructureChecker::CheckAtomicUpdateAssignmentRhs(
+    const SomeExpr &atom, const SomeExpr &rhs, parser::CharBlock source,
+    bool suppressDiagnostics) {
+  auto [lsrc, rsrc]{SplitAssignmentSource(source)};
+  (void)lsrc;
+
   std::pair<operation::Operator, std::vector<SomeExpr>> top{
       operation::Operator::Unknown, {}};
-  if (auto &&maybeInput{GetConvertInput(update.rhs)}) {
+  if (auto &&maybeInput{GetConvertInput(rhs)}) {
     top = GetTopLevelOperationIgnoreResizing(*maybeInput);
   }
   switch (top.first) {
@@ -665,29 +855,39 @@ void OmpStructureChecker::CheckAtomicUpdateAssignment(
   case operation::Operator::Identity:
     break;
   case operation::Operator::Call:
-    context_.Say(source,
-        "A call to this function is not a valid ATOMIC UPDATE operation"_err_en_US);
-    return;
+    if (!suppressDiagnostics) {
+      context_.Say(source,
+          "A call to this function is not a valid ATOMIC UPDATE operation"_err_en_US);
+    }
+    return std::make_pair(true, false);
   case operation::Operator::Convert:
-    context_.Say(source,
-        "An implicit or explicit type conversion is not a valid ATOMIC UPDATE operation"_err_en_US);
-    return;
+    if (!suppressDiagnostics) {
+      context_.Say(source,
+          "An implicit or explicit type conversion is not a valid ATOMIC UPDATE operation"_err_en_US);
+    }
+    return std::make_pair(true, false);
   case operation::Operator::Intrinsic:
-    context_.Say(source,
-        "This intrinsic function is not a valid ATOMIC UPDATE operation"_err_en_US);
-    return;
+    if (!suppressDiagnostics) {
+      context_.Say(source,
+          "This intrinsic function is not a valid ATOMIC UPDATE operation"_err_en_US);
+    }
+    return std::make_pair(true, false);
   case operation::Operator::Constant:
   case operation::Operator::Unknown:
-    context_.Say(
-        source, "This is not a valid ATOMIC UPDATE operation"_err_en_US);
-    return;
+    if (!suppressDiagnostics) {
+      context_.Say(
+          source, "This is not a valid ATOMIC UPDATE operation"_err_en_US);
+    }
+    return std::make_pair(true, false);
   default:
     assert(
         top.first != operation::Operator::Identity && "Handle this separately");
-    context_.Say(source,
-        "The %s operator is not a valid ATOMIC UPDATE operation"_err_en_US,
-        operation::ToString(top.first));
-    return;
+    if (!suppressDiagnostics) {
+      context_.Say(source,
+          "The %s operator is not a valid ATOMIC UPDATE operation"_err_en_US,
+          operation::ToString(top.first));
+    }
+    return std::make_pair(true, false);
   }
   // Check how many times `atom` occurs as an argument, if it's a subexpression
   // of an argument, and collect the non-atom arguments.
@@ -708,39 +908,48 @@ void OmpStructureChecker::CheckAtomicUpdateAssignment(
     return count;
   }()};
 
-  bool hasError{false};
+  bool hasError{false}, tryReassoc{false};
   if (subExpr) {
-    context_.Say(rsrc,
-        "The atomic variable %s cannot be a proper subexpression of an argument (here: %s) in the update operation"_err_en_US,
-        atom.AsFortran(), subExpr->AsFortran());
+    if (!suppressDiagnostics) {
+      context_.Say(rsrc,
+          "The atomic variable %s cannot be a proper subexpression of an argument (here: %s) in the update operation"_err_en_US,
+          atom.AsFortran(), subExpr->AsFortran());
+    }
     hasError = true;
   }
   if (top.first == operation::Operator::Identity) {
     // This is "x = y".
     assert((atomCount == 0 || atomCount == 1) && "Unexpected count");
     if (atomCount == 0) {
-      context_.Say(rsrc,
-          "The atomic variable %s should appear as an argument in the update operation"_err_en_US,
-          atom.AsFortran());
+      if (!suppressDiagnostics) {
+        context_.Say(rsrc,
+            "The atomic variable %s should appear as an argument in the update operation"_err_en_US,
+            atom.AsFortran());
+      }
       hasError = true;
     }
   } else {
     if (atomCount == 0) {
-      context_.Say(rsrc,
-          "The atomic variable %s should appear as an argument of the top-level %s operator"_err_en_US,
-          atom.AsFortran(), operation::ToString(top.first));
+      if (!suppressDiagnostics) {
+        context_.Say(rsrc,
+            "The atomic variable %s should appear as an argument of the top-level %s operator"_err_en_US,
+            atom.AsFortran(), operation::ToString(top.first));
+      }
+      // If `atom` is a proper subexpression, and it not present as an
+      // argument on its own, reassociation may be able to help.
+      tryReassoc = subExpr.has_value();
       hasError = true;
     } else if (atomCount > 1) {
-      context_.Say(rsrc,
-          "The atomic variable %s should be exactly one of the arguments of the top-level %s operator"_err_en_US,
-          atom.AsFortran(), operation::ToString(top.first));
+      if (!suppressDiagnostics) {
+        context_.Say(rsrc,
+            "The atomic variable %s should be exactly one of the arguments of the top-level %s operator"_err_en_US,
+            atom.AsFortran(), operation::ToString(top.first));
+      }
       hasError = true;
     }
   }
 
-  if (!hasError) {
-    CheckStorageOverlap(atom, nonAtom, source);
-  }
+  return std::make_pair(hasError, tryReassoc);
 }
 
 void OmpStructureChecker::CheckAtomicConditionalUpdateAssignment(
@@ -843,11 +1052,13 @@ void OmpStructureChecker::CheckAtomicUpdateOnly(
     SourcedActionStmt action{GetActionStmt(&body.front())};
     if (auto maybeUpdate{GetEvaluateAssignment(action.stmt)}) {
       const SomeExpr &atom{maybeUpdate->lhs};
-      CheckAtomicUpdateAssignment(*maybeUpdate, action.source);
+      auto maybeAssign{
+          CheckAtomicUpdateAssignment(*maybeUpdate, action.source)};
+      auto &updateAssign{maybeAssign.has_value() ? maybeAssign : maybeUpdate};
 
       using Analysis = parser::OpenMPAtomicConstruct::Analysis;
       x.analysis = AtomicAnalysis(atom)
-                       .addOp0(Analysis::Update, maybeUpdate)
+                       .addOp0(Analysis::Update, updateAssign)
                        .addOp1(Analysis::None);
     } else if (!IsAssignment(action.stmt)) {
       context_.Say(
@@ -963,16 +1174,19 @@ void OmpStructureChecker::CheckAtomicUpdateCapture(
   using Analysis = parser::OpenMPAtomicConstruct::Analysis;
   int action;
 
+  std::optional<evaluate::Assignment> updateAssign{update};
   if (IsMaybeAtomicWrite(update)) {
     action = Analysis::Write;
     CheckAtomicWriteAssignment(update, uact.source);
   } else {
     action = Analysis::Update;
-    CheckAtomicUpdateAssignment(update, uact.source);
+    if (auto &&maybe{CheckAtomicUpdateAssignment(update, uact.source)}) {
+      updateAssign = maybe;
+    }
   }
   CheckAtomicCaptureAssignment(capture, atom, cact.source);
 
-  if (IsPointerAssignment(update) != IsPointerAssignment(capture)) {
+  if (IsPointerAssignment(*updateAssign) != IsPointerAssignment(capture)) {
     context_.Say(cact.source,
         "The update and capture assignments should both be pointer-assignments or both be non-pointer-assignments"_err_en_US);
     return;
@@ -980,12 +1194,12 @@ void OmpStructureChecker::CheckAtomicUpdateCapture(
 
   if (GetActionStmt(&body.front()).stmt == uact.stmt) {
     x.analysis = AtomicAnalysis(atom)
-                     .addOp0(action, update)
+                     .addOp0(action, updateAssign)
                      .addOp1(Analysis::Read, capture);
   } else {
     x.analysis = AtomicAnalysis(atom)
                      .addOp0(Analysis::Read, capture)
-                     .addOp1(action, update);
+                     .addOp1(action, updateAssign);
   }
 }
 
diff --git a/flang/lib/Semantics/check-omp-structure.h b/flang/lib/Semantics/check-omp-structure.h
index 6b33ca6ab583f..a973aee28d0e2 100644
--- a/flang/lib/Semantics/check-omp-structure.h
+++ b/flang/lib/Semantics/check-omp-structure.h
@@ -267,8 +267,10 @@ class OmpStructureChecker
       const evaluate::Assignment &read, parser::CharBlock source);
   void CheckAtomicWriteAssignment(
       const evaluate::Assignment &write, parser::CharBlock source);
-  void CheckAtomicUpdateAssignment(
+  std::optional<evaluate::Assignment> CheckAtomicUpdateAssignment(
       const evaluate::Assignment &update, parser::CharBlock source);
+  std::pair<bool, bool> CheckAtomicUpdateAssignmentRhs(const SomeExpr &atom,
+      const SomeExpr &rhs, parser::CharBlock source, bool suppressDiagnostics);
   void CheckAtomicConditionalUpdateAssignment(const SomeExpr &cond,
       parser::CharBlock condSource, const evaluate::Assignment &assign,
       parser::CharBlock assignSource);
diff --git a/flang/test/Lower/OpenMP/atomic-update-reassoc.f90 b/flang/test/Lower/OpenMP/atomic-update-reassoc.f90
new file mode 100644
index 0000000000000..96ebb56b8d6ca
--- /dev/null
+++ b/flang/test/Lower/OpenMP/atomic-update-reassoc.f90
@@ -0,0 +1,75 @@
+!RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=60 %s -o - | FileCheck %s
+
+subroutine f00(x, y)
+  implicit none
+  integer :: x, y
+
+  !$omp atomic update
+  x = ((x + 1) + y) + 2
+end
+
+!CHECK-LABEL: func.func @_QPf00
+!CHECK: %[[X:[0-9]+]]:2 = hlfir.declare %arg0
+!CHECK: %[[Y:[0-9]+]]:2 = hlfir.declare %arg1
+!CHECK: %c1_i32 = arith.constant 1 : i32
+!CHECK: %[[LOAD_Y:[0-9]+]] = fir.load %[[Y]]#0 : !fir.ref<i32>
+!CHECK: %[[Y_1:[0-9]+]] = arith.addi %c1_i32, %[[LOAD_Y]] : i32
+!CHECK: %c2_i32 = arith.constant 2 : i32
+!CHECK: %[[Y_1_2:[0-9]+]] = arith.addi %[[Y_1]], %c2_i32 : i32
+!CHECK: omp.atomic.update memory_order(relaxed) %[[X]]#0 : !fir.ref<i32> {
+!CHECK: ^bb0(%[[ARG:arg[0-9]+]]: i32):
+!CHECK:   %[[ARG_P:[0-9]+]] = arith.addi %[[ARG]], %[[Y_1_2]] : i32
+!CHECK:   omp.yield(%[[ARG_P]] : i32)
+!CHECK: }
+
+
+subroutine f01(x, y)
+  implicit none
+  real :: x
+  integer :: y
+
+  !$omp atomic update
+  x = (int(x) + y) + 1
+end
+
+!CHECK-LABEL: func.func @_QPf01
+!CHECK: %[[X:[0-9]+]]:2 = hlfir.declare %arg0
+!CHECK: %[[Y:[0-9]+]]:2 = hlfir.declare %arg1
+!CHECK: %[[LOAD_Y:[0-9]+]] = fir.load %[[Y]]#0 : !fir.ref<i32>
+!CHECK: %c1_i32 = arith.constant 1 : i32
+!CHECK: %[[Y_1:[0-9]+]] = arith.addi %[[LOAD_Y]], %c1_i32 : i32
+!CHECK: omp.atomic.update memory_order(relaxed) %[[X]]#0 : !fir.ref<f32> {
+!CHECK: ^bb0(%[[ARG:arg[0-9]+]]: f32):
+!CHECK:   %[[ARG_I:[0-9]+]] = fir.convert %[[ARG]] : (f32) -> i32
+!CHECK:   %[[ARG_P:[0-9]+]] = arith.addi %[[ARG_I]], %[[Y_1]] : i32
+!CHECK:   %[[ARG_F:[0-9]+]] = fir.convert %[[ARG_P]] : (i32) -> f32
+!CHECK:   omp.yield(%[[ARG_F]] : f32)
+!CHECK: }
+
+
+subroutine f02(x, a, b, c)
+  implicit none
+  integer(kind=4) :: x
+  integer(kind=8) :: a, b, c
+
+  !$omp atomic update
+  x = ((b + a) + x) + c
+end
+
+!CHECK-LABEL: func.func @_QPf02
+!CHECK: %[[A:[0-9]+]]:2 = hlfir.declare %arg1
+!CHECK: %[[B:[0-9]+]]:2 = hlfir.declare %arg2
+!CHECK: %[[C:[0-9]+]]:2 = hlfir.declare %arg3
+!CHECK: %[[X:[0-9]+]]:2 = hlfir.declare %arg0
+!CHECK: %[[LOAD_B:[0-9]+]] = fir.load %[[B]]#0 : !fir.ref<i64>
+!CHECK: %[[LOAD_A:[0-9]+]] = fir.load %[[A]]#0 : !fir.ref<i64>
+!CHECK: %[[A_B:[0-9]+]] = arith.addi %[[LOAD_B]], ...
[truncated]
 | 
    
| 
          
 @llvm/pr-subscribers-flang-fir-hlfir Author: Krzysztof Parzyszek (kparzysz) ChangesAn atomic update expression of form This PR implements this kind of reassociation for integral types, and for the two arithmetic associative/commutative operators: + and *. Reinstate PR153098 one more time with fixes for the issues that came up: 
 Patch is 22.32 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/153488.diff 5 Files Affected: 
 diff --git a/flang/lib/Semantics/check-omp-atomic.cpp b/flang/lib/Semantics/check-omp-atomic.cpp
index 0c0e6158485e9..9d92be6327fdb 100644
--- a/flang/lib/Semantics/check-omp-atomic.cpp
+++ b/flang/lib/Semantics/check-omp-atomic.cpp
@@ -13,7 +13,9 @@
 #include "check-omp-structure.h"
 
 #include "flang/Common/indirection.h"
+#include "flang/Common/template.h"
 #include "flang/Evaluate/expression.h"
+#include "flang/Evaluate/match.h"
 #include "flang/Evaluate/rewrite.h"
 #include "flang/Evaluate/tools.h"
 #include "flang/Parser/char-block.h"
@@ -50,6 +52,138 @@ static bool operator!=(const evaluate::Expr<T> &e, const evaluate::Expr<U> &f) {
   return !(e == f);
 }
 
+namespace {
+template <typename...> struct IsIntegral {
+  static constexpr bool value{false};
+};
+
+template <common::TypeCategory C, int K>
+struct IsIntegral<evaluate::Type<C, K>> {
+  static constexpr bool value{//
+      C == common::TypeCategory::Integer ||
+      C == common::TypeCategory::Unsigned ||
+      C == common::TypeCategory::Logical};
+};
+
+template <typename T> constexpr bool is_integral_v{IsIntegral<T>::value};
+
+template <typename T, typename Op0, typename Op1>
+using ReassocOpBase = evaluate::match::AnyOfPattern< //
+    evaluate::match::Add<T, Op0, Op1>, //
+    evaluate::match::Mul<T, Op0, Op1>>;
+
+template <typename T, typename Op0, typename Op1>
+struct ReassocOp : public ReassocOpBase<T, Op0, Op1> {
+  using Base = ReassocOpBase<T, Op0, Op1>;
+  using Base::Base;
+};
+
+template <typename T, typename Op0, typename Op1>
+ReassocOp<T, Op0, Op1> reassocOp(const Op0 &op0, const Op1 &op1) {
+  return ReassocOp<T, Op0, Op1>(op0, op1);
+}
+} // namespace
+
+struct ReassocRewriter : public evaluate::rewrite::Identity {
+  using Id = evaluate::rewrite::Identity;
+  using Id::operator();
+  struct NonIntegralTag {};
+
+  ReassocRewriter(const SomeExpr &atom) : atom_(atom) {}
+
+  // Try to find cases where the input expression is of the form
+  // (1) (a . b) . c, or
+  // (2) a . (b . c),
+  // where . denotes an associative operation (currently + or *), and a, b, c
+  // are some subexpresions.
+  // If one of the operands in the nested operation is the atomic variable
+  // (with some possible type conversions applied to it), bring it to the
+  // top-level operation, and move the top-level operand into the nested
+  // operation.
+  // For example, assuming x is the atomic variable:
+  //   (a + x) + b  ->  (a + b) + x,  i.e. (conceptually) swap x and b.
+  template <typename T, typename U,
+      typename = std::enable_if_t<is_integral_v<T>>>
+  evaluate::Expr<T> operator()(evaluate::Expr<T> &&x, const U &u) {
+    // As per the above comment, there are 3 subexpressions involved in this
+    // transformation. A match::Expr<T> will match evaluate::Expr<U> when T is
+    // same as U, plus it will store a pointer (ref) to the matched expression.
+    // When the match is successful, the sub[i].ref will point to a, b, x (in
+    // some order) from the example above.
+    evaluate::match::Expr<T> sub[3];
+    auto inner{reassocOp<T>(sub[0], sub[1])};
+    auto outer1{reassocOp<T>(inner, sub[2])}; // inner + something
+    auto outer2{reassocOp<T>(sub[2], inner)}; // something + inner
+#if !defined(__clang__) && !defined(_MSC_VER) && \
+      (__GNUC__ < 8 || (__GNUC__ == 8 && __GNUC_MINOR__ < 5))
+    // If GCC version < 8.5, use this definition. For the other definition
+    // (which is equivalent), GCC 7.5 emits a somewhat cryptic error:
+    //    use of ‘outer1’ before deduction of ‘auto’
+    // inside of the visitor function in common::visit.
+    // Since this works with clang, MSVC and at least GCC 8.5, I'm assuming
+    // that this is some kind of a GCC issue.
+    using MatchTypes = std::tuple<evaluate::Add<T>, evaluate::Multiply<T>>;
+#else
+    using MatchTypes = typename decltype(outer1)::MatchTypes;
+#endif
+    // There is no way to ensure that the outer operation is the same as
+    // the inner one. They are matched independently, so we need to compare
+    // the index in the member variant that represents the matched type.
+    if ((match(outer1, x) && outer1.ref.index() == inner.ref.index()) ||
+        (match(outer2, x) && outer2.ref.index() == inner.ref.index())) {
+      size_t atomIdx{[&]() { // sub[atomIdx] will be the atom.
+        size_t idx;
+        for (idx = 0; idx != 3; ++idx) {
+          if (IsAtom(*sub[idx].ref)) {
+            break;
+          }
+        }
+        return idx;
+      }()};
+
+      if (atomIdx > 2) {
+        return Id::operator()(std::move(x), u);
+      }
+      return common::visit(
+          [&](auto &&s) {
+            using Expr = evaluate::Expr<T>;
+            using TypeS = llvm::remove_cvref_t<decltype(s)>;
+            // This visitor has to be semantically correct for all possible
+            // types of s even though at runtime s will only be one of the
+            // matched types.
+            // Limit the construction to the operation types that we tried
+            // to match (otherwise TypeS(op1, op2) would fail for non-binary
+            // operations).
+            if constexpr (common::HasMember<TypeS, MatchTypes>) {
+              Expr atom{*sub[atomIdx].ref};
+              Expr op1{*sub[(atomIdx + 1) % 3].ref};
+              Expr op2{*sub[(atomIdx + 2) % 3].ref};
+              return Expr(
+                  TypeS(atom, Expr(TypeS(std::move(op1), std::move(op2)))));
+            } else {
+              return Expr(TypeS(s));
+            }
+          },
+          evaluate::match::deparen(x).u);
+    }
+    return Id::operator()(std::move(x), u);
+  }
+
+  template <typename T, typename U,
+      typename = std::enable_if_t<!is_integral_v<T>>>
+  evaluate::Expr<T> operator()(
+      evaluate::Expr<T> &&x, const U &u, NonIntegralTag = {}) {
+    return Id::operator()(std::move(x), u);
+  }
+
+private:
+  template <typename T> bool IsAtom(const evaluate::Expr<T> &x) const {
+    return IsSameOrConvertOf(evaluate::AsGenericExpr(AsRvalue(x)), atom_);
+  }
+
+  const SomeExpr &atom_;
+};
+
 struct AnalyzedCondStmt {
   SomeExpr cond{evaluate::NullPointer{}}; // Default ctor is deleted
   parser::CharBlock source;
@@ -199,6 +333,26 @@ static std::pair<parser::CharBlock, parser::CharBlock> SplitAssignmentSource(
   llvm_unreachable("Could not find assignment operator");
 }
 
+static std::vector<SomeExpr> GetNonAtomExpressions(
+    const SomeExpr &atom, const std::vector<SomeExpr> &exprs) {
+  std::vector<SomeExpr> nonAtom;
+  for (const SomeExpr &e : exprs) {
+    if (!IsSameOrConvertOf(e, atom)) {
+      nonAtom.push_back(e);
+    }
+  }
+  return nonAtom;
+}
+
+static std::vector<SomeExpr> GetNonAtomArguments(
+    const SomeExpr &atom, const SomeExpr &expr) {
+  if (auto &&maybe{GetConvertInput(expr)}) {
+    return GetNonAtomExpressions(
+        atom, GetTopLevelOperationIgnoreResizing(*maybe).second);
+  }
+  return {};
+}
+
 static bool IsCheckForAssociated(const SomeExpr &cond) {
   return GetTopLevelOperationIgnoreResizing(cond).first ==
       operation::Operator::Associated;
@@ -576,6 +730,7 @@ void OmpStructureChecker::CheckAtomicCaptureAssignment(
     const evaluate::Assignment &capture, const SomeExpr &atom,
     parser::CharBlock source) {
   auto [lsrc, rsrc]{SplitAssignmentSource(source)};
+  (void)lsrc;
   const SomeExpr &cap{capture.lhs};
 
   if (!IsVarOrFunctionRef(atom)) {
@@ -592,6 +747,7 @@ void OmpStructureChecker::CheckAtomicCaptureAssignment(
 void OmpStructureChecker::CheckAtomicReadAssignment(
     const evaluate::Assignment &read, parser::CharBlock source) {
   auto [lsrc, rsrc]{SplitAssignmentSource(source)};
+  (void)lsrc;
 
   if (auto maybe{GetConvertInput(read.rhs)}) {
     const SomeExpr &atom{*maybe};
@@ -625,7 +781,8 @@ void OmpStructureChecker::CheckAtomicWriteAssignment(
   }
 }
 
-void OmpStructureChecker::CheckAtomicUpdateAssignment(
+std::optional<evaluate::Assignment>
+OmpStructureChecker::CheckAtomicUpdateAssignment(
     const evaluate::Assignment &update, parser::CharBlock source) {
   // [6.0:191:1-7]
   // An update structured block is update-statement, an update statement
@@ -641,14 +798,47 @@ void OmpStructureChecker::CheckAtomicUpdateAssignment(
   if (!IsVarOrFunctionRef(atom)) {
     ErrorShouldBeVariable(atom, rsrc);
     // Skip other checks.
-    return;
+    return std::nullopt;
   }
 
   CheckAtomicVariable(atom, lsrc);
 
+  auto [hasErrors, tryReassoc]{CheckAtomicUpdateAssignmentRhs(
+      atom, update.rhs, source, /*suppressDiagnostics=*/true)};
+
+  if (!hasErrors) {
+    CheckStorageOverlap(atom, GetNonAtomArguments(atom, update.rhs), source);
+    return std::nullopt;
+  } else if (tryReassoc) {
+    ReassocRewriter ra(atom);
+    SomeExpr raRhs{evaluate::rewrite::Mutator(ra)(update.rhs)};
+
+    std::tie(hasErrors, tryReassoc) = CheckAtomicUpdateAssignmentRhs(
+        atom, raRhs, source, /*suppressDiagnostics=*/true);
+    if (!hasErrors) {
+      CheckStorageOverlap(atom, GetNonAtomArguments(atom, raRhs), source);
+
+      evaluate::Assignment raAssign(update);
+      raAssign.rhs = raRhs;
+      return raAssign;
+    }
+  }
+
+  // This is guaranteed to report errors.
+  CheckAtomicUpdateAssignmentRhs(
+      atom, update.rhs, source, /*suppressDiagnostics=*/false);
+  return std::nullopt;
+}
+
+std::pair<bool, bool> OmpStructureChecker::CheckAtomicUpdateAssignmentRhs(
+    const SomeExpr &atom, const SomeExpr &rhs, parser::CharBlock source,
+    bool suppressDiagnostics) {
+  auto [lsrc, rsrc]{SplitAssignmentSource(source)};
+  (void)lsrc;
+
   std::pair<operation::Operator, std::vector<SomeExpr>> top{
       operation::Operator::Unknown, {}};
-  if (auto &&maybeInput{GetConvertInput(update.rhs)}) {
+  if (auto &&maybeInput{GetConvertInput(rhs)}) {
     top = GetTopLevelOperationIgnoreResizing(*maybeInput);
   }
   switch (top.first) {
@@ -665,29 +855,39 @@ void OmpStructureChecker::CheckAtomicUpdateAssignment(
   case operation::Operator::Identity:
     break;
   case operation::Operator::Call:
-    context_.Say(source,
-        "A call to this function is not a valid ATOMIC UPDATE operation"_err_en_US);
-    return;
+    if (!suppressDiagnostics) {
+      context_.Say(source,
+          "A call to this function is not a valid ATOMIC UPDATE operation"_err_en_US);
+    }
+    return std::make_pair(true, false);
   case operation::Operator::Convert:
-    context_.Say(source,
-        "An implicit or explicit type conversion is not a valid ATOMIC UPDATE operation"_err_en_US);
-    return;
+    if (!suppressDiagnostics) {
+      context_.Say(source,
+          "An implicit or explicit type conversion is not a valid ATOMIC UPDATE operation"_err_en_US);
+    }
+    return std::make_pair(true, false);
   case operation::Operator::Intrinsic:
-    context_.Say(source,
-        "This intrinsic function is not a valid ATOMIC UPDATE operation"_err_en_US);
-    return;
+    if (!suppressDiagnostics) {
+      context_.Say(source,
+          "This intrinsic function is not a valid ATOMIC UPDATE operation"_err_en_US);
+    }
+    return std::make_pair(true, false);
   case operation::Operator::Constant:
   case operation::Operator::Unknown:
-    context_.Say(
-        source, "This is not a valid ATOMIC UPDATE operation"_err_en_US);
-    return;
+    if (!suppressDiagnostics) {
+      context_.Say(
+          source, "This is not a valid ATOMIC UPDATE operation"_err_en_US);
+    }
+    return std::make_pair(true, false);
   default:
     assert(
         top.first != operation::Operator::Identity && "Handle this separately");
-    context_.Say(source,
-        "The %s operator is not a valid ATOMIC UPDATE operation"_err_en_US,
-        operation::ToString(top.first));
-    return;
+    if (!suppressDiagnostics) {
+      context_.Say(source,
+          "The %s operator is not a valid ATOMIC UPDATE operation"_err_en_US,
+          operation::ToString(top.first));
+    }
+    return std::make_pair(true, false);
   }
   // Check how many times `atom` occurs as an argument, if it's a subexpression
   // of an argument, and collect the non-atom arguments.
@@ -708,39 +908,48 @@ void OmpStructureChecker::CheckAtomicUpdateAssignment(
     return count;
   }()};
 
-  bool hasError{false};
+  bool hasError{false}, tryReassoc{false};
   if (subExpr) {
-    context_.Say(rsrc,
-        "The atomic variable %s cannot be a proper subexpression of an argument (here: %s) in the update operation"_err_en_US,
-        atom.AsFortran(), subExpr->AsFortran());
+    if (!suppressDiagnostics) {
+      context_.Say(rsrc,
+          "The atomic variable %s cannot be a proper subexpression of an argument (here: %s) in the update operation"_err_en_US,
+          atom.AsFortran(), subExpr->AsFortran());
+    }
     hasError = true;
   }
   if (top.first == operation::Operator::Identity) {
     // This is "x = y".
     assert((atomCount == 0 || atomCount == 1) && "Unexpected count");
     if (atomCount == 0) {
-      context_.Say(rsrc,
-          "The atomic variable %s should appear as an argument in the update operation"_err_en_US,
-          atom.AsFortran());
+      if (!suppressDiagnostics) {
+        context_.Say(rsrc,
+            "The atomic variable %s should appear as an argument in the update operation"_err_en_US,
+            atom.AsFortran());
+      }
       hasError = true;
     }
   } else {
     if (atomCount == 0) {
-      context_.Say(rsrc,
-          "The atomic variable %s should appear as an argument of the top-level %s operator"_err_en_US,
-          atom.AsFortran(), operation::ToString(top.first));
+      if (!suppressDiagnostics) {
+        context_.Say(rsrc,
+            "The atomic variable %s should appear as an argument of the top-level %s operator"_err_en_US,
+            atom.AsFortran(), operation::ToString(top.first));
+      }
+      // If `atom` is a proper subexpression, and it not present as an
+      // argument on its own, reassociation may be able to help.
+      tryReassoc = subExpr.has_value();
       hasError = true;
     } else if (atomCount > 1) {
-      context_.Say(rsrc,
-          "The atomic variable %s should be exactly one of the arguments of the top-level %s operator"_err_en_US,
-          atom.AsFortran(), operation::ToString(top.first));
+      if (!suppressDiagnostics) {
+        context_.Say(rsrc,
+            "The atomic variable %s should be exactly one of the arguments of the top-level %s operator"_err_en_US,
+            atom.AsFortran(), operation::ToString(top.first));
+      }
       hasError = true;
     }
   }
 
-  if (!hasError) {
-    CheckStorageOverlap(atom, nonAtom, source);
-  }
+  return std::make_pair(hasError, tryReassoc);
 }
 
 void OmpStructureChecker::CheckAtomicConditionalUpdateAssignment(
@@ -843,11 +1052,13 @@ void OmpStructureChecker::CheckAtomicUpdateOnly(
     SourcedActionStmt action{GetActionStmt(&body.front())};
     if (auto maybeUpdate{GetEvaluateAssignment(action.stmt)}) {
       const SomeExpr &atom{maybeUpdate->lhs};
-      CheckAtomicUpdateAssignment(*maybeUpdate, action.source);
+      auto maybeAssign{
+          CheckAtomicUpdateAssignment(*maybeUpdate, action.source)};
+      auto &updateAssign{maybeAssign.has_value() ? maybeAssign : maybeUpdate};
 
       using Analysis = parser::OpenMPAtomicConstruct::Analysis;
       x.analysis = AtomicAnalysis(atom)
-                       .addOp0(Analysis::Update, maybeUpdate)
+                       .addOp0(Analysis::Update, updateAssign)
                        .addOp1(Analysis::None);
     } else if (!IsAssignment(action.stmt)) {
       context_.Say(
@@ -963,16 +1174,19 @@ void OmpStructureChecker::CheckAtomicUpdateCapture(
   using Analysis = parser::OpenMPAtomicConstruct::Analysis;
   int action;
 
+  std::optional<evaluate::Assignment> updateAssign{update};
   if (IsMaybeAtomicWrite(update)) {
     action = Analysis::Write;
     CheckAtomicWriteAssignment(update, uact.source);
   } else {
     action = Analysis::Update;
-    CheckAtomicUpdateAssignment(update, uact.source);
+    if (auto &&maybe{CheckAtomicUpdateAssignment(update, uact.source)}) {
+      updateAssign = maybe;
+    }
   }
   CheckAtomicCaptureAssignment(capture, atom, cact.source);
 
-  if (IsPointerAssignment(update) != IsPointerAssignment(capture)) {
+  if (IsPointerAssignment(*updateAssign) != IsPointerAssignment(capture)) {
     context_.Say(cact.source,
         "The update and capture assignments should both be pointer-assignments or both be non-pointer-assignments"_err_en_US);
     return;
@@ -980,12 +1194,12 @@ void OmpStructureChecker::CheckAtomicUpdateCapture(
 
   if (GetActionStmt(&body.front()).stmt == uact.stmt) {
     x.analysis = AtomicAnalysis(atom)
-                     .addOp0(action, update)
+                     .addOp0(action, updateAssign)
                      .addOp1(Analysis::Read, capture);
   } else {
     x.analysis = AtomicAnalysis(atom)
                      .addOp0(Analysis::Read, capture)
-                     .addOp1(action, update);
+                     .addOp1(action, updateAssign);
   }
 }
 
diff --git a/flang/lib/Semantics/check-omp-structure.h b/flang/lib/Semantics/check-omp-structure.h
index 6b33ca6ab583f..a973aee28d0e2 100644
--- a/flang/lib/Semantics/check-omp-structure.h
+++ b/flang/lib/Semantics/check-omp-structure.h
@@ -267,8 +267,10 @@ class OmpStructureChecker
       const evaluate::Assignment &read, parser::CharBlock source);
   void CheckAtomicWriteAssignment(
       const evaluate::Assignment &write, parser::CharBlock source);
-  void CheckAtomicUpdateAssignment(
+  std::optional<evaluate::Assignment> CheckAtomicUpdateAssignment(
       const evaluate::Assignment &update, parser::CharBlock source);
+  std::pair<bool, bool> CheckAtomicUpdateAssignmentRhs(const SomeExpr &atom,
+      const SomeExpr &rhs, parser::CharBlock source, bool suppressDiagnostics);
   void CheckAtomicConditionalUpdateAssignment(const SomeExpr &cond,
       parser::CharBlock condSource, const evaluate::Assignment &assign,
       parser::CharBlock assignSource);
diff --git a/flang/test/Lower/OpenMP/atomic-update-reassoc.f90 b/flang/test/Lower/OpenMP/atomic-update-reassoc.f90
new file mode 100644
index 0000000000000..96ebb56b8d6ca
--- /dev/null
+++ b/flang/test/Lower/OpenMP/atomic-update-reassoc.f90
@@ -0,0 +1,75 @@
+!RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=60 %s -o - | FileCheck %s
+
+subroutine f00(x, y)
+  implicit none
+  integer :: x, y
+
+  !$omp atomic update
+  x = ((x + 1) + y) + 2
+end
+
+!CHECK-LABEL: func.func @_QPf00
+!CHECK: %[[X:[0-9]+]]:2 = hlfir.declare %arg0
+!CHECK: %[[Y:[0-9]+]]:2 = hlfir.declare %arg1
+!CHECK: %c1_i32 = arith.constant 1 : i32
+!CHECK: %[[LOAD_Y:[0-9]+]] = fir.load %[[Y]]#0 : !fir.ref<i32>
+!CHECK: %[[Y_1:[0-9]+]] = arith.addi %c1_i32, %[[LOAD_Y]] : i32
+!CHECK: %c2_i32 = arith.constant 2 : i32
+!CHECK: %[[Y_1_2:[0-9]+]] = arith.addi %[[Y_1]], %c2_i32 : i32
+!CHECK: omp.atomic.update memory_order(relaxed) %[[X]]#0 : !fir.ref<i32> {
+!CHECK: ^bb0(%[[ARG:arg[0-9]+]]: i32):
+!CHECK:   %[[ARG_P:[0-9]+]] = arith.addi %[[ARG]], %[[Y_1_2]] : i32
+!CHECK:   omp.yield(%[[ARG_P]] : i32)
+!CHECK: }
+
+
+subroutine f01(x, y)
+  implicit none
+  real :: x
+  integer :: y
+
+  !$omp atomic update
+  x = (int(x) + y) + 1
+end
+
+!CHECK-LABEL: func.func @_QPf01
+!CHECK: %[[X:[0-9]+]]:2 = hlfir.declare %arg0
+!CHECK: %[[Y:[0-9]+]]:2 = hlfir.declare %arg1
+!CHECK: %[[LOAD_Y:[0-9]+]] = fir.load %[[Y]]#0 : !fir.ref<i32>
+!CHECK: %c1_i32 = arith.constant 1 : i32
+!CHECK: %[[Y_1:[0-9]+]] = arith.addi %[[LOAD_Y]], %c1_i32 : i32
+!CHECK: omp.atomic.update memory_order(relaxed) %[[X]]#0 : !fir.ref<f32> {
+!CHECK: ^bb0(%[[ARG:arg[0-9]+]]: f32):
+!CHECK:   %[[ARG_I:[0-9]+]] = fir.convert %[[ARG]] : (f32) -> i32
+!CHECK:   %[[ARG_P:[0-9]+]] = arith.addi %[[ARG_I]], %[[Y_1]] : i32
+!CHECK:   %[[ARG_F:[0-9]+]] = fir.convert %[[ARG_P]] : (i32) -> f32
+!CHECK:   omp.yield(%[[ARG_F]] : f32)
+!CHECK: }
+
+
+subroutine f02(x, a, b, c)
+  implicit none
+  integer(kind=4) :: x
+  integer(kind=8) :: a, b, c
+
+  !$omp atomic update
+  x = ((b + a) + x) + c
+end
+
+!CHECK-LABEL: func.func @_QPf02
+!CHECK: %[[A:[0-9]+]]:2 = hlfir.declare %arg1
+!CHECK: %[[B:[0-9]+]]:2 = hlfir.declare %arg2
+!CHECK: %[[C:[0-9]+]]:2 = hlfir.declare %arg3
+!CHECK: %[[X:[0-9]+]]:2 = hlfir.declare %arg0
+!CHECK: %[[LOAD_B:[0-9]+]] = fir.load %[[B]]#0 : !fir.ref<i64>
+!CHECK: %[[LOAD_A:[0-9]+]] = fir.load %[[A]]#0 : !fir.ref<i64>
+!CHECK: %[[A_B:[0-9]+]] = arith.addi %[[LOAD_B]], ...
[truncated]
 | 
    
| 
          
 ✅ With the latest revision this PR passed the C/C++ code formatter.  | 
    
PR #153488 caused the msvc build (https://lab.llvm.org/buildbot/#/builders/166/builds/1397) to fail: ``` ..\llvm-project\flang\include\flang/Evaluate/rewrite.h(78): error C2668: 'Fortran::evaluate::rewrite::Identity::operator ()': ambiguous call to overloaded function ..\llvm-project\flang\include\flang/Evaluate/rewrite.h(43): note: could be 'Fortran::evaluate::Expr<Fortran::evaluate::SomeType> Fortran::evaluate::rewrite::Identity::operator ()<Fortran::evaluate::SomeType,S>(Fortran::evaluate::Expr<Fortran::evaluate::SomeType> &&,const U &)' with [ S=Fortran::evaluate::value::Integer<128,true,32,unsigned int,unsigned __int64,128>, U=Fortran::evaluate::value::Integer<128,true,32,unsigned int,unsigned __int64,128> ] ..\llvm-project\flang\lib\Semantics\check-omp-atomic.cpp(174): note: or 'Fortran::evaluate::Expr<Fortran::evaluate::SomeType> Fortran::semantics::ReassocRewriter::operator ()<Fortran::evaluate::SomeType,S,void>(Fortran::evaluate::Expr<Fortran::evaluate::SomeType> &&,const U &,Fortran::semantics::ReassocRewriter::NonIntegralTag)' with [ S=Fortran::evaluate::value::Integer<128,true,32,unsigned int,unsigned __int64,128>, U=Fortran::evaluate::value::Integer<128,true,32,unsigned int,unsigned __int64,128> ] ..\llvm-project\flang\include\flang/Evaluate/rewrite.h(78): note: while trying to match the argument list '(Fortran::evaluate::Expr<Fortran::evaluate::SomeType>, const S)' with [ S=Fortran::evaluate::value::Integer<128,true,32,unsigned int,unsigned __int64,128> ] ..\llvm-project\flang\include\flang/Evaluate/rewrite.h(78): note: the template instantiation context (the oldest one first) is ..\llvm-project\flang\lib\Semantics\check-omp-atomic.cpp(814): note: see reference to function template instantiation 'U Fortran::evaluate::rewrite::Mutator<Fortran::semantics::ReassocRewriter>::operator ()<const Fortran::evaluate::Expr<Fortran::evaluate::SomeType>&,Fortran::evaluate::Expr<Fortran::evaluate::SomeType>>(T)' being compiled with [ U=Fortran::evaluate::Expr<Fortran::evaluate::SomeType>, T=const Fortran::evaluate::Expr<Fortran::evaluate::SomeType> & ] ``` The reason is that there is an ambiguity between operator() of ReassocRewriter itself and operator() of the base class `Identity` through `using Id::operator();`. By the C++ specification, method declarations in ReassocRewriter hide methods with the same signature from a using declaration, but this does not apply to ``` evaluate::Expr<T> operator()(..., NonIntegralTag = {}) ``` which has a different signature due to an additional tag parameter. Since it has a default value, it is ambiguous with operator() without tag parameter. GCC and Clang both accept this, but in my understanding MSVC is correct here. Since the overloads of ReassocRewriter cover all cases (integral and non-integral), removing the using declaration to avoid the ambiguity.
| 
           I fixed a compilation break in 38853a0. I would recommend using   | 
    
PR llvm#153488 caused the msvc build (https://lab.llvm.org/buildbot/#/builders/166/builds/1397) to fail: ``` ..\llvm-project\flang\include\flang/Evaluate/rewrite.h(78): error C2668: 'Fortran::evaluate::rewrite::Identity::operator ()': ambiguous call to overloaded function ..\llvm-project\flang\include\flang/Evaluate/rewrite.h(43): note: could be 'Fortran::evaluate::Expr<Fortran::evaluate::SomeType> Fortran::evaluate::rewrite::Identity::operator ()<Fortran::evaluate::SomeType,S>(Fortran::evaluate::Expr<Fortran::evaluate::SomeType> &&,const U &)' with [ S=Fortran::evaluate::value::Integer<128,true,32,unsigned int,unsigned __int64,128>, U=Fortran::evaluate::value::Integer<128,true,32,unsigned int,unsigned __int64,128> ] ..\llvm-project\flang\lib\Semantics\check-omp-atomic.cpp(174): note: or 'Fortran::evaluate::Expr<Fortran::evaluate::SomeType> Fortran::semantics::ReassocRewriter::operator ()<Fortran::evaluate::SomeType,S,void>(Fortran::evaluate::Expr<Fortran::evaluate::SomeType> &&,const U &,Fortran::semantics::ReassocRewriter::NonIntegralTag)' with [ S=Fortran::evaluate::value::Integer<128,true,32,unsigned int,unsigned __int64,128>, U=Fortran::evaluate::value::Integer<128,true,32,unsigned int,unsigned __int64,128> ] ..\llvm-project\flang\include\flang/Evaluate/rewrite.h(78): note: while trying to match the argument list '(Fortran::evaluate::Expr<Fortran::evaluate::SomeType>, const S)' with [ S=Fortran::evaluate::value::Integer<128,true,32,unsigned int,unsigned __int64,128> ] ..\llvm-project\flang\include\flang/Evaluate/rewrite.h(78): note: the template instantiation context (the oldest one first) is ..\llvm-project\flang\lib\Semantics\check-omp-atomic.cpp(814): note: see reference to function template instantiation 'U Fortran::evaluate::rewrite::Mutator<Fortran::semantics::ReassocRewriter>::operator ()<const Fortran::evaluate::Expr<Fortran::evaluate::SomeType>&,Fortran::evaluate::Expr<Fortran::evaluate::SomeType>>(T)' being compiled with [ U=Fortran::evaluate::Expr<Fortran::evaluate::SomeType>, T=const Fortran::evaluate::Expr<Fortran::evaluate::SomeType> & ] ``` The reason is that there is an ambiguite between operator() of ReassocRewriter itself as operator() of the base class Identity through `using Id::operator();`. By the C++ specification, method declarations in ReassocRewriter hide methods with the same signature from a using declaration, but this does not apply to ``` evaluate::Expr<T> operator()(..., NonIntegralTag = {}) ``` which has a different signature due to an additional tag parameter. Since it has a default value, it is ambiguous with operator() without tag parameter. GCC and Clang both accept this, but in my understanding MSVC is correct here. Since the overloads of ReassocRewriter cover all cases, remopving the using reclarations to avoid the ambiguity.
An atomic update expression of form
x = x + a + b
is technically illegal, since the right-hand side is parsed as (x+a)+b, and the atomic variable x should be an argument to the top-level +. When the type of x is integer, the result of (x+a)+b is guaranteed to be the same as x+(a+b), so instead of reporting an error, the compiler can treat (x+a)+b as x+(a+b).
This PR implements this kind of reassociation for integral types, and for the two arithmetic associative/commutative operators: + and *.
Reinstate PR153098 one more time with fixes for the issues that came up: