Skip to content

Conversation

@kparzysz
Copy link
Contributor

An atomic update expression of form
x = x + a + b
is technically illegal, since the right-hand side is parsed as (x+a)+b, and the atomic variable x should be an argument to the top-level +. When the type of x is integer, the result of (x+a)+b is guaranteed to be the same as x+(a+b), so instead of reporting an error, the compiler can treat (x+a)+b as x+(a+b).

This PR implements this kind of reassociation for integral types, and for the two arithmetic associative/commutative operators: + and *.

Reinstate PR153098 one more time with fixes for the issues that came up:

  • unused variable "lsrc",
  • use of ‘outer1’ before deduction of ‘auto’.

An atomic update expression of form
  x = x + a + b
is technically illegal, since the right-hand side is parsed as (x+a)+b,
and the atomic variable x should be an argument to the top-level +. When
the type of x is integer, the result of (x+a)+b is guaranteed to be the
same as x+(a+b), so instead of reporting an error, the compiler can
treat (x+a)+b as x+(a+b).

This PR implements this kind of reassociation for integral types, and
for the two arithmetic associative/commutative operators: + and *.

Reinstate PR153098 one more time with fixes for the issues that came up:
- unused variable "lsrc",
- use of ‘outer1’ before deduction of ‘auto’.
@llvmbot llvmbot added flang Flang issues not falling into any other category flang:fir-hlfir flang:openmp flang:semantics labels Aug 13, 2025
@llvmbot
Copy link
Member

llvmbot commented Aug 13, 2025

@llvm/pr-subscribers-flang-semantics

Author: Krzysztof Parzyszek (kparzysz)

Changes

An atomic update expression of form
x = x + a + b
is technically illegal, since the right-hand side is parsed as (x+a)+b, and the atomic variable x should be an argument to the top-level +. When the type of x is integer, the result of (x+a)+b is guaranteed to be the same as x+(a+b), so instead of reporting an error, the compiler can treat (x+a)+b as x+(a+b).

This PR implements this kind of reassociation for integral types, and for the two arithmetic associative/commutative operators: + and *.

Reinstate PR153098 one more time with fixes for the issues that came up:

  • unused variable "lsrc",
  • use of ‘outer1’ before deduction of ‘auto’.

Patch is 22.32 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/153488.diff

5 Files Affected:

  • (modified) flang/lib/Semantics/check-omp-atomic.cpp (+255-41)
  • (modified) flang/lib/Semantics/check-omp-structure.h (+3-1)
  • (added) flang/test/Lower/OpenMP/atomic-update-reassoc.f90 (+75)
  • (modified) flang/test/Semantics/OpenMP/atomic-update-only.f90 (+9-2)
  • (modified) flang/test/Semantics/OpenMP/atomic04.f90 (+1-2)
diff --git a/flang/lib/Semantics/check-omp-atomic.cpp b/flang/lib/Semantics/check-omp-atomic.cpp
index 0c0e6158485e9..9d92be6327fdb 100644
--- a/flang/lib/Semantics/check-omp-atomic.cpp
+++ b/flang/lib/Semantics/check-omp-atomic.cpp
@@ -13,7 +13,9 @@
 #include "check-omp-structure.h"
 
 #include "flang/Common/indirection.h"
+#include "flang/Common/template.h"
 #include "flang/Evaluate/expression.h"
+#include "flang/Evaluate/match.h"
 #include "flang/Evaluate/rewrite.h"
 #include "flang/Evaluate/tools.h"
 #include "flang/Parser/char-block.h"
@@ -50,6 +52,138 @@ static bool operator!=(const evaluate::Expr<T> &e, const evaluate::Expr<U> &f) {
   return !(e == f);
 }
 
+namespace {
+template <typename...> struct IsIntegral {
+  static constexpr bool value{false};
+};
+
+template <common::TypeCategory C, int K>
+struct IsIntegral<evaluate::Type<C, K>> {
+  static constexpr bool value{//
+      C == common::TypeCategory::Integer ||
+      C == common::TypeCategory::Unsigned ||
+      C == common::TypeCategory::Logical};
+};
+
+template <typename T> constexpr bool is_integral_v{IsIntegral<T>::value};
+
+template <typename T, typename Op0, typename Op1>
+using ReassocOpBase = evaluate::match::AnyOfPattern< //
+    evaluate::match::Add<T, Op0, Op1>, //
+    evaluate::match::Mul<T, Op0, Op1>>;
+
+template <typename T, typename Op0, typename Op1>
+struct ReassocOp : public ReassocOpBase<T, Op0, Op1> {
+  using Base = ReassocOpBase<T, Op0, Op1>;
+  using Base::Base;
+};
+
+template <typename T, typename Op0, typename Op1>
+ReassocOp<T, Op0, Op1> reassocOp(const Op0 &op0, const Op1 &op1) {
+  return ReassocOp<T, Op0, Op1>(op0, op1);
+}
+} // namespace
+
+struct ReassocRewriter : public evaluate::rewrite::Identity {
+  using Id = evaluate::rewrite::Identity;
+  using Id::operator();
+  struct NonIntegralTag {};
+
+  ReassocRewriter(const SomeExpr &atom) : atom_(atom) {}
+
+  // Try to find cases where the input expression is of the form
+  // (1) (a . b) . c, or
+  // (2) a . (b . c),
+  // where . denotes an associative operation (currently + or *), and a, b, c
+  // are some subexpresions.
+  // If one of the operands in the nested operation is the atomic variable
+  // (with some possible type conversions applied to it), bring it to the
+  // top-level operation, and move the top-level operand into the nested
+  // operation.
+  // For example, assuming x is the atomic variable:
+  //   (a + x) + b  ->  (a + b) + x,  i.e. (conceptually) swap x and b.
+  template <typename T, typename U,
+      typename = std::enable_if_t<is_integral_v<T>>>
+  evaluate::Expr<T> operator()(evaluate::Expr<T> &&x, const U &u) {
+    // As per the above comment, there are 3 subexpressions involved in this
+    // transformation. A match::Expr<T> will match evaluate::Expr<U> when T is
+    // same as U, plus it will store a pointer (ref) to the matched expression.
+    // When the match is successful, the sub[i].ref will point to a, b, x (in
+    // some order) from the example above.
+    evaluate::match::Expr<T> sub[3];
+    auto inner{reassocOp<T>(sub[0], sub[1])};
+    auto outer1{reassocOp<T>(inner, sub[2])}; // inner + something
+    auto outer2{reassocOp<T>(sub[2], inner)}; // something + inner
+#if !defined(__clang__) && !defined(_MSC_VER) && \
+      (__GNUC__ < 8 || (__GNUC__ == 8 && __GNUC_MINOR__ < 5))
+    // If GCC version < 8.5, use this definition. For the other definition
+    // (which is equivalent), GCC 7.5 emits a somewhat cryptic error:
+    //    use of ‘outer1’ before deduction of ‘auto’
+    // inside of the visitor function in common::visit.
+    // Since this works with clang, MSVC and at least GCC 8.5, I'm assuming
+    // that this is some kind of a GCC issue.
+    using MatchTypes = std::tuple<evaluate::Add<T>, evaluate::Multiply<T>>;
+#else
+    using MatchTypes = typename decltype(outer1)::MatchTypes;
+#endif
+    // There is no way to ensure that the outer operation is the same as
+    // the inner one. They are matched independently, so we need to compare
+    // the index in the member variant that represents the matched type.
+    if ((match(outer1, x) && outer1.ref.index() == inner.ref.index()) ||
+        (match(outer2, x) && outer2.ref.index() == inner.ref.index())) {
+      size_t atomIdx{[&]() { // sub[atomIdx] will be the atom.
+        size_t idx;
+        for (idx = 0; idx != 3; ++idx) {
+          if (IsAtom(*sub[idx].ref)) {
+            break;
+          }
+        }
+        return idx;
+      }()};
+
+      if (atomIdx > 2) {
+        return Id::operator()(std::move(x), u);
+      }
+      return common::visit(
+          [&](auto &&s) {
+            using Expr = evaluate::Expr<T>;
+            using TypeS = llvm::remove_cvref_t<decltype(s)>;
+            // This visitor has to be semantically correct for all possible
+            // types of s even though at runtime s will only be one of the
+            // matched types.
+            // Limit the construction to the operation types that we tried
+            // to match (otherwise TypeS(op1, op2) would fail for non-binary
+            // operations).
+            if constexpr (common::HasMember<TypeS, MatchTypes>) {
+              Expr atom{*sub[atomIdx].ref};
+              Expr op1{*sub[(atomIdx + 1) % 3].ref};
+              Expr op2{*sub[(atomIdx + 2) % 3].ref};
+              return Expr(
+                  TypeS(atom, Expr(TypeS(std::move(op1), std::move(op2)))));
+            } else {
+              return Expr(TypeS(s));
+            }
+          },
+          evaluate::match::deparen(x).u);
+    }
+    return Id::operator()(std::move(x), u);
+  }
+
+  template <typename T, typename U,
+      typename = std::enable_if_t<!is_integral_v<T>>>
+  evaluate::Expr<T> operator()(
+      evaluate::Expr<T> &&x, const U &u, NonIntegralTag = {}) {
+    return Id::operator()(std::move(x), u);
+  }
+
+private:
+  template <typename T> bool IsAtom(const evaluate::Expr<T> &x) const {
+    return IsSameOrConvertOf(evaluate::AsGenericExpr(AsRvalue(x)), atom_);
+  }
+
+  const SomeExpr &atom_;
+};
+
 struct AnalyzedCondStmt {
   SomeExpr cond{evaluate::NullPointer{}}; // Default ctor is deleted
   parser::CharBlock source;
@@ -199,6 +333,26 @@ static std::pair<parser::CharBlock, parser::CharBlock> SplitAssignmentSource(
   llvm_unreachable("Could not find assignment operator");
 }
 
+static std::vector<SomeExpr> GetNonAtomExpressions(
+    const SomeExpr &atom, const std::vector<SomeExpr> &exprs) {
+  std::vector<SomeExpr> nonAtom;
+  for (const SomeExpr &e : exprs) {
+    if (!IsSameOrConvertOf(e, atom)) {
+      nonAtom.push_back(e);
+    }
+  }
+  return nonAtom;
+}
+
+static std::vector<SomeExpr> GetNonAtomArguments(
+    const SomeExpr &atom, const SomeExpr &expr) {
+  if (auto &&maybe{GetConvertInput(expr)}) {
+    return GetNonAtomExpressions(
+        atom, GetTopLevelOperationIgnoreResizing(*maybe).second);
+  }
+  return {};
+}
+
 static bool IsCheckForAssociated(const SomeExpr &cond) {
   return GetTopLevelOperationIgnoreResizing(cond).first ==
       operation::Operator::Associated;
@@ -576,6 +730,7 @@ void OmpStructureChecker::CheckAtomicCaptureAssignment(
     const evaluate::Assignment &capture, const SomeExpr &atom,
     parser::CharBlock source) {
   auto [lsrc, rsrc]{SplitAssignmentSource(source)};
+  (void)lsrc;
   const SomeExpr &cap{capture.lhs};
 
   if (!IsVarOrFunctionRef(atom)) {
@@ -592,6 +747,7 @@ void OmpStructureChecker::CheckAtomicCaptureAssignment(
 void OmpStructureChecker::CheckAtomicReadAssignment(
     const evaluate::Assignment &read, parser::CharBlock source) {
   auto [lsrc, rsrc]{SplitAssignmentSource(source)};
+  (void)lsrc;
 
   if (auto maybe{GetConvertInput(read.rhs)}) {
     const SomeExpr &atom{*maybe};
@@ -625,7 +781,8 @@ void OmpStructureChecker::CheckAtomicWriteAssignment(
   }
 }
 
-void OmpStructureChecker::CheckAtomicUpdateAssignment(
+std::optional<evaluate::Assignment>
+OmpStructureChecker::CheckAtomicUpdateAssignment(
     const evaluate::Assignment &update, parser::CharBlock source) {
   // [6.0:191:1-7]
   // An update structured block is update-statement, an update statement
@@ -641,14 +798,47 @@ void OmpStructureChecker::CheckAtomicUpdateAssignment(
   if (!IsVarOrFunctionRef(atom)) {
     ErrorShouldBeVariable(atom, rsrc);
     // Skip other checks.
-    return;
+    return std::nullopt;
   }
 
   CheckAtomicVariable(atom, lsrc);
 
+  auto [hasErrors, tryReassoc]{CheckAtomicUpdateAssignmentRhs(
+      atom, update.rhs, source, /*suppressDiagnostics=*/true)};
+
+  if (!hasErrors) {
+    CheckStorageOverlap(atom, GetNonAtomArguments(atom, update.rhs), source);
+    return std::nullopt;
+  } else if (tryReassoc) {
+    ReassocRewriter ra(atom);
+    SomeExpr raRhs{evaluate::rewrite::Mutator(ra)(update.rhs)};
+
+    std::tie(hasErrors, tryReassoc) = CheckAtomicUpdateAssignmentRhs(
+        atom, raRhs, source, /*suppressDiagnostics=*/true);
+    if (!hasErrors) {
+      CheckStorageOverlap(atom, GetNonAtomArguments(atom, raRhs), source);
+
+      evaluate::Assignment raAssign(update);
+      raAssign.rhs = raRhs;
+      return raAssign;
+    }
+  }
+
+  // This is guaranteed to report errors.
+  CheckAtomicUpdateAssignmentRhs(
+      atom, update.rhs, source, /*suppressDiagnostics=*/false);
+  return std::nullopt;
+}
+
+std::pair<bool, bool> OmpStructureChecker::CheckAtomicUpdateAssignmentRhs(
+    const SomeExpr &atom, const SomeExpr &rhs, parser::CharBlock source,
+    bool suppressDiagnostics) {
+  auto [lsrc, rsrc]{SplitAssignmentSource(source)};
+  (void)lsrc;
+
   std::pair<operation::Operator, std::vector<SomeExpr>> top{
       operation::Operator::Unknown, {}};
-  if (auto &&maybeInput{GetConvertInput(update.rhs)}) {
+  if (auto &&maybeInput{GetConvertInput(rhs)}) {
     top = GetTopLevelOperationIgnoreResizing(*maybeInput);
   }
   switch (top.first) {
@@ -665,29 +855,39 @@ void OmpStructureChecker::CheckAtomicUpdateAssignment(
   case operation::Operator::Identity:
     break;
   case operation::Operator::Call:
-    context_.Say(source,
-        "A call to this function is not a valid ATOMIC UPDATE operation"_err_en_US);
-    return;
+    if (!suppressDiagnostics) {
+      context_.Say(source,
+          "A call to this function is not a valid ATOMIC UPDATE operation"_err_en_US);
+    }
+    return std::make_pair(true, false);
   case operation::Operator::Convert:
-    context_.Say(source,
-        "An implicit or explicit type conversion is not a valid ATOMIC UPDATE operation"_err_en_US);
-    return;
+    if (!suppressDiagnostics) {
+      context_.Say(source,
+          "An implicit or explicit type conversion is not a valid ATOMIC UPDATE operation"_err_en_US);
+    }
+    return std::make_pair(true, false);
   case operation::Operator::Intrinsic:
-    context_.Say(source,
-        "This intrinsic function is not a valid ATOMIC UPDATE operation"_err_en_US);
-    return;
+    if (!suppressDiagnostics) {
+      context_.Say(source,
+          "This intrinsic function is not a valid ATOMIC UPDATE operation"_err_en_US);
+    }
+    return std::make_pair(true, false);
   case operation::Operator::Constant:
   case operation::Operator::Unknown:
-    context_.Say(
-        source, "This is not a valid ATOMIC UPDATE operation"_err_en_US);
-    return;
+    if (!suppressDiagnostics) {
+      context_.Say(
+          source, "This is not a valid ATOMIC UPDATE operation"_err_en_US);
+    }
+    return std::make_pair(true, false);
   default:
     assert(
         top.first != operation::Operator::Identity && "Handle this separately");
-    context_.Say(source,
-        "The %s operator is not a valid ATOMIC UPDATE operation"_err_en_US,
-        operation::ToString(top.first));
-    return;
+    if (!suppressDiagnostics) {
+      context_.Say(source,
+          "The %s operator is not a valid ATOMIC UPDATE operation"_err_en_US,
+          operation::ToString(top.first));
+    }
+    return std::make_pair(true, false);
   }
   // Check how many times `atom` occurs as an argument, if it's a subexpression
   // of an argument, and collect the non-atom arguments.
@@ -708,39 +908,48 @@ void OmpStructureChecker::CheckAtomicUpdateAssignment(
     return count;
   }()};
 
-  bool hasError{false};
+  bool hasError{false}, tryReassoc{false};
   if (subExpr) {
-    context_.Say(rsrc,
-        "The atomic variable %s cannot be a proper subexpression of an argument (here: %s) in the update operation"_err_en_US,
-        atom.AsFortran(), subExpr->AsFortran());
+    if (!suppressDiagnostics) {
+      context_.Say(rsrc,
+          "The atomic variable %s cannot be a proper subexpression of an argument (here: %s) in the update operation"_err_en_US,
+          atom.AsFortran(), subExpr->AsFortran());
+    }
     hasError = true;
   }
   if (top.first == operation::Operator::Identity) {
     // This is "x = y".
     assert((atomCount == 0 || atomCount == 1) && "Unexpected count");
     if (atomCount == 0) {
-      context_.Say(rsrc,
-          "The atomic variable %s should appear as an argument in the update operation"_err_en_US,
-          atom.AsFortran());
+      if (!suppressDiagnostics) {
+        context_.Say(rsrc,
+            "The atomic variable %s should appear as an argument in the update operation"_err_en_US,
+            atom.AsFortran());
+      }
       hasError = true;
     }
   } else {
     if (atomCount == 0) {
-      context_.Say(rsrc,
-          "The atomic variable %s should appear as an argument of the top-level %s operator"_err_en_US,
-          atom.AsFortran(), operation::ToString(top.first));
+      if (!suppressDiagnostics) {
+        context_.Say(rsrc,
+            "The atomic variable %s should appear as an argument of the top-level %s operator"_err_en_US,
+            atom.AsFortran(), operation::ToString(top.first));
+      }
+      // If `atom` is a proper subexpression, and it not present as an
+      // argument on its own, reassociation may be able to help.
+      tryReassoc = subExpr.has_value();
       hasError = true;
     } else if (atomCount > 1) {
-      context_.Say(rsrc,
-          "The atomic variable %s should be exactly one of the arguments of the top-level %s operator"_err_en_US,
-          atom.AsFortran(), operation::ToString(top.first));
+      if (!suppressDiagnostics) {
+        context_.Say(rsrc,
+            "The atomic variable %s should be exactly one of the arguments of the top-level %s operator"_err_en_US,
+            atom.AsFortran(), operation::ToString(top.first));
+      }
       hasError = true;
     }
   }
 
-  if (!hasError) {
-    CheckStorageOverlap(atom, nonAtom, source);
-  }
+  return std::make_pair(hasError, tryReassoc);
 }
 
 void OmpStructureChecker::CheckAtomicConditionalUpdateAssignment(
@@ -843,11 +1052,13 @@ void OmpStructureChecker::CheckAtomicUpdateOnly(
     SourcedActionStmt action{GetActionStmt(&body.front())};
     if (auto maybeUpdate{GetEvaluateAssignment(action.stmt)}) {
       const SomeExpr &atom{maybeUpdate->lhs};
-      CheckAtomicUpdateAssignment(*maybeUpdate, action.source);
+      auto maybeAssign{
+          CheckAtomicUpdateAssignment(*maybeUpdate, action.source)};
+      auto &updateAssign{maybeAssign.has_value() ? maybeAssign : maybeUpdate};
 
       using Analysis = parser::OpenMPAtomicConstruct::Analysis;
       x.analysis = AtomicAnalysis(atom)
-                       .addOp0(Analysis::Update, maybeUpdate)
+                       .addOp0(Analysis::Update, updateAssign)
                        .addOp1(Analysis::None);
     } else if (!IsAssignment(action.stmt)) {
       context_.Say(
@@ -963,16 +1174,19 @@ void OmpStructureChecker::CheckAtomicUpdateCapture(
   using Analysis = parser::OpenMPAtomicConstruct::Analysis;
   int action;
 
+  std::optional<evaluate::Assignment> updateAssign{update};
   if (IsMaybeAtomicWrite(update)) {
     action = Analysis::Write;
     CheckAtomicWriteAssignment(update, uact.source);
   } else {
     action = Analysis::Update;
-    CheckAtomicUpdateAssignment(update, uact.source);
+    if (auto &&maybe{CheckAtomicUpdateAssignment(update, uact.source)}) {
+      updateAssign = maybe;
+    }
   }
   CheckAtomicCaptureAssignment(capture, atom, cact.source);
 
-  if (IsPointerAssignment(update) != IsPointerAssignment(capture)) {
+  if (IsPointerAssignment(*updateAssign) != IsPointerAssignment(capture)) {
     context_.Say(cact.source,
         "The update and capture assignments should both be pointer-assignments or both be non-pointer-assignments"_err_en_US);
     return;
@@ -980,12 +1194,12 @@ void OmpStructureChecker::CheckAtomicUpdateCapture(
 
   if (GetActionStmt(&body.front()).stmt == uact.stmt) {
     x.analysis = AtomicAnalysis(atom)
-                     .addOp0(action, update)
+                     .addOp0(action, updateAssign)
                      .addOp1(Analysis::Read, capture);
   } else {
     x.analysis = AtomicAnalysis(atom)
                      .addOp0(Analysis::Read, capture)
-                     .addOp1(action, update);
+                     .addOp1(action, updateAssign);
   }
 }
 
diff --git a/flang/lib/Semantics/check-omp-structure.h b/flang/lib/Semantics/check-omp-structure.h
index 6b33ca6ab583f..a973aee28d0e2 100644
--- a/flang/lib/Semantics/check-omp-structure.h
+++ b/flang/lib/Semantics/check-omp-structure.h
@@ -267,8 +267,10 @@ class OmpStructureChecker
       const evaluate::Assignment &read, parser::CharBlock source);
   void CheckAtomicWriteAssignment(
       const evaluate::Assignment &write, parser::CharBlock source);
-  void CheckAtomicUpdateAssignment(
+  std::optional<evaluate::Assignment> CheckAtomicUpdateAssignment(
       const evaluate::Assignment &update, parser::CharBlock source);
+  std::pair<bool, bool> CheckAtomicUpdateAssignmentRhs(const SomeExpr &atom,
+      const SomeExpr &rhs, parser::CharBlock source, bool suppressDiagnostics);
   void CheckAtomicConditionalUpdateAssignment(const SomeExpr &cond,
       parser::CharBlock condSource, const evaluate::Assignment &assign,
       parser::CharBlock assignSource);
diff --git a/flang/test/Lower/OpenMP/atomic-update-reassoc.f90 b/flang/test/Lower/OpenMP/atomic-update-reassoc.f90
new file mode 100644
index 0000000000000..96ebb56b8d6ca
--- /dev/null
+++ b/flang/test/Lower/OpenMP/atomic-update-reassoc.f90
@@ -0,0 +1,75 @@
+!RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=60 %s -o - | FileCheck %s
+
+subroutine f00(x, y)
+  implicit none
+  integer :: x, y
+
+  !$omp atomic update
+  x = ((x + 1) + y) + 2
+end
+
+!CHECK-LABEL: func.func @_QPf00
+!CHECK: %[[X:[0-9]+]]:2 = hlfir.declare %arg0
+!CHECK: %[[Y:[0-9]+]]:2 = hlfir.declare %arg1
+!CHECK: %c1_i32 = arith.constant 1 : i32
+!CHECK: %[[LOAD_Y:[0-9]+]] = fir.load %[[Y]]#0 : !fir.ref<i32>
+!CHECK: %[[Y_1:[0-9]+]] = arith.addi %c1_i32, %[[LOAD_Y]] : i32
+!CHECK: %c2_i32 = arith.constant 2 : i32
+!CHECK: %[[Y_1_2:[0-9]+]] = arith.addi %[[Y_1]], %c2_i32 : i32
+!CHECK: omp.atomic.update memory_order(relaxed) %[[X]]#0 : !fir.ref<i32> {
+!CHECK: ^bb0(%[[ARG:arg[0-9]+]]: i32):
+!CHECK:   %[[ARG_P:[0-9]+]] = arith.addi %[[ARG]], %[[Y_1_2]] : i32
+!CHECK:   omp.yield(%[[ARG_P]] : i32)
+!CHECK: }
+
+
+subroutine f01(x, y)
+  implicit none
+  real :: x
+  integer :: y
+
+  !$omp atomic update
+  x = (int(x) + y) + 1
+end
+
+!CHECK-LABEL: func.func @_QPf01
+!CHECK: %[[X:[0-9]+]]:2 = hlfir.declare %arg0
+!CHECK: %[[Y:[0-9]+]]:2 = hlfir.declare %arg1
+!CHECK: %[[LOAD_Y:[0-9]+]] = fir.load %[[Y]]#0 : !fir.ref<i32>
+!CHECK: %c1_i32 = arith.constant 1 : i32
+!CHECK: %[[Y_1:[0-9]+]] = arith.addi %[[LOAD_Y]], %c1_i32 : i32
+!CHECK: omp.atomic.update memory_order(relaxed) %[[X]]#0 : !fir.ref<f32> {
+!CHECK: ^bb0(%[[ARG:arg[0-9]+]]: f32):
+!CHECK:   %[[ARG_I:[0-9]+]] = fir.convert %[[ARG]] : (f32) -> i32
+!CHECK:   %[[ARG_P:[0-9]+]] = arith.addi %[[ARG_I]], %[[Y_1]] : i32
+!CHECK:   %[[ARG_F:[0-9]+]] = fir.convert %[[ARG_P]] : (i32) -> f32
+!CHECK:   omp.yield(%[[ARG_F]] : f32)
+!CHECK: }
+
+
+subroutine f02(x, a, b, c)
+  implicit none
+  integer(kind=4) :: x
+  integer(kind=8) :: a, b, c
+
+  !$omp atomic update
+  x = ((b + a) + x) + c
+end
+
+!CHECK-LABEL: func.func @_QPf02
+!CHECK: %[[A:[0-9]+]]:2 = hlfir.declare %arg1
+!CHECK: %[[B:[0-9]+]]:2 = hlfir.declare %arg2
+!CHECK: %[[C:[0-9]+]]:2 = hlfir.declare %arg3
+!CHECK: %[[X:[0-9]+]]:2 = hlfir.declare %arg0
+!CHECK: %[[LOAD_B:[0-9]+]] = fir.load %[[B]]#0 : !fir.ref<i64>
+!CHECK: %[[LOAD_A:[0-9]+]] = fir.load %[[A]]#0 : !fir.ref<i64>
+!CHECK: %[[A_B:[0-9]+]] = arith.addi %[[LOAD_B]], ...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Aug 13, 2025

@llvm/pr-subscribers-flang-openmp

Author: Krzysztof Parzyszek (kparzysz)

Changes

An atomic update expression of form
x = x + a + b
is technically illegal, since the right-hand side is parsed as (x+a)+b, and the atomic variable x should be an argument to the top-level +. When the type of x is integer, the result of (x+a)+b is guaranteed to be the same as x+(a+b), so instead of reporting an error, the compiler can treat (x+a)+b as x+(a+b).

This PR implements this kind of reassociation for integral types, and for the two arithmetic associative/commutative operators: + and *.

Reinstate PR153098 one more time with fixes for the issues that came up:

  • unused variable "lsrc",
  • use of ‘outer1’ before deduction of ‘auto’.

Patch is 22.32 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/153488.diff

5 Files Affected:

  • (modified) flang/lib/Semantics/check-omp-atomic.cpp (+255-41)
  • (modified) flang/lib/Semantics/check-omp-structure.h (+3-1)
  • (added) flang/test/Lower/OpenMP/atomic-update-reassoc.f90 (+75)
  • (modified) flang/test/Semantics/OpenMP/atomic-update-only.f90 (+9-2)
  • (modified) flang/test/Semantics/OpenMP/atomic04.f90 (+1-2)
diff --git a/flang/lib/Semantics/check-omp-atomic.cpp b/flang/lib/Semantics/check-omp-atomic.cpp
index 0c0e6158485e9..9d92be6327fdb 100644
--- a/flang/lib/Semantics/check-omp-atomic.cpp
+++ b/flang/lib/Semantics/check-omp-atomic.cpp
@@ -13,7 +13,9 @@
 #include "check-omp-structure.h"
 
 #include "flang/Common/indirection.h"
+#include "flang/Common/template.h"
 #include "flang/Evaluate/expression.h"
+#include "flang/Evaluate/match.h"
 #include "flang/Evaluate/rewrite.h"
 #include "flang/Evaluate/tools.h"
 #include "flang/Parser/char-block.h"
@@ -50,6 +52,138 @@ static bool operator!=(const evaluate::Expr<T> &e, const evaluate::Expr<U> &f) {
   return !(e == f);
 }
 
+namespace {
+template <typename...> struct IsIntegral {
+  static constexpr bool value{false};
+};
+
+template <common::TypeCategory C, int K>
+struct IsIntegral<evaluate::Type<C, K>> {
+  static constexpr bool value{//
+      C == common::TypeCategory::Integer ||
+      C == common::TypeCategory::Unsigned ||
+      C == common::TypeCategory::Logical};
+};
+
+template <typename T> constexpr bool is_integral_v{IsIntegral<T>::value};
+
+template <typename T, typename Op0, typename Op1>
+using ReassocOpBase = evaluate::match::AnyOfPattern< //
+    evaluate::match::Add<T, Op0, Op1>, //
+    evaluate::match::Mul<T, Op0, Op1>>;
+
+template <typename T, typename Op0, typename Op1>
+struct ReassocOp : public ReassocOpBase<T, Op0, Op1> {
+  using Base = ReassocOpBase<T, Op0, Op1>;
+  using Base::Base;
+};
+
+template <typename T, typename Op0, typename Op1>
+ReassocOp<T, Op0, Op1> reassocOp(const Op0 &op0, const Op1 &op1) {
+  return ReassocOp<T, Op0, Op1>(op0, op1);
+}
+} // namespace
+
+struct ReassocRewriter : public evaluate::rewrite::Identity {
+  using Id = evaluate::rewrite::Identity;
+  using Id::operator();
+  struct NonIntegralTag {};
+
+  ReassocRewriter(const SomeExpr &atom) : atom_(atom) {}
+
+  // Try to find cases where the input expression is of the form
+  // (1) (a . b) . c, or
+  // (2) a . (b . c),
+  // where . denotes an associative operation (currently + or *), and a, b, c
+  // are some subexpresions.
+  // If one of the operands in the nested operation is the atomic variable
+  // (with some possible type conversions applied to it), bring it to the
+  // top-level operation, and move the top-level operand into the nested
+  // operation.
+  // For example, assuming x is the atomic variable:
+  //   (a + x) + b  ->  (a + b) + x,  i.e. (conceptually) swap x and b.
+  template <typename T, typename U,
+      typename = std::enable_if_t<is_integral_v<T>>>
+  evaluate::Expr<T> operator()(evaluate::Expr<T> &&x, const U &u) {
+    // As per the above comment, there are 3 subexpressions involved in this
+    // transformation. A match::Expr<T> will match evaluate::Expr<U> when T is
+    // same as U, plus it will store a pointer (ref) to the matched expression.
+    // When the match is successful, the sub[i].ref will point to a, b, x (in
+    // some order) from the example above.
+    evaluate::match::Expr<T> sub[3];
+    auto inner{reassocOp<T>(sub[0], sub[1])};
+    auto outer1{reassocOp<T>(inner, sub[2])}; // inner + something
+    auto outer2{reassocOp<T>(sub[2], inner)}; // something + inner
+#if !defined(__clang__) && !defined(_MSC_VER) && \
+      (__GNUC__ < 8 || (__GNUC__ == 8 && __GNUC_MINOR__ < 5))
+    // If GCC version < 8.5, use this definition. For the other definition
+    // (which is equivalent), GCC 7.5 emits a somewhat cryptic error:
+    //    use of ‘outer1’ before deduction of ‘auto’
+    // inside of the visitor function in common::visit.
+    // Since this works with clang, MSVC and at least GCC 8.5, I'm assuming
+    // that this is some kind of a GCC issue.
+    using MatchTypes = std::tuple<evaluate::Add<T>, evaluate::Multiply<T>>;
+#else
+    using MatchTypes = typename decltype(outer1)::MatchTypes;
+#endif
+    // There is no way to ensure that the outer operation is the same as
+    // the inner one. They are matched independently, so we need to compare
+    // the index in the member variant that represents the matched type.
+    if ((match(outer1, x) && outer1.ref.index() == inner.ref.index()) ||
+        (match(outer2, x) && outer2.ref.index() == inner.ref.index())) {
+      size_t atomIdx{[&]() { // sub[atomIdx] will be the atom.
+        size_t idx;
+        for (idx = 0; idx != 3; ++idx) {
+          if (IsAtom(*sub[idx].ref)) {
+            break;
+          }
+        }
+        return idx;
+      }()};
+
+      if (atomIdx > 2) {
+        return Id::operator()(std::move(x), u);
+      }
+      return common::visit(
+          [&](auto &&s) {
+            using Expr = evaluate::Expr<T>;
+            using TypeS = llvm::remove_cvref_t<decltype(s)>;
+            // This visitor has to be semantically correct for all possible
+            // types of s even though at runtime s will only be one of the
+            // matched types.
+            // Limit the construction to the operation types that we tried
+            // to match (otherwise TypeS(op1, op2) would fail for non-binary
+            // operations).
+            if constexpr (common::HasMember<TypeS, MatchTypes>) {
+              Expr atom{*sub[atomIdx].ref};
+              Expr op1{*sub[(atomIdx + 1) % 3].ref};
+              Expr op2{*sub[(atomIdx + 2) % 3].ref};
+              return Expr(
+                  TypeS(atom, Expr(TypeS(std::move(op1), std::move(op2)))));
+            } else {
+              return Expr(TypeS(s));
+            }
+          },
+          evaluate::match::deparen(x).u);
+    }
+    return Id::operator()(std::move(x), u);
+  }
+
+  template <typename T, typename U,
+      typename = std::enable_if_t<!is_integral_v<T>>>
+  evaluate::Expr<T> operator()(
+      evaluate::Expr<T> &&x, const U &u, NonIntegralTag = {}) {
+    return Id::operator()(std::move(x), u);
+  }
+
+private:
+  template <typename T> bool IsAtom(const evaluate::Expr<T> &x) const {
+    return IsSameOrConvertOf(evaluate::AsGenericExpr(AsRvalue(x)), atom_);
+  }
+
+  const SomeExpr &atom_;
+};
+
 struct AnalyzedCondStmt {
   SomeExpr cond{evaluate::NullPointer{}}; // Default ctor is deleted
   parser::CharBlock source;
@@ -199,6 +333,26 @@ static std::pair<parser::CharBlock, parser::CharBlock> SplitAssignmentSource(
   llvm_unreachable("Could not find assignment operator");
 }
 
+static std::vector<SomeExpr> GetNonAtomExpressions(
+    const SomeExpr &atom, const std::vector<SomeExpr> &exprs) {
+  std::vector<SomeExpr> nonAtom;
+  for (const SomeExpr &e : exprs) {
+    if (!IsSameOrConvertOf(e, atom)) {
+      nonAtom.push_back(e);
+    }
+  }
+  return nonAtom;
+}
+
+static std::vector<SomeExpr> GetNonAtomArguments(
+    const SomeExpr &atom, const SomeExpr &expr) {
+  if (auto &&maybe{GetConvertInput(expr)}) {
+    return GetNonAtomExpressions(
+        atom, GetTopLevelOperationIgnoreResizing(*maybe).second);
+  }
+  return {};
+}
+
 static bool IsCheckForAssociated(const SomeExpr &cond) {
   return GetTopLevelOperationIgnoreResizing(cond).first ==
       operation::Operator::Associated;
@@ -576,6 +730,7 @@ void OmpStructureChecker::CheckAtomicCaptureAssignment(
     const evaluate::Assignment &capture, const SomeExpr &atom,
     parser::CharBlock source) {
   auto [lsrc, rsrc]{SplitAssignmentSource(source)};
+  (void)lsrc;
   const SomeExpr &cap{capture.lhs};
 
   if (!IsVarOrFunctionRef(atom)) {
@@ -592,6 +747,7 @@ void OmpStructureChecker::CheckAtomicCaptureAssignment(
 void OmpStructureChecker::CheckAtomicReadAssignment(
     const evaluate::Assignment &read, parser::CharBlock source) {
   auto [lsrc, rsrc]{SplitAssignmentSource(source)};
+  (void)lsrc;
 
   if (auto maybe{GetConvertInput(read.rhs)}) {
     const SomeExpr &atom{*maybe};
@@ -625,7 +781,8 @@ void OmpStructureChecker::CheckAtomicWriteAssignment(
   }
 }
 
-void OmpStructureChecker::CheckAtomicUpdateAssignment(
+std::optional<evaluate::Assignment>
+OmpStructureChecker::CheckAtomicUpdateAssignment(
     const evaluate::Assignment &update, parser::CharBlock source) {
   // [6.0:191:1-7]
   // An update structured block is update-statement, an update statement
@@ -641,14 +798,47 @@ void OmpStructureChecker::CheckAtomicUpdateAssignment(
   if (!IsVarOrFunctionRef(atom)) {
     ErrorShouldBeVariable(atom, rsrc);
     // Skip other checks.
-    return;
+    return std::nullopt;
   }
 
   CheckAtomicVariable(atom, lsrc);
 
+  auto [hasErrors, tryReassoc]{CheckAtomicUpdateAssignmentRhs(
+      atom, update.rhs, source, /*suppressDiagnostics=*/true)};
+
+  if (!hasErrors) {
+    CheckStorageOverlap(atom, GetNonAtomArguments(atom, update.rhs), source);
+    return std::nullopt;
+  } else if (tryReassoc) {
+    ReassocRewriter ra(atom);
+    SomeExpr raRhs{evaluate::rewrite::Mutator(ra)(update.rhs)};
+
+    std::tie(hasErrors, tryReassoc) = CheckAtomicUpdateAssignmentRhs(
+        atom, raRhs, source, /*suppressDiagnostics=*/true);
+    if (!hasErrors) {
+      CheckStorageOverlap(atom, GetNonAtomArguments(atom, raRhs), source);
+
+      evaluate::Assignment raAssign(update);
+      raAssign.rhs = raRhs;
+      return raAssign;
+    }
+  }
+
+  // This is guaranteed to report errors.
+  CheckAtomicUpdateAssignmentRhs(
+      atom, update.rhs, source, /*suppressDiagnostics=*/false);
+  return std::nullopt;
+}
+
+std::pair<bool, bool> OmpStructureChecker::CheckAtomicUpdateAssignmentRhs(
+    const SomeExpr &atom, const SomeExpr &rhs, parser::CharBlock source,
+    bool suppressDiagnostics) {
+  auto [lsrc, rsrc]{SplitAssignmentSource(source)};
+  (void)lsrc;
+
   std::pair<operation::Operator, std::vector<SomeExpr>> top{
       operation::Operator::Unknown, {}};
-  if (auto &&maybeInput{GetConvertInput(update.rhs)}) {
+  if (auto &&maybeInput{GetConvertInput(rhs)}) {
     top = GetTopLevelOperationIgnoreResizing(*maybeInput);
   }
   switch (top.first) {
@@ -665,29 +855,39 @@ void OmpStructureChecker::CheckAtomicUpdateAssignment(
   case operation::Operator::Identity:
     break;
   case operation::Operator::Call:
-    context_.Say(source,
-        "A call to this function is not a valid ATOMIC UPDATE operation"_err_en_US);
-    return;
+    if (!suppressDiagnostics) {
+      context_.Say(source,
+          "A call to this function is not a valid ATOMIC UPDATE operation"_err_en_US);
+    }
+    return std::make_pair(true, false);
   case operation::Operator::Convert:
-    context_.Say(source,
-        "An implicit or explicit type conversion is not a valid ATOMIC UPDATE operation"_err_en_US);
-    return;
+    if (!suppressDiagnostics) {
+      context_.Say(source,
+          "An implicit or explicit type conversion is not a valid ATOMIC UPDATE operation"_err_en_US);
+    }
+    return std::make_pair(true, false);
   case operation::Operator::Intrinsic:
-    context_.Say(source,
-        "This intrinsic function is not a valid ATOMIC UPDATE operation"_err_en_US);
-    return;
+    if (!suppressDiagnostics) {
+      context_.Say(source,
+          "This intrinsic function is not a valid ATOMIC UPDATE operation"_err_en_US);
+    }
+    return std::make_pair(true, false);
   case operation::Operator::Constant:
   case operation::Operator::Unknown:
-    context_.Say(
-        source, "This is not a valid ATOMIC UPDATE operation"_err_en_US);
-    return;
+    if (!suppressDiagnostics) {
+      context_.Say(
+          source, "This is not a valid ATOMIC UPDATE operation"_err_en_US);
+    }
+    return std::make_pair(true, false);
   default:
     assert(
         top.first != operation::Operator::Identity && "Handle this separately");
-    context_.Say(source,
-        "The %s operator is not a valid ATOMIC UPDATE operation"_err_en_US,
-        operation::ToString(top.first));
-    return;
+    if (!suppressDiagnostics) {
+      context_.Say(source,
+          "The %s operator is not a valid ATOMIC UPDATE operation"_err_en_US,
+          operation::ToString(top.first));
+    }
+    return std::make_pair(true, false);
   }
   // Check how many times `atom` occurs as an argument, if it's a subexpression
   // of an argument, and collect the non-atom arguments.
@@ -708,39 +908,48 @@ void OmpStructureChecker::CheckAtomicUpdateAssignment(
     return count;
   }()};
 
-  bool hasError{false};
+  bool hasError{false}, tryReassoc{false};
   if (subExpr) {
-    context_.Say(rsrc,
-        "The atomic variable %s cannot be a proper subexpression of an argument (here: %s) in the update operation"_err_en_US,
-        atom.AsFortran(), subExpr->AsFortran());
+    if (!suppressDiagnostics) {
+      context_.Say(rsrc,
+          "The atomic variable %s cannot be a proper subexpression of an argument (here: %s) in the update operation"_err_en_US,
+          atom.AsFortran(), subExpr->AsFortran());
+    }
     hasError = true;
   }
   if (top.first == operation::Operator::Identity) {
     // This is "x = y".
     assert((atomCount == 0 || atomCount == 1) && "Unexpected count");
     if (atomCount == 0) {
-      context_.Say(rsrc,
-          "The atomic variable %s should appear as an argument in the update operation"_err_en_US,
-          atom.AsFortran());
+      if (!suppressDiagnostics) {
+        context_.Say(rsrc,
+            "The atomic variable %s should appear as an argument in the update operation"_err_en_US,
+            atom.AsFortran());
+      }
       hasError = true;
     }
   } else {
     if (atomCount == 0) {
-      context_.Say(rsrc,
-          "The atomic variable %s should appear as an argument of the top-level %s operator"_err_en_US,
-          atom.AsFortran(), operation::ToString(top.first));
+      if (!suppressDiagnostics) {
+        context_.Say(rsrc,
+            "The atomic variable %s should appear as an argument of the top-level %s operator"_err_en_US,
+            atom.AsFortran(), operation::ToString(top.first));
+      }
+      // If `atom` is a proper subexpression, and it not present as an
+      // argument on its own, reassociation may be able to help.
+      tryReassoc = subExpr.has_value();
       hasError = true;
     } else if (atomCount > 1) {
-      context_.Say(rsrc,
-          "The atomic variable %s should be exactly one of the arguments of the top-level %s operator"_err_en_US,
-          atom.AsFortran(), operation::ToString(top.first));
+      if (!suppressDiagnostics) {
+        context_.Say(rsrc,
+            "The atomic variable %s should be exactly one of the arguments of the top-level %s operator"_err_en_US,
+            atom.AsFortran(), operation::ToString(top.first));
+      }
       hasError = true;
     }
   }
 
-  if (!hasError) {
-    CheckStorageOverlap(atom, nonAtom, source);
-  }
+  return std::make_pair(hasError, tryReassoc);
 }
 
 void OmpStructureChecker::CheckAtomicConditionalUpdateAssignment(
@@ -843,11 +1052,13 @@ void OmpStructureChecker::CheckAtomicUpdateOnly(
     SourcedActionStmt action{GetActionStmt(&body.front())};
     if (auto maybeUpdate{GetEvaluateAssignment(action.stmt)}) {
       const SomeExpr &atom{maybeUpdate->lhs};
-      CheckAtomicUpdateAssignment(*maybeUpdate, action.source);
+      auto maybeAssign{
+          CheckAtomicUpdateAssignment(*maybeUpdate, action.source)};
+      auto &updateAssign{maybeAssign.has_value() ? maybeAssign : maybeUpdate};
 
       using Analysis = parser::OpenMPAtomicConstruct::Analysis;
       x.analysis = AtomicAnalysis(atom)
-                       .addOp0(Analysis::Update, maybeUpdate)
+                       .addOp0(Analysis::Update, updateAssign)
                        .addOp1(Analysis::None);
     } else if (!IsAssignment(action.stmt)) {
       context_.Say(
@@ -963,16 +1174,19 @@ void OmpStructureChecker::CheckAtomicUpdateCapture(
   using Analysis = parser::OpenMPAtomicConstruct::Analysis;
   int action;
 
+  std::optional<evaluate::Assignment> updateAssign{update};
   if (IsMaybeAtomicWrite(update)) {
     action = Analysis::Write;
     CheckAtomicWriteAssignment(update, uact.source);
   } else {
     action = Analysis::Update;
-    CheckAtomicUpdateAssignment(update, uact.source);
+    if (auto &&maybe{CheckAtomicUpdateAssignment(update, uact.source)}) {
+      updateAssign = maybe;
+    }
   }
   CheckAtomicCaptureAssignment(capture, atom, cact.source);
 
-  if (IsPointerAssignment(update) != IsPointerAssignment(capture)) {
+  if (IsPointerAssignment(*updateAssign) != IsPointerAssignment(capture)) {
     context_.Say(cact.source,
         "The update and capture assignments should both be pointer-assignments or both be non-pointer-assignments"_err_en_US);
     return;
@@ -980,12 +1194,12 @@ void OmpStructureChecker::CheckAtomicUpdateCapture(
 
   if (GetActionStmt(&body.front()).stmt == uact.stmt) {
     x.analysis = AtomicAnalysis(atom)
-                     .addOp0(action, update)
+                     .addOp0(action, updateAssign)
                      .addOp1(Analysis::Read, capture);
   } else {
     x.analysis = AtomicAnalysis(atom)
                      .addOp0(Analysis::Read, capture)
-                     .addOp1(action, update);
+                     .addOp1(action, updateAssign);
   }
 }
 
diff --git a/flang/lib/Semantics/check-omp-structure.h b/flang/lib/Semantics/check-omp-structure.h
index 6b33ca6ab583f..a973aee28d0e2 100644
--- a/flang/lib/Semantics/check-omp-structure.h
+++ b/flang/lib/Semantics/check-omp-structure.h
@@ -267,8 +267,10 @@ class OmpStructureChecker
       const evaluate::Assignment &read, parser::CharBlock source);
   void CheckAtomicWriteAssignment(
       const evaluate::Assignment &write, parser::CharBlock source);
-  void CheckAtomicUpdateAssignment(
+  std::optional<evaluate::Assignment> CheckAtomicUpdateAssignment(
       const evaluate::Assignment &update, parser::CharBlock source);
+  std::pair<bool, bool> CheckAtomicUpdateAssignmentRhs(const SomeExpr &atom,
+      const SomeExpr &rhs, parser::CharBlock source, bool suppressDiagnostics);
   void CheckAtomicConditionalUpdateAssignment(const SomeExpr &cond,
       parser::CharBlock condSource, const evaluate::Assignment &assign,
       parser::CharBlock assignSource);
diff --git a/flang/test/Lower/OpenMP/atomic-update-reassoc.f90 b/flang/test/Lower/OpenMP/atomic-update-reassoc.f90
new file mode 100644
index 0000000000000..96ebb56b8d6ca
--- /dev/null
+++ b/flang/test/Lower/OpenMP/atomic-update-reassoc.f90
@@ -0,0 +1,75 @@
+!RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=60 %s -o - | FileCheck %s
+
+subroutine f00(x, y)
+  implicit none
+  integer :: x, y
+
+  !$omp atomic update
+  x = ((x + 1) + y) + 2
+end
+
+!CHECK-LABEL: func.func @_QPf00
+!CHECK: %[[X:[0-9]+]]:2 = hlfir.declare %arg0
+!CHECK: %[[Y:[0-9]+]]:2 = hlfir.declare %arg1
+!CHECK: %c1_i32 = arith.constant 1 : i32
+!CHECK: %[[LOAD_Y:[0-9]+]] = fir.load %[[Y]]#0 : !fir.ref<i32>
+!CHECK: %[[Y_1:[0-9]+]] = arith.addi %c1_i32, %[[LOAD_Y]] : i32
+!CHECK: %c2_i32 = arith.constant 2 : i32
+!CHECK: %[[Y_1_2:[0-9]+]] = arith.addi %[[Y_1]], %c2_i32 : i32
+!CHECK: omp.atomic.update memory_order(relaxed) %[[X]]#0 : !fir.ref<i32> {
+!CHECK: ^bb0(%[[ARG:arg[0-9]+]]: i32):
+!CHECK:   %[[ARG_P:[0-9]+]] = arith.addi %[[ARG]], %[[Y_1_2]] : i32
+!CHECK:   omp.yield(%[[ARG_P]] : i32)
+!CHECK: }
+
+
+subroutine f01(x, y)
+  implicit none
+  real :: x
+  integer :: y
+
+  !$omp atomic update
+  x = (int(x) + y) + 1
+end
+
+!CHECK-LABEL: func.func @_QPf01
+!CHECK: %[[X:[0-9]+]]:2 = hlfir.declare %arg0
+!CHECK: %[[Y:[0-9]+]]:2 = hlfir.declare %arg1
+!CHECK: %[[LOAD_Y:[0-9]+]] = fir.load %[[Y]]#0 : !fir.ref<i32>
+!CHECK: %c1_i32 = arith.constant 1 : i32
+!CHECK: %[[Y_1:[0-9]+]] = arith.addi %[[LOAD_Y]], %c1_i32 : i32
+!CHECK: omp.atomic.update memory_order(relaxed) %[[X]]#0 : !fir.ref<f32> {
+!CHECK: ^bb0(%[[ARG:arg[0-9]+]]: f32):
+!CHECK:   %[[ARG_I:[0-9]+]] = fir.convert %[[ARG]] : (f32) -> i32
+!CHECK:   %[[ARG_P:[0-9]+]] = arith.addi %[[ARG_I]], %[[Y_1]] : i32
+!CHECK:   %[[ARG_F:[0-9]+]] = fir.convert %[[ARG_P]] : (i32) -> f32
+!CHECK:   omp.yield(%[[ARG_F]] : f32)
+!CHECK: }
+
+
+subroutine f02(x, a, b, c)
+  implicit none
+  integer(kind=4) :: x
+  integer(kind=8) :: a, b, c
+
+  !$omp atomic update
+  x = ((b + a) + x) + c
+end
+
+!CHECK-LABEL: func.func @_QPf02
+!CHECK: %[[A:[0-9]+]]:2 = hlfir.declare %arg1
+!CHECK: %[[B:[0-9]+]]:2 = hlfir.declare %arg2
+!CHECK: %[[C:[0-9]+]]:2 = hlfir.declare %arg3
+!CHECK: %[[X:[0-9]+]]:2 = hlfir.declare %arg0
+!CHECK: %[[LOAD_B:[0-9]+]] = fir.load %[[B]]#0 : !fir.ref<i64>
+!CHECK: %[[LOAD_A:[0-9]+]] = fir.load %[[A]]#0 : !fir.ref<i64>
+!CHECK: %[[A_B:[0-9]+]] = arith.addi %[[LOAD_B]], ...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Aug 13, 2025

@llvm/pr-subscribers-flang-fir-hlfir

Author: Krzysztof Parzyszek (kparzysz)

Changes

An atomic update expression of form
x = x + a + b
is technically illegal, since the right-hand side is parsed as (x+a)+b, and the atomic variable x should be an argument to the top-level +. When the type of x is integer, the result of (x+a)+b is guaranteed to be the same as x+(a+b), so instead of reporting an error, the compiler can treat (x+a)+b as x+(a+b).

This PR implements this kind of reassociation for integral types, and for the two arithmetic associative/commutative operators: + and *.

Reinstate PR153098 one more time with fixes for the issues that came up:

  • unused variable "lsrc",
  • use of ‘outer1’ before deduction of ‘auto’.

Patch is 22.32 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/153488.diff

5 Files Affected:

  • (modified) flang/lib/Semantics/check-omp-atomic.cpp (+255-41)
  • (modified) flang/lib/Semantics/check-omp-structure.h (+3-1)
  • (added) flang/test/Lower/OpenMP/atomic-update-reassoc.f90 (+75)
  • (modified) flang/test/Semantics/OpenMP/atomic-update-only.f90 (+9-2)
  • (modified) flang/test/Semantics/OpenMP/atomic04.f90 (+1-2)
diff --git a/flang/lib/Semantics/check-omp-atomic.cpp b/flang/lib/Semantics/check-omp-atomic.cpp
index 0c0e6158485e9..9d92be6327fdb 100644
--- a/flang/lib/Semantics/check-omp-atomic.cpp
+++ b/flang/lib/Semantics/check-omp-atomic.cpp
@@ -13,7 +13,9 @@
 #include "check-omp-structure.h"
 
 #include "flang/Common/indirection.h"
+#include "flang/Common/template.h"
 #include "flang/Evaluate/expression.h"
+#include "flang/Evaluate/match.h"
 #include "flang/Evaluate/rewrite.h"
 #include "flang/Evaluate/tools.h"
 #include "flang/Parser/char-block.h"
@@ -50,6 +52,138 @@ static bool operator!=(const evaluate::Expr<T> &e, const evaluate::Expr<U> &f) {
   return !(e == f);
 }
 
+namespace {
+template <typename...> struct IsIntegral {
+  static constexpr bool value{false};
+};
+
+template <common::TypeCategory C, int K>
+struct IsIntegral<evaluate::Type<C, K>> {
+  static constexpr bool value{//
+      C == common::TypeCategory::Integer ||
+      C == common::TypeCategory::Unsigned ||
+      C == common::TypeCategory::Logical};
+};
+
+template <typename T> constexpr bool is_integral_v{IsIntegral<T>::value};
+
+template <typename T, typename Op0, typename Op1>
+using ReassocOpBase = evaluate::match::AnyOfPattern< //
+    evaluate::match::Add<T, Op0, Op1>, //
+    evaluate::match::Mul<T, Op0, Op1>>;
+
+template <typename T, typename Op0, typename Op1>
+struct ReassocOp : public ReassocOpBase<T, Op0, Op1> {
+  using Base = ReassocOpBase<T, Op0, Op1>;
+  using Base::Base;
+};
+
+template <typename T, typename Op0, typename Op1>
+ReassocOp<T, Op0, Op1> reassocOp(const Op0 &op0, const Op1 &op1) {
+  return ReassocOp<T, Op0, Op1>(op0, op1);
+}
+} // namespace
+
+struct ReassocRewriter : public evaluate::rewrite::Identity {
+  using Id = evaluate::rewrite::Identity;
+  using Id::operator();
+  struct NonIntegralTag {};
+
+  ReassocRewriter(const SomeExpr &atom) : atom_(atom) {}
+
+  // Try to find cases where the input expression is of the form
+  // (1) (a . b) . c, or
+  // (2) a . (b . c),
+  // where . denotes an associative operation (currently + or *), and a, b, c
+  // are some subexpresions.
+  // If one of the operands in the nested operation is the atomic variable
+  // (with some possible type conversions applied to it), bring it to the
+  // top-level operation, and move the top-level operand into the nested
+  // operation.
+  // For example, assuming x is the atomic variable:
+  //   (a + x) + b  ->  (a + b) + x,  i.e. (conceptually) swap x and b.
+  template <typename T, typename U,
+      typename = std::enable_if_t<is_integral_v<T>>>
+  evaluate::Expr<T> operator()(evaluate::Expr<T> &&x, const U &u) {
+    // As per the above comment, there are 3 subexpressions involved in this
+    // transformation. A match::Expr<T> will match evaluate::Expr<U> when T is
+    // same as U, plus it will store a pointer (ref) to the matched expression.
+    // When the match is successful, the sub[i].ref will point to a, b, x (in
+    // some order) from the example above.
+    evaluate::match::Expr<T> sub[3];
+    auto inner{reassocOp<T>(sub[0], sub[1])};
+    auto outer1{reassocOp<T>(inner, sub[2])}; // inner + something
+    auto outer2{reassocOp<T>(sub[2], inner)}; // something + inner
+#if !defined(__clang__) && !defined(_MSC_VER) && \
+      (__GNUC__ < 8 || (__GNUC__ == 8 && __GNUC_MINOR__ < 5))
+    // If GCC version < 8.5, use this definition. For the other definition
+    // (which is equivalent), GCC 7.5 emits a somewhat cryptic error:
+    //    use of ‘outer1’ before deduction of ‘auto’
+    // inside of the visitor function in common::visit.
+    // Since this works with clang, MSVC and at least GCC 8.5, I'm assuming
+    // that this is some kind of a GCC issue.
+    using MatchTypes = std::tuple<evaluate::Add<T>, evaluate::Multiply<T>>;
+#else
+    using MatchTypes = typename decltype(outer1)::MatchTypes;
+#endif
+    // There is no way to ensure that the outer operation is the same as
+    // the inner one. They are matched independently, so we need to compare
+    // the index in the member variant that represents the matched type.
+    if ((match(outer1, x) && outer1.ref.index() == inner.ref.index()) ||
+        (match(outer2, x) && outer2.ref.index() == inner.ref.index())) {
+      size_t atomIdx{[&]() { // sub[atomIdx] will be the atom.
+        size_t idx;
+        for (idx = 0; idx != 3; ++idx) {
+          if (IsAtom(*sub[idx].ref)) {
+            break;
+          }
+        }
+        return idx;
+      }()};
+
+      if (atomIdx > 2) {
+        return Id::operator()(std::move(x), u);
+      }
+      return common::visit(
+          [&](auto &&s) {
+            using Expr = evaluate::Expr<T>;
+            using TypeS = llvm::remove_cvref_t<decltype(s)>;
+            // This visitor has to be semantically correct for all possible
+            // types of s even though at runtime s will only be one of the
+            // matched types.
+            // Limit the construction to the operation types that we tried
+            // to match (otherwise TypeS(op1, op2) would fail for non-binary
+            // operations).
+            if constexpr (common::HasMember<TypeS, MatchTypes>) {
+              Expr atom{*sub[atomIdx].ref};
+              Expr op1{*sub[(atomIdx + 1) % 3].ref};
+              Expr op2{*sub[(atomIdx + 2) % 3].ref};
+              return Expr(
+                  TypeS(atom, Expr(TypeS(std::move(op1), std::move(op2)))));
+            } else {
+              return Expr(TypeS(s));
+            }
+          },
+          evaluate::match::deparen(x).u);
+    }
+    return Id::operator()(std::move(x), u);
+  }
+
+  template <typename T, typename U,
+      typename = std::enable_if_t<!is_integral_v<T>>>
+  evaluate::Expr<T> operator()(
+      evaluate::Expr<T> &&x, const U &u, NonIntegralTag = {}) {
+    return Id::operator()(std::move(x), u);
+  }
+
+private:
+  template <typename T> bool IsAtom(const evaluate::Expr<T> &x) const {
+    return IsSameOrConvertOf(evaluate::AsGenericExpr(AsRvalue(x)), atom_);
+  }
+
+  const SomeExpr &atom_;
+};
+
 struct AnalyzedCondStmt {
   SomeExpr cond{evaluate::NullPointer{}}; // Default ctor is deleted
   parser::CharBlock source;
@@ -199,6 +333,26 @@ static std::pair<parser::CharBlock, parser::CharBlock> SplitAssignmentSource(
   llvm_unreachable("Could not find assignment operator");
 }
 
+static std::vector<SomeExpr> GetNonAtomExpressions(
+    const SomeExpr &atom, const std::vector<SomeExpr> &exprs) {
+  std::vector<SomeExpr> nonAtom;
+  for (const SomeExpr &e : exprs) {
+    if (!IsSameOrConvertOf(e, atom)) {
+      nonAtom.push_back(e);
+    }
+  }
+  return nonAtom;
+}
+
+static std::vector<SomeExpr> GetNonAtomArguments(
+    const SomeExpr &atom, const SomeExpr &expr) {
+  if (auto &&maybe{GetConvertInput(expr)}) {
+    return GetNonAtomExpressions(
+        atom, GetTopLevelOperationIgnoreResizing(*maybe).second);
+  }
+  return {};
+}
+
 static bool IsCheckForAssociated(const SomeExpr &cond) {
   return GetTopLevelOperationIgnoreResizing(cond).first ==
       operation::Operator::Associated;
@@ -576,6 +730,7 @@ void OmpStructureChecker::CheckAtomicCaptureAssignment(
     const evaluate::Assignment &capture, const SomeExpr &atom,
     parser::CharBlock source) {
   auto [lsrc, rsrc]{SplitAssignmentSource(source)};
+  (void)lsrc;
   const SomeExpr &cap{capture.lhs};
 
   if (!IsVarOrFunctionRef(atom)) {
@@ -592,6 +747,7 @@ void OmpStructureChecker::CheckAtomicCaptureAssignment(
 void OmpStructureChecker::CheckAtomicReadAssignment(
     const evaluate::Assignment &read, parser::CharBlock source) {
   auto [lsrc, rsrc]{SplitAssignmentSource(source)};
+  (void)lsrc;
 
   if (auto maybe{GetConvertInput(read.rhs)}) {
     const SomeExpr &atom{*maybe};
@@ -625,7 +781,8 @@ void OmpStructureChecker::CheckAtomicWriteAssignment(
   }
 }
 
-void OmpStructureChecker::CheckAtomicUpdateAssignment(
+std::optional<evaluate::Assignment>
+OmpStructureChecker::CheckAtomicUpdateAssignment(
     const evaluate::Assignment &update, parser::CharBlock source) {
   // [6.0:191:1-7]
   // An update structured block is update-statement, an update statement
@@ -641,14 +798,47 @@ void OmpStructureChecker::CheckAtomicUpdateAssignment(
   if (!IsVarOrFunctionRef(atom)) {
     ErrorShouldBeVariable(atom, rsrc);
     // Skip other checks.
-    return;
+    return std::nullopt;
   }
 
   CheckAtomicVariable(atom, lsrc);
 
+  auto [hasErrors, tryReassoc]{CheckAtomicUpdateAssignmentRhs(
+      atom, update.rhs, source, /*suppressDiagnostics=*/true)};
+
+  if (!hasErrors) {
+    CheckStorageOverlap(atom, GetNonAtomArguments(atom, update.rhs), source);
+    return std::nullopt;
+  } else if (tryReassoc) {
+    ReassocRewriter ra(atom);
+    SomeExpr raRhs{evaluate::rewrite::Mutator(ra)(update.rhs)};
+
+    std::tie(hasErrors, tryReassoc) = CheckAtomicUpdateAssignmentRhs(
+        atom, raRhs, source, /*suppressDiagnostics=*/true);
+    if (!hasErrors) {
+      CheckStorageOverlap(atom, GetNonAtomArguments(atom, raRhs), source);
+
+      evaluate::Assignment raAssign(update);
+      raAssign.rhs = raRhs;
+      return raAssign;
+    }
+  }
+
+  // This is guaranteed to report errors.
+  CheckAtomicUpdateAssignmentRhs(
+      atom, update.rhs, source, /*suppressDiagnostics=*/false);
+  return std::nullopt;
+}
+
+std::pair<bool, bool> OmpStructureChecker::CheckAtomicUpdateAssignmentRhs(
+    const SomeExpr &atom, const SomeExpr &rhs, parser::CharBlock source,
+    bool suppressDiagnostics) {
+  auto [lsrc, rsrc]{SplitAssignmentSource(source)};
+  (void)lsrc;
+
   std::pair<operation::Operator, std::vector<SomeExpr>> top{
       operation::Operator::Unknown, {}};
-  if (auto &&maybeInput{GetConvertInput(update.rhs)}) {
+  if (auto &&maybeInput{GetConvertInput(rhs)}) {
     top = GetTopLevelOperationIgnoreResizing(*maybeInput);
   }
   switch (top.first) {
@@ -665,29 +855,39 @@ void OmpStructureChecker::CheckAtomicUpdateAssignment(
   case operation::Operator::Identity:
     break;
   case operation::Operator::Call:
-    context_.Say(source,
-        "A call to this function is not a valid ATOMIC UPDATE operation"_err_en_US);
-    return;
+    if (!suppressDiagnostics) {
+      context_.Say(source,
+          "A call to this function is not a valid ATOMIC UPDATE operation"_err_en_US);
+    }
+    return std::make_pair(true, false);
   case operation::Operator::Convert:
-    context_.Say(source,
-        "An implicit or explicit type conversion is not a valid ATOMIC UPDATE operation"_err_en_US);
-    return;
+    if (!suppressDiagnostics) {
+      context_.Say(source,
+          "An implicit or explicit type conversion is not a valid ATOMIC UPDATE operation"_err_en_US);
+    }
+    return std::make_pair(true, false);
   case operation::Operator::Intrinsic:
-    context_.Say(source,
-        "This intrinsic function is not a valid ATOMIC UPDATE operation"_err_en_US);
-    return;
+    if (!suppressDiagnostics) {
+      context_.Say(source,
+          "This intrinsic function is not a valid ATOMIC UPDATE operation"_err_en_US);
+    }
+    return std::make_pair(true, false);
   case operation::Operator::Constant:
   case operation::Operator::Unknown:
-    context_.Say(
-        source, "This is not a valid ATOMIC UPDATE operation"_err_en_US);
-    return;
+    if (!suppressDiagnostics) {
+      context_.Say(
+          source, "This is not a valid ATOMIC UPDATE operation"_err_en_US);
+    }
+    return std::make_pair(true, false);
   default:
     assert(
         top.first != operation::Operator::Identity && "Handle this separately");
-    context_.Say(source,
-        "The %s operator is not a valid ATOMIC UPDATE operation"_err_en_US,
-        operation::ToString(top.first));
-    return;
+    if (!suppressDiagnostics) {
+      context_.Say(source,
+          "The %s operator is not a valid ATOMIC UPDATE operation"_err_en_US,
+          operation::ToString(top.first));
+    }
+    return std::make_pair(true, false);
   }
   // Check how many times `atom` occurs as an argument, if it's a subexpression
   // of an argument, and collect the non-atom arguments.
@@ -708,39 +908,48 @@ void OmpStructureChecker::CheckAtomicUpdateAssignment(
     return count;
   }()};
 
-  bool hasError{false};
+  bool hasError{false}, tryReassoc{false};
   if (subExpr) {
-    context_.Say(rsrc,
-        "The atomic variable %s cannot be a proper subexpression of an argument (here: %s) in the update operation"_err_en_US,
-        atom.AsFortran(), subExpr->AsFortran());
+    if (!suppressDiagnostics) {
+      context_.Say(rsrc,
+          "The atomic variable %s cannot be a proper subexpression of an argument (here: %s) in the update operation"_err_en_US,
+          atom.AsFortran(), subExpr->AsFortran());
+    }
     hasError = true;
   }
   if (top.first == operation::Operator::Identity) {
     // This is "x = y".
     assert((atomCount == 0 || atomCount == 1) && "Unexpected count");
     if (atomCount == 0) {
-      context_.Say(rsrc,
-          "The atomic variable %s should appear as an argument in the update operation"_err_en_US,
-          atom.AsFortran());
+      if (!suppressDiagnostics) {
+        context_.Say(rsrc,
+            "The atomic variable %s should appear as an argument in the update operation"_err_en_US,
+            atom.AsFortran());
+      }
       hasError = true;
     }
   } else {
     if (atomCount == 0) {
-      context_.Say(rsrc,
-          "The atomic variable %s should appear as an argument of the top-level %s operator"_err_en_US,
-          atom.AsFortran(), operation::ToString(top.first));
+      if (!suppressDiagnostics) {
+        context_.Say(rsrc,
+            "The atomic variable %s should appear as an argument of the top-level %s operator"_err_en_US,
+            atom.AsFortran(), operation::ToString(top.first));
+      }
+      // If `atom` is a proper subexpression, and it not present as an
+      // argument on its own, reassociation may be able to help.
+      tryReassoc = subExpr.has_value();
       hasError = true;
     } else if (atomCount > 1) {
-      context_.Say(rsrc,
-          "The atomic variable %s should be exactly one of the arguments of the top-level %s operator"_err_en_US,
-          atom.AsFortran(), operation::ToString(top.first));
+      if (!suppressDiagnostics) {
+        context_.Say(rsrc,
+            "The atomic variable %s should be exactly one of the arguments of the top-level %s operator"_err_en_US,
+            atom.AsFortran(), operation::ToString(top.first));
+      }
       hasError = true;
     }
   }
 
-  if (!hasError) {
-    CheckStorageOverlap(atom, nonAtom, source);
-  }
+  return std::make_pair(hasError, tryReassoc);
 }
 
 void OmpStructureChecker::CheckAtomicConditionalUpdateAssignment(
@@ -843,11 +1052,13 @@ void OmpStructureChecker::CheckAtomicUpdateOnly(
     SourcedActionStmt action{GetActionStmt(&body.front())};
     if (auto maybeUpdate{GetEvaluateAssignment(action.stmt)}) {
       const SomeExpr &atom{maybeUpdate->lhs};
-      CheckAtomicUpdateAssignment(*maybeUpdate, action.source);
+      auto maybeAssign{
+          CheckAtomicUpdateAssignment(*maybeUpdate, action.source)};
+      auto &updateAssign{maybeAssign.has_value() ? maybeAssign : maybeUpdate};
 
       using Analysis = parser::OpenMPAtomicConstruct::Analysis;
       x.analysis = AtomicAnalysis(atom)
-                       .addOp0(Analysis::Update, maybeUpdate)
+                       .addOp0(Analysis::Update, updateAssign)
                        .addOp1(Analysis::None);
     } else if (!IsAssignment(action.stmt)) {
       context_.Say(
@@ -963,16 +1174,19 @@ void OmpStructureChecker::CheckAtomicUpdateCapture(
   using Analysis = parser::OpenMPAtomicConstruct::Analysis;
   int action;
 
+  std::optional<evaluate::Assignment> updateAssign{update};
   if (IsMaybeAtomicWrite(update)) {
     action = Analysis::Write;
     CheckAtomicWriteAssignment(update, uact.source);
   } else {
     action = Analysis::Update;
-    CheckAtomicUpdateAssignment(update, uact.source);
+    if (auto &&maybe{CheckAtomicUpdateAssignment(update, uact.source)}) {
+      updateAssign = maybe;
+    }
   }
   CheckAtomicCaptureAssignment(capture, atom, cact.source);
 
-  if (IsPointerAssignment(update) != IsPointerAssignment(capture)) {
+  if (IsPointerAssignment(*updateAssign) != IsPointerAssignment(capture)) {
     context_.Say(cact.source,
         "The update and capture assignments should both be pointer-assignments or both be non-pointer-assignments"_err_en_US);
     return;
@@ -980,12 +1194,12 @@ void OmpStructureChecker::CheckAtomicUpdateCapture(
 
   if (GetActionStmt(&body.front()).stmt == uact.stmt) {
     x.analysis = AtomicAnalysis(atom)
-                     .addOp0(action, update)
+                     .addOp0(action, updateAssign)
                      .addOp1(Analysis::Read, capture);
   } else {
     x.analysis = AtomicAnalysis(atom)
                      .addOp0(Analysis::Read, capture)
-                     .addOp1(action, update);
+                     .addOp1(action, updateAssign);
   }
 }
 
diff --git a/flang/lib/Semantics/check-omp-structure.h b/flang/lib/Semantics/check-omp-structure.h
index 6b33ca6ab583f..a973aee28d0e2 100644
--- a/flang/lib/Semantics/check-omp-structure.h
+++ b/flang/lib/Semantics/check-omp-structure.h
@@ -267,8 +267,10 @@ class OmpStructureChecker
       const evaluate::Assignment &read, parser::CharBlock source);
   void CheckAtomicWriteAssignment(
       const evaluate::Assignment &write, parser::CharBlock source);
-  void CheckAtomicUpdateAssignment(
+  std::optional<evaluate::Assignment> CheckAtomicUpdateAssignment(
       const evaluate::Assignment &update, parser::CharBlock source);
+  std::pair<bool, bool> CheckAtomicUpdateAssignmentRhs(const SomeExpr &atom,
+      const SomeExpr &rhs, parser::CharBlock source, bool suppressDiagnostics);
   void CheckAtomicConditionalUpdateAssignment(const SomeExpr &cond,
       parser::CharBlock condSource, const evaluate::Assignment &assign,
       parser::CharBlock assignSource);
diff --git a/flang/test/Lower/OpenMP/atomic-update-reassoc.f90 b/flang/test/Lower/OpenMP/atomic-update-reassoc.f90
new file mode 100644
index 0000000000000..96ebb56b8d6ca
--- /dev/null
+++ b/flang/test/Lower/OpenMP/atomic-update-reassoc.f90
@@ -0,0 +1,75 @@
+!RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=60 %s -o - | FileCheck %s
+
+subroutine f00(x, y)
+  implicit none
+  integer :: x, y
+
+  !$omp atomic update
+  x = ((x + 1) + y) + 2
+end
+
+!CHECK-LABEL: func.func @_QPf00
+!CHECK: %[[X:[0-9]+]]:2 = hlfir.declare %arg0
+!CHECK: %[[Y:[0-9]+]]:2 = hlfir.declare %arg1
+!CHECK: %c1_i32 = arith.constant 1 : i32
+!CHECK: %[[LOAD_Y:[0-9]+]] = fir.load %[[Y]]#0 : !fir.ref<i32>
+!CHECK: %[[Y_1:[0-9]+]] = arith.addi %c1_i32, %[[LOAD_Y]] : i32
+!CHECK: %c2_i32 = arith.constant 2 : i32
+!CHECK: %[[Y_1_2:[0-9]+]] = arith.addi %[[Y_1]], %c2_i32 : i32
+!CHECK: omp.atomic.update memory_order(relaxed) %[[X]]#0 : !fir.ref<i32> {
+!CHECK: ^bb0(%[[ARG:arg[0-9]+]]: i32):
+!CHECK:   %[[ARG_P:[0-9]+]] = arith.addi %[[ARG]], %[[Y_1_2]] : i32
+!CHECK:   omp.yield(%[[ARG_P]] : i32)
+!CHECK: }
+
+
+subroutine f01(x, y)
+  implicit none
+  real :: x
+  integer :: y
+
+  !$omp atomic update
+  x = (int(x) + y) + 1
+end
+
+!CHECK-LABEL: func.func @_QPf01
+!CHECK: %[[X:[0-9]+]]:2 = hlfir.declare %arg0
+!CHECK: %[[Y:[0-9]+]]:2 = hlfir.declare %arg1
+!CHECK: %[[LOAD_Y:[0-9]+]] = fir.load %[[Y]]#0 : !fir.ref<i32>
+!CHECK: %c1_i32 = arith.constant 1 : i32
+!CHECK: %[[Y_1:[0-9]+]] = arith.addi %[[LOAD_Y]], %c1_i32 : i32
+!CHECK: omp.atomic.update memory_order(relaxed) %[[X]]#0 : !fir.ref<f32> {
+!CHECK: ^bb0(%[[ARG:arg[0-9]+]]: f32):
+!CHECK:   %[[ARG_I:[0-9]+]] = fir.convert %[[ARG]] : (f32) -> i32
+!CHECK:   %[[ARG_P:[0-9]+]] = arith.addi %[[ARG_I]], %[[Y_1]] : i32
+!CHECK:   %[[ARG_F:[0-9]+]] = fir.convert %[[ARG_P]] : (i32) -> f32
+!CHECK:   omp.yield(%[[ARG_F]] : f32)
+!CHECK: }
+
+
+subroutine f02(x, a, b, c)
+  implicit none
+  integer(kind=4) :: x
+  integer(kind=8) :: a, b, c
+
+  !$omp atomic update
+  x = ((b + a) + x) + c
+end
+
+!CHECK-LABEL: func.func @_QPf02
+!CHECK: %[[A:[0-9]+]]:2 = hlfir.declare %arg1
+!CHECK: %[[B:[0-9]+]]:2 = hlfir.declare %arg2
+!CHECK: %[[C:[0-9]+]]:2 = hlfir.declare %arg3
+!CHECK: %[[X:[0-9]+]]:2 = hlfir.declare %arg0
+!CHECK: %[[LOAD_B:[0-9]+]] = fir.load %[[B]]#0 : !fir.ref<i64>
+!CHECK: %[[LOAD_A:[0-9]+]] = fir.load %[[A]]#0 : !fir.ref<i64>
+!CHECK: %[[A_B:[0-9]+]] = arith.addi %[[LOAD_B]], ...
[truncated]

@github-actions
Copy link

github-actions bot commented Aug 13, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@kparzysz kparzysz merged commit 1e7772a into main Aug 13, 2025
9 checks passed
@kparzysz kparzysz deleted the users/kparzysz/w05-rewrite-reassoc branch August 13, 2025 20:54
Meinersbur added a commit that referenced this pull request Aug 14, 2025
PR #153488 caused the msvc build (https://lab.llvm.org/buildbot/#/builders/166/builds/1397) to fail:
```
..\llvm-project\flang\include\flang/Evaluate/rewrite.h(78): error C2668: 'Fortran::evaluate::rewrite::Identity::operator ()': ambiguous call to overloaded function
..\llvm-project\flang\include\flang/Evaluate/rewrite.h(43): note: could be 'Fortran::evaluate::Expr<Fortran::evaluate::SomeType> Fortran::evaluate::rewrite::Identity::operator ()<Fortran::evaluate::SomeType,S>(Fortran::evaluate::Expr<Fortran::evaluate::SomeType> &&,const U &)'
        with
        [
            S=Fortran::evaluate::value::Integer<128,true,32,unsigned int,unsigned __int64,128>,
            U=Fortran::evaluate::value::Integer<128,true,32,unsigned int,unsigned __int64,128>
        ]
..\llvm-project\flang\lib\Semantics\check-omp-atomic.cpp(174): note: or       'Fortran::evaluate::Expr<Fortran::evaluate::SomeType> Fortran::semantics::ReassocRewriter::operator ()<Fortran::evaluate::SomeType,S,void>(Fortran::evaluate::Expr<Fortran::evaluate::SomeType> &&,const U &,Fortran::semantics::ReassocRewriter::NonIntegralTag)'
        with
        [
            S=Fortran::evaluate::value::Integer<128,true,32,unsigned int,unsigned __int64,128>,
            U=Fortran::evaluate::value::Integer<128,true,32,unsigned int,unsigned __int64,128>
        ]
..\llvm-project\flang\include\flang/Evaluate/rewrite.h(78): note: while trying to match the argument list '(Fortran::evaluate::Expr<Fortran::evaluate::SomeType>, const S)'
        with
        [
            S=Fortran::evaluate::value::Integer<128,true,32,unsigned int,unsigned __int64,128>
        ]
..\llvm-project\flang\include\flang/Evaluate/rewrite.h(78): note: the template instantiation context (the oldest one first) is
..\llvm-project\flang\lib\Semantics\check-omp-atomic.cpp(814): note: see reference to function template instantiation 'U Fortran::evaluate::rewrite::Mutator<Fortran::semantics::ReassocRewriter>::operator ()<const Fortran::evaluate::Expr<Fortran::evaluate::SomeType>&,Fortran::evaluate::Expr<Fortran::evaluate::SomeType>>(T)' being compiled
        with
        [
            U=Fortran::evaluate::Expr<Fortran::evaluate::SomeType>,
            T=const Fortran::evaluate::Expr<Fortran::evaluate::SomeType> &
        ]
```

The reason is that there is an ambiguity between operator() of
ReassocRewriter itself and operator() of the base class `Identity` through
`using Id::operator();`. By the C++ specification, method declarations
in ReassocRewriter hide methods with the same signature from a using
declaration, but this does not apply to
```
evaluate::Expr<T> operator()(..., NonIntegralTag = {})
```
which has a different signature due to an additional tag parameter.
Since it has a default value, it is ambiguous with operator() without
tag parameter.

GCC and Clang both accept this, but in my understanding MSVC is correct
here.

Since the overloads of ReassocRewriter cover all cases (integral and
non-integral), removing the using declaration to avoid the ambiguity.
@Meinersbur
Copy link
Member

Meinersbur commented Aug 14, 2025

I fixed a compilation break in 38853a0.

I would recommend using if constexpr () instead of tricks with overloads. I think using Id::operator() is intended as a fallaback when the derived class's method SFINAE fails, but seems only gcc accepts that. With the additional overload with tag, there is no case where overload resolution would fall back to using Id::operator().

Meinersbur added a commit to Meinersbur/llvm-project that referenced this pull request Aug 22, 2025
PR llvm#153488 caused the msvc build (https://lab.llvm.org/buildbot/#/builders/166/builds/1397) to fail:
```
..\llvm-project\flang\include\flang/Evaluate/rewrite.h(78): error C2668: 'Fortran::evaluate::rewrite::Identity::operator ()': ambiguous call to overloaded function
..\llvm-project\flang\include\flang/Evaluate/rewrite.h(43): note: could be 'Fortran::evaluate::Expr<Fortran::evaluate::SomeType> Fortran::evaluate::rewrite::Identity::operator ()<Fortran::evaluate::SomeType,S>(Fortran::evaluate::Expr<Fortran::evaluate::SomeType> &&,const U &)'
        with
        [
            S=Fortran::evaluate::value::Integer<128,true,32,unsigned int,unsigned __int64,128>,
            U=Fortran::evaluate::value::Integer<128,true,32,unsigned int,unsigned __int64,128>
        ]
..\llvm-project\flang\lib\Semantics\check-omp-atomic.cpp(174): note: or       'Fortran::evaluate::Expr<Fortran::evaluate::SomeType> Fortran::semantics::ReassocRewriter::operator ()<Fortran::evaluate::SomeType,S,void>(Fortran::evaluate::Expr<Fortran::evaluate::SomeType> &&,const U &,Fortran::semantics::ReassocRewriter::NonIntegralTag)'
        with
        [
            S=Fortran::evaluate::value::Integer<128,true,32,unsigned int,unsigned __int64,128>,
            U=Fortran::evaluate::value::Integer<128,true,32,unsigned int,unsigned __int64,128>
        ]
..\llvm-project\flang\include\flang/Evaluate/rewrite.h(78): note: while trying to match the argument list '(Fortran::evaluate::Expr<Fortran::evaluate::SomeType>, const S)'
        with
        [
            S=Fortran::evaluate::value::Integer<128,true,32,unsigned int,unsigned __int64,128>
        ]
..\llvm-project\flang\include\flang/Evaluate/rewrite.h(78): note: the template instantiation context (the oldest one first) is
..\llvm-project\flang\lib\Semantics\check-omp-atomic.cpp(814): note: see reference to function template instantiation 'U Fortran::evaluate::rewrite::Mutator<Fortran::semantics::ReassocRewriter>::operator ()<const Fortran::evaluate::Expr<Fortran::evaluate::SomeType>&,Fortran::evaluate::Expr<Fortran::evaluate::SomeType>>(T)' being compiled
        with
        [
            U=Fortran::evaluate::Expr<Fortran::evaluate::SomeType>,
            T=const Fortran::evaluate::Expr<Fortran::evaluate::SomeType> &
        ]
```

The reason is that there is an ambiguite between operator() of
ReassocRewriter itself as operator() of the base class Identity through
`using Id::operator();`. By the C++ specification, method declarations
in ReassocRewriter hide methods with the same signature from a using
declaration, but this does not apply to
```
evaluate::Expr<T> operator()(..., NonIntegralTag = {})
```
which has a different signature due to an additional tag parameter.
Since it has a default value, it is ambiguous with operator() without
tag parameter.

GCC and Clang both accept this, but in my understanding MSVC is correct
here.

Since the overloads of ReassocRewriter cover all cases, remopving the
using reclarations to avoid the ambiguity.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

flang:fir-hlfir flang:openmp flang:semantics flang Flang issues not falling into any other category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants