Skip to content

Conversation

Stylie777
Copy link
Contributor

Currently, Flang does not correctly lower ArrayElement's when processing an OpenMP Reduction Clause correctly. Rather than lowering the array element, the whole array will be lowered. This leads to slower performance for the end user in their program.

This patch works to rectify this by rewriteing the parse tree while processing semantics. The use of an ArrayElement in an OpenMP Reduction Clause will be identified, and replaced with a temporary both in the reduction clause, and anywhere that array element is used within the respective DoConstruct. Once the DoConstruct has finished, if the ArrayElement has been used within the Do loop, the value of the temporary will be re-assigned to the array element. One limitation of this approach is that if the ArrayElement is not used, there is no available element in the parse tree to use to reassign the value, so its only done if used.

The reason for making the change in the parse tree is due to how ArrayElements are lowered. When lowering, the expression of the ArrayElement being used in the reduction is being substitued with the reference to the symbol. In this case, that would be the whole array. By replacing it with a temporary, it removes the issue of lowering a full array as it will be referencing the temporary instead. To address this in lowering would require a major rethink on how a considerable amount of non-OpenMP code is lowered and as such, not deemed the appropriate course of action for this specific case.

This process is done after the initial Semantics Pass as to not affect the checking of users original code. If the array element has been replaced, the first pass of semantics will need to be rerun to ensure all TypedExpr's are correctly captured otherwise the lowering will not function correctly. This step is only done if an ArrayElement is replaced.

Testing is covered by reduction17.f90. This checks both the parse tree, unparsing and HLFIR to ensure the temproary is being used in the reduction clause and Do loop. Assignment to, and reassignment from the ArrayElement and the Temporary is also considered to ensure this is inserted at the correct location.

reduction09.f90 has also been reformatted to rely on FileCheck. As the Parse Tree is changing, the output is different to that of the user, so we can no longer rely on test_symbols.py for this test. The same information is being checked, with test cases that cover using an ArrayElement in the Do loop, and not using the ArrayElement being covered.

Array Sections are not affected by this change, only uses of single array elements.

Currently, Flang does not correctly lower ArrayElement's
when processing an OpenMP Reduction Clause correctly. Rather
than lowering the array element, the whole array will be lowered.
This leads to slower performance for the end user in their program.

This patch works to rectify this by rewriteing the parse tree
while processing semantics. The use of an ArrayElement in an
OpenMP Reduction Clause will be identified, and replaced with a
temporary both in the reduction clause, and anywhere that array
element is used within the respective DoConstruct. Once the
DoConstruct has finished, if the ArrayElement has been used within
the Do loop, the value of the temporary will be re-assigned to
the array element. One limitation of this approach is that if the
ArrayElement is not used, there is no available element in the parse
tree to use to reassign the value, so its only done if used.

The reason for making the change in the parse tree is due to how
ArrayElements are lowered. When lowering, the expression of the
ArrayElement being used in the reduction is being substitued with
the reference to the symbol. In this case, that would be the whole
array. By replacing it with a temporary, it removes the issue
of lowering a full array as it will be referencing the
temporary instead. To address this in lowering would require a
major rethink on how a considerable amount of non-OpenMP code
is lowered and as such, not deemed the appropriate course of
action for this specific case.

This process is done after the initial Semantics Pass as to
not affect the checking of users original code. If the array
element has been replaced, the first pass of semantics will
need to be rerun to ensure all TypedExpr's are correctly
captured otherwise the lowering will not function correctly.
This step is only done if an ArrayElement is replaced.

Testing is covered by reduction17.f90. This checks both the
parse tree, unparsing and HLFIR to ensure the temproary is being used in
the reduction clause and Do loop. Assignment to, and reassignment from
the ArrayElement and the Temporary is also considered to ensure this is
inserted at the correct location.

reduction09.f90 has also been reformatted to rely on FileCheck.
As the Parse Tree is changing, the output is different to that of
the user, so we can no longer rely on test_symbols.py for this test.
The same information is being checked, with test cases that cover using
an ArrayElement in the Do loop, and not using the ArrayElement being
covered.

Array Sections are not affected by this change, only uses of
single array elements.
@Stylie777 Stylie777 requested review from kparzysz and tblah October 17, 2025 10:51
@llvmbot llvmbot added flang Flang issues not falling into any other category flang:openmp flang:semantics labels Oct 17, 2025
@llvmbot
Copy link
Member

llvmbot commented Oct 17, 2025

@llvm/pr-subscribers-flang-openmp

@llvm/pr-subscribers-flang-semantics

Author: Jack Styles (Stylie777)

Changes

Currently, Flang does not correctly lower ArrayElement's when processing an OpenMP Reduction Clause correctly. Rather than lowering the array element, the whole array will be lowered. This leads to slower performance for the end user in their program.

This patch works to rectify this by rewriteing the parse tree while processing semantics. The use of an ArrayElement in an OpenMP Reduction Clause will be identified, and replaced with a temporary both in the reduction clause, and anywhere that array element is used within the respective DoConstruct. Once the DoConstruct has finished, if the ArrayElement has been used within the Do loop, the value of the temporary will be re-assigned to the array element. One limitation of this approach is that if the ArrayElement is not used, there is no available element in the parse tree to use to reassign the value, so its only done if used.

The reason for making the change in the parse tree is due to how ArrayElements are lowered. When lowering, the expression of the ArrayElement being used in the reduction is being substitued with the reference to the symbol. In this case, that would be the whole array. By replacing it with a temporary, it removes the issue of lowering a full array as it will be referencing the temporary instead. To address this in lowering would require a major rethink on how a considerable amount of non-OpenMP code is lowered and as such, not deemed the appropriate course of action for this specific case.

This process is done after the initial Semantics Pass as to not affect the checking of users original code. If the array element has been replaced, the first pass of semantics will need to be rerun to ensure all TypedExpr's are correctly captured otherwise the lowering will not function correctly. This step is only done if an ArrayElement is replaced.

Testing is covered by reduction17.f90. This checks both the parse tree, unparsing and HLFIR to ensure the temproary is being used in the reduction clause and Do loop. Assignment to, and reassignment from the ArrayElement and the Temporary is also considered to ensure this is inserted at the correct location.

reduction09.f90 has also been reformatted to rely on FileCheck. As the Parse Tree is changing, the output is different to that of the user, so we can no longer rely on test_symbols.py for this test. The same information is being checked, with test cases that cover using an ArrayElement in the Do loop, and not using the ArrayElement being covered.

Array Sections are not affected by this change, only uses of single array elements.


Patch is 46.64 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/163940.diff

5 Files Affected:

  • (modified) flang/lib/Semantics/rewrite-parse-tree.cpp (+524)
  • (modified) flang/lib/Semantics/rewrite-parse-tree.h (+2)
  • (modified) flang/lib/Semantics/semantics.cpp (+12-1)
  • (modified) flang/test/Semantics/OpenMP/reduction09.f90 (+90-19)
  • (added) flang/test/Semantics/OpenMP/reduction17.f90 (+209)
diff --git a/flang/lib/Semantics/rewrite-parse-tree.cpp b/flang/lib/Semantics/rewrite-parse-tree.cpp
index 5b7dab309eda7..5379dcdd3d40c 100644
--- a/flang/lib/Semantics/rewrite-parse-tree.cpp
+++ b/flang/lib/Semantics/rewrite-parse-tree.cpp
@@ -95,6 +95,86 @@ class RewriteMutator {
   parser::Messages &messages_;
 };
 
+class ReplacementTemp {
+public:
+  ReplacementTemp() {}
+
+  void createTempSymbol(
+      SourceName &source, Scope &scope, SemanticsContext &context);
+  void setOriginalSubscriptInt(
+      std::list<parser::SectionSubscript> &sectionSubscript);
+  Symbol *getTempSymbol() { return replacementTempSymbol_; }
+  Symbol *getPrivateReductionSymbol() { return privateReductionSymbol_; }
+  parser::CharBlock getOriginalSource() { return originalSource_; }
+  parser::Name getOriginalName() { return originalName_; }
+  parser::CharBlock getOriginalSubscript() {
+    return originalSubscriptCharBlock_;
+  }
+  Scope *getTempScope() { return tempScope_; }
+  bool isArrayElementReassigned() { return arrayElementReassigned_; }
+  bool isSectionTriplet() { return isSectionTriplet_; }
+  void arrayElementReassigned() { arrayElementReassigned_ = true; }
+  void setOriginalName(parser::Name &name) {
+    originalName_ = common::Clone(name);
+  }
+  void setOriginalSource(parser::CharBlock &source) {
+    originalSource_ = source;
+  }
+  void setOriginalSubscriptInt(parser::CharBlock &subscript) {
+    originalSubscriptCharBlock_ = subscript;
+  }
+  void setTempScope(Scope &scope) { tempScope_ = &scope; };
+  void setTempSymbol(Symbol *symbol) { replacementTempSymbol_ = symbol; }
+
+private:
+  Symbol *replacementTempSymbol_{nullptr};
+  Symbol *privateReductionSymbol_{nullptr};
+  Scope *tempScope_{nullptr};
+  parser::CharBlock originalSource_;
+  parser::Name originalName_;
+  parser::CharBlock originalSubscriptCharBlock_;
+  bool arrayElementReassigned_{false};
+  bool isSectionTriplet_{false};
+};
+
+class RewriteOmpReductionArrayElements {
+public:
+  explicit RewriteOmpReductionArrayElements(SemanticsContext &context)
+      : context_(context) {}
+  // Default action for a parse tree node is to visit children.
+  template <typename T> bool Pre(T &) { return true; }
+  template <typename T> void Post(T &) {}
+
+  void Post(parser::Block &block);
+  void Post(parser::Variable &var);
+  void Post(parser::Expr &expr);
+  void Post(parser::AssignmentStmt &assignmentStmt);
+  void Post(parser::PointerAssignmentStmt &ptrAssignmentStmt);
+  void rewriteReductionArrayElementToTemp(parser::Block &block);
+  bool isArrayElementRewritten() { return arrayElementReassigned_; }
+
+private:
+  bool isMatchingArrayElement(parser::Designator &existingDesignator);
+  template <typename T>
+  void processFunctionReference(
+      T &node, parser::CharBlock source, parser::FunctionReference &funcRef);
+  parser::Designator makeTempDesignator(parser::CharBlock source);
+  bool rewriteArrayElementToTemp(parser::Block::iterator &it,
+      parser::OpenMPLoopConstruct &ompLoop, parser::Block &block,
+      ReplacementTemp &temp);
+  bool identifyArrayElementReduced(
+      parser::Designator &designator, ReplacementTemp &temp);
+  void reassignTempValueToArrayElement(parser::ArrayElement &arrayElement);
+  void setCurrentTemp(ReplacementTemp *temp) { currentTemp_ = temp; }
+  void resetCurrentTemp() { currentTemp_ = nullptr; }
+
+  SemanticsContext &context_;
+  bool arrayElementReassigned_{false};
+  parser::Block::iterator reassignmentInsertionPoint_;
+  parser::Block *block_{nullptr};
+  ReplacementTemp *currentTemp_{nullptr};
+};
+
 // Check that name has been resolved to a symbol
 void RewriteMutator::Post(parser::Name &name) {
   if (!name.symbol && errorOnUnresolvedName_) {
@@ -492,10 +572,454 @@ void RewriteMutator::Post(parser::WriteStmt &x) {
   FixMisparsedUntaggedNamelistName(x);
 }
 
+void ReplacementTemp::createTempSymbol(
+    SourceName &source, Scope &scope, SemanticsContext &context) {
+  replacementTempSymbol_ =
+      const_cast<semantics::Scope &>(originalName_.symbol->owner())
+          .FindSymbol(source);
+  replacementTempSymbol_->set_scope(
+      &const_cast<semantics::Scope &>(originalName_.symbol->owner()));
+  DeclTypeSpec *tempType = originalName_.symbol->GetUltimate().GetType();
+  replacementTempSymbol_->get<ObjectEntityDetails>().set_type(*tempType);
+  replacementTempSymbol_->flags().set(Symbol::Flag::CompilerCreated);
+}
+
+void ReplacementTemp::setOriginalSubscriptInt(
+    std::list<parser::SectionSubscript> &sectionSubscript) {
+  bool setSubscript{false};
+  for (parser::SectionSubscript &subscript : sectionSubscript) {
+    std::visit(llvm::makeVisitor(
+                   [&](parser::IntExpr &intExpr) {
+                     parser::Expr &expr = intExpr.thing.value();
+                     std::visit(
+                         llvm::makeVisitor(
+                             [&](parser::LiteralConstant &literalContant) {
+                               std::visit(llvm::makeVisitor(
+                                              [&](parser::IntLiteralConstant
+                                                      &intLiteralConstant) {
+                                                originalSubscriptCharBlock_ =
+                                                    std::get<parser::CharBlock>(
+                                                        intLiteralConstant.t);
+                                                setSubscript = true;
+                                              },
+                                              [&](auto &) {}),
+                                   literalContant.u);
+                             },
+                             [&](auto &) {}),
+                         expr.u);
+                   },
+                   [&](parser::SubscriptTriplet &triplet) {
+                     isSectionTriplet_ = true;
+                     setSubscript = true;
+                   },
+                   [&](auto &) {}),
+        subscript.u);
+    if (setSubscript) {
+      break;
+    }
+  }
+}
+
+void RewriteOmpReductionArrayElements::rewriteReductionArrayElementToTemp(
+    parser::Block &block) {
+  if (block.empty()) {
+    return;
+  }
+
+  for (auto it{block.begin()}; it != block.end(); ++it) {
+    std::visit(
+        llvm::makeVisitor(
+            [&](parser::ExecutableConstruct &execConstruct) {
+              std::visit(
+                  llvm::makeVisitor(
+                      [&](common::Indirection<parser::OpenMPConstruct>
+                              &ompConstruct) {
+                        std::visit(
+                            llvm::makeVisitor(
+                                [&](parser::OpenMPLoopConstruct &ompLoop) {
+                                  ReplacementTemp temp;
+                                  if (!rewriteArrayElementToTemp(
+                                          it, ompLoop, block, temp)) {
+                                    return;
+                                  }
+                                  auto &NestedConstruct = std::get<
+                                      std::optional<parser::NestedConstruct>>(
+                                      ompLoop.t);
+                                  if (!NestedConstruct.has_value()) {
+                                    return;
+                                  }
+                                  if (parser::DoConstruct *
+                                      doConst{std::get_if<parser::DoConstruct>(
+                                          &NestedConstruct.value())}) {
+                                    block_ = &block;
+                                    parser::Block &doBlock{
+                                        std::get<parser::Block>(doConst->t)};
+                                    parser::Walk(doBlock, *this);
+                                    // Reset the current temp value so future
+                                    // iterations use their own version.
+                                    resetCurrentTemp();
+                                  }
+                                },
+                                [&](auto &) {}),
+                            ompConstruct.value().u);
+                      },
+                      [&](auto &) {}),
+                  execConstruct.u);
+            },
+            [&](auto &) {}),
+        it->u);
+  }
+}
+
+bool RewriteOmpReductionArrayElements::isMatchingArrayElement(
+    parser::Designator &existingDesignator) {
+  bool matchesArrayElement{false};
+  std::list<parser::SectionSubscript> *subscripts{nullptr};
+
+  std::visit(llvm::makeVisitor(
+                 [&](parser::DataRef &dataRef) {
+                   std::visit(
+                       llvm::makeVisitor(
+                           [&](common::Indirection<parser::ArrayElement>
+                                   &arrayElement) {
+                             subscripts = &arrayElement.value().subscripts;
+                             std::visit(
+                                 llvm::makeVisitor(
+                                     [&](parser::Name &name) {
+                                       if (name.symbol->GetUltimate() ==
+                                           currentTemp_->getOriginalName()
+                                               .symbol->GetUltimate()) {
+                                         matchesArrayElement = true;
+                                         if (!currentTemp_
+                                                 ->isArrayElementReassigned()) {
+                                           reassignTempValueToArrayElement(
+                                               arrayElement.value());
+                                         }
+                                       }
+                                     },
+                                     [](auto &) {}),
+                                 arrayElement.value().base.u);
+                           },
+                           [&](parser::Name &name) {
+                             if (name.symbol->GetUltimate() ==
+                                 currentTemp_->getOriginalName()
+                                     .symbol->GetUltimate()) {
+                               matchesArrayElement = true;
+                             }
+                           },
+                           [](auto &) {}),
+                       dataRef.u);
+                 },
+                 [&](auto &) {}),
+      existingDesignator.u);
+
+  if (subscripts) {
+    bool foundSubscript{false};
+    for (parser::SectionSubscript &subscript : *subscripts) {
+      matchesArrayElement = std::visit(
+          llvm::makeVisitor(
+              [&](parser::IntExpr &intExpr) -> bool {
+                parser::Expr &expr = intExpr.thing.value();
+                return std::visit(
+                    llvm::makeVisitor(
+                        [&](parser::LiteralConstant &literalContant) -> bool {
+                          return std::visit(
+                              llvm::makeVisitor(
+                                  [&](parser::IntLiteralConstant
+                                          &intLiteralConstant) -> bool {
+                                    foundSubscript = true;
+                                    assert(currentTemp_ != nullptr &&
+                                        "Value for ReplacementTemp should have "
+                                        "been found");
+                                    if (std::get<parser::CharBlock>(
+                                            intLiteralConstant.t) ==
+                                        currentTemp_->getOriginalSubscript()) {
+                                      return true;
+                                    }
+                                    return false;
+                                  },
+                                  [](auto &) -> bool { return false; }),
+                              literalContant.u);
+                        },
+                        [](auto &) -> bool { return false; }),
+                    expr.u);
+              },
+              [](auto &) -> bool { return false; }),
+          subscript.u);
+      if (foundSubscript) {
+        break;
+      }
+    }
+  }
+  return matchesArrayElement;
+}
+
+template <typename T>
+void RewriteOmpReductionArrayElements::processFunctionReference(
+    T &node, parser::CharBlock source, parser::FunctionReference &funcRef) {
+  auto &[procedureDesignator, ArgSpecList] = funcRef.v.t;
+  std::optional<parser::Designator> arrayElementDesignator =
+      std::visit(llvm::makeVisitor(
+                     [&](parser::Name &functionReferenceName)
+                         -> std::optional<parser::Designator> {
+                       if (currentTemp_->getOriginalName().symbol ==
+                           functionReferenceName.symbol) {
+                         return funcRef.ConvertToArrayElementRef();
+                       }
+                       return std::nullopt;
+                     },
+                     [&](auto &) -> std::optional<parser::Designator> {
+                       return std::nullopt;
+                     }),
+          procedureDesignator.u);
+
+  if (arrayElementDesignator.has_value()) {
+    if (this->isMatchingArrayElement(arrayElementDesignator.value())) {
+      node = T{
+          common::Indirection<parser::Designator>{makeTempDesignator(source)}};
+    }
+  }
+}
+
+parser::Designator RewriteOmpReductionArrayElements::makeTempDesignator(
+    parser::CharBlock source) {
+  parser::Name tempVariableName{currentTemp_->getTempSymbol()->name()};
+  tempVariableName.symbol = currentTemp_->getTempSymbol();
+  parser::Designator tempDesignator{
+      parser::DataRef{std::move(tempVariableName)}};
+  tempDesignator.source = source;
+  return tempDesignator;
+}
+
+bool RewriteOmpReductionArrayElements::rewriteArrayElementToTemp(
+    parser::Block::iterator &it, parser::OpenMPLoopConstruct &ompLoop,
+    parser::Block &block, ReplacementTemp &temp) {
+  parser::OmpBeginLoopDirective &ompBeginLoop{
+      std::get<parser::OmpBeginLoopDirective>(ompLoop.t)};
+  std::list<parser::OmpClause> &clauseList{
+      std::get<std::optional<parser::OmpClauseList>>(ompBeginLoop.t)->v};
+  bool rewrittenArrayElement{false};
+
+  for (auto iter{clauseList.begin()}; iter != clauseList.end(); ++iter) {
+    rewrittenArrayElement = std::visit(
+        llvm::makeVisitor(
+            [&](parser::OmpClause::Reduction &clause) -> bool {
+              std::list<parser::OmpObject> &objectList =
+                  std::get<parser::OmpObjectList>(clause.v.t).v;
+
+              bool rewritten{false};
+              for (auto object{objectList.begin()}; object != objectList.end();
+                  ++object) {
+                rewritten = std::visit(
+                    llvm::makeVisitor(
+                        [&](parser::Designator &designator) -> bool {
+                          if (!identifyArrayElementReduced(designator, temp)) {
+                            return false;
+                          }
+                          if (temp.isSectionTriplet()) {
+                            return false;
+                          }
+
+                          reassignmentInsertionPoint_ =
+                              it != block.end() ? it : block.end();
+                          std::string tempSourceString = "reduction_temp_" +
+                              temp.getOriginalSource().ToString() + "(" +
+                              temp.getOriginalSubscript().ToString() + ")";
+                          SourceName source = context_.SaveTempName(
+                              std::move(tempSourceString));
+                          Scope &scope = const_cast<Scope &>(
+                              temp.getOriginalName().symbol->owner());
+                          if (Symbol * symbol{scope.FindSymbol(source)}) {
+                            temp.setTempSymbol(symbol);
+                          } else {
+                            if (scope
+                                    .try_emplace(source, semantics::Attrs{},
+                                        semantics::ObjectEntityDetails{})
+                                    .second) {
+                              temp.createTempSymbol(source, scope, context_);
+                            } else {
+                              common::die("Failed to create temp symbol for %s",
+                                  source.ToString().c_str());
+                            }
+                          }
+                          setCurrentTemp(&temp);
+                          temp.setTempScope(scope);
+
+                          // Assign the value of the array element to the
+                          // temporary variable
+                          parser::Variable newVariable{
+                              makeTempDesignator(temp.getOriginalSource())};
+                          parser::Expr newExpr{
+                              common::Indirection<parser::Designator>{
+                                  std::move(designator)}};
+                          newExpr.source = temp.getOriginalSource();
+                          std::tuple<parser::Variable, parser::Expr> newT{
+                              std::move(newVariable), std::move(newExpr)};
+                          parser::AssignmentStmt assignment{std::move(newT)};
+                          parser::ExecutionPartConstruct
+                              tempVariablePartConstruct{
+                                  parser::ExecutionPartConstruct{
+                                      parser::ExecutableConstruct{
+                                          parser::Statement<parser::ActionStmt>{
+                                              std::optional<parser::Label>{},
+                                              std::move(assignment)}}}};
+                          block.insert(
+                              it, std::move(tempVariablePartConstruct));
+                          arrayElementReassigned_ = true;
+
+                          designator =
+                              makeTempDesignator(temp.getOriginalSource());
+                          return true;
+                        },
+                        [&](const auto &) -> bool { return false; }),
+                    object->u);
+              }
+              return rewritten;
+            },
+            [&](auto &) -> bool { return false; }),
+        iter->u);
+
+    if (rewrittenArrayElement) {
+      return rewrittenArrayElement;
+    }
+  }
+  return rewrittenArrayElement;
+}
+
+bool RewriteOmpReductionArrayElements::identifyArrayElementReduced(
+    parser::Designator &designator, ReplacementTemp &temp) {
+  return std::visit(
+      llvm::makeVisitor(
+          [&](parser::DataRef &dataRef) -> bool {
+            return std::visit(
+                llvm::makeVisitor(
+                    [&](common::Indirection<parser::ArrayElement>
+                            &arrayElement) {
+                      std::visit(llvm::makeVisitor(
+                                     [&](parser::Name &name) -> void {
+                                       temp.setOriginalName(name);
+                                       temp.setOriginalSource(name.source);
+                                     },
+                                     [&](auto &) -> void {}),
+                          arrayElement.value().base.u);
+                      temp.setOriginalSubscriptInt(
+                          arrayElement.value().subscripts);
+                      return !temp.isSectionTriplet() ? true : false;
+                    },
+                    [&](auto &) -> bool { return false; }),
+                dataRef.u);
+          },
+          [&](auto &) -> bool { return false; }),
+      designator.u);
+}
+
+void RewriteOmpReductionArrayElements::reassignTempValueToArrayElement(
+    parser::ArrayElement &arrayElement) {
+  assert(block_ && "Need iterator to reassign");
+  parser::CharBlock originalSource = currentTemp_->getOriginalSource();
+  parser::DataRef reassignmentDataRef{std::move(arrayElement)};
+  common::Indirection<parser::Designator> arrayElementDesignator{
+      std::move(reassignmentDataRef)};
+  arrayElementDesignator.value().source = originalSource;
+  parser::Variable exisitingVar{std::move(arrayElementDesignator)};
+  std::get<c...
[truncated]

@Stylie777 Stylie777 changed the title [flang][OpenMP] Fix reduction of Scalar ArrayElement types [flang][OpenMP] Improve reduction of Scalar ArrayElement types Oct 17, 2025
@kiranchandramohan
Copy link
Contributor

I forgot what was the difficulty here in reduction. I suspect that we faced a similar issue in OpenMP atomic (example source below) and it was handled during lowering.

integer :: j(2)
!$omp atomic
j(1) = j(1) + 1
!$omp end atomic
end

In other words, is it not possible to get a reference to the array element and then use that as the value in the reduction clause's argument in MLIR?

      %6:2 = hlfir.declare %arg0 {uniq_name = "_QFEa"} : (!fir.ref<!fir.box<!fir.array<2xi32>>>) -> (!fir.ref<!fir.box<!fir.array<2xi32>>>, !fir.ref<!fir.box<!fir.array<2xi32>>>)
      %7 = fir.load %6#0 : !fir.ref<!fir.box<!fir.array<2xi32>>>
      %c1 = arith.constant 1 : index
      %8 = hlfir.designate %7 (%c1)  : (!fir.box<!fir.array<2xi32>>, index) -> !fir.ref<i3

Something like the following:

omp.parallel reduction(byref @add_reduction_byref_i32 %8 -> %a1 : !fir.ref<i32>)

Copy link
Contributor

@kparzysz kparzysz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some initial comments.

The large indentations in some places make the code harder to read. Could you extract some of the visitor lambdas into separate entities? e.g.

std::visit(
  [&](...) {
    [&] (...) {
    }
  }
)

to

auto foo = [&] (...) {
};
auto bar = [&] (...) {
  foo
};
std::visit(...bar...);

Or something like that to reduce the amount of leading whitespace.

@kparzysz
Copy link
Contributor

I forgot what was the difficulty here in reduction.

I'm also not exactly sure what the issue is. When you (Jack) say that the whole array is lowered, what do you mean exactly? Could you maybe share a small example to illustrate the problem?

@kparzysz
Copy link
Contributor

One limitation of this approach is that if the ArrayElement is not used, there is no available element in the parse tree to use to reassign the value, so its only done if used.

Correct me if I'm reading the code wrong, but it seems like you do the reassignment back to the original array only when that element is used, but you do the replacement with a scalar in all cases. If so, that will produce incorrect code:

subroutine foo(a, b)
  integer, intent(inout) :: a(10), b(10)
  integer :: i
  !$omp parallel do reduction(+: a(3))
  do i = 1, 10
    a(3) = a(3) + b(i)  ! update visible to callers of foo
  end do
  !$omp end parallel do
end

@Stylie777
Copy link
Contributor Author

@kiranchandramohan From my understanding addressing this in Lowering is not possible. An array element is an expression and the reference to the Symbol is used, which in this case would be the full array. This is ok for Array Sections, but if the user is only using 1 single array element, it does not make sense to reduce the full array, hence the rewrite to the temp so it can pick up a symbol which is just a single value.

With this rewrite, the MLIR example you provided with a single integer being reduced is what ends up being generated, but its not possible to do this using the array element itself because of the process I explained above.

@kparzysz an example of the lowering as it is would be, say for a small program

program test
    integer :: a(2)
    integer :: i

    !$omp do reduction(+:a(2))
    do i=1,5
        a(2) = a(2) + i
    end do
    !$omp end do

end program test

When the reduction appears in High Level FIR, its generated as

reduction(byref @add_reduction_byref_box_2xi32 %7 -> %arg1 : !fir.ref<!fir.box<!fir.array<2xi32>>>)

As the user only needs 1 element of the array, it does not make sense to use the whole array in the reduction. This patch would then turn that into the following, with the assignment to and from the temp in the relevant places:

reduction(@add_reduction_i32 %7#0 -> %arg1 : !fir.ref<i32>) {

With the example you have given in the subroutine foo(a, b), that does reassign after the loop correctly as it will detect the use of the array element within the loop and use one of those occurrences on ArrayElement in the Parse Tree to build the reassignment Expression. The issue I have is because I am moving the use of an ArrayElement for the reassignment Expression, if it's not used within the do loop, I cannot make it. But then if it's not used within the loop, do we need to reassign it back? Probably not which was the justification for taking this approach because in reality the value would not change anyway if you reduced the whole array.

@tblah
Copy link
Contributor

tblah commented Oct 17, 2025

When generating the body of the construct containing the reduction clause, we need to ensure that the references to the reduced value use the private reduction variable and not some host variable. The way this works for reductions (and I think privatisation) is that we create a host associated symbol for the variable being reduced and that shadows the original shared copy: ensuring that lowering of references to that symbol inside of the body instead point to the reduction block argument.

This works well when whole symbols are reduced, but not when an expression is what is reduced. We don't have a convenient way to instruct lowering to replace all instances of a parse tree expression with a new value (the reduction block argument). I couldn't think of a way to add that without changing a lot of non-openmp code. This work around is to rewrite something like

!$omp parallel reduction(+: a(10))
...
a(10) = a(10) + 10 
!$omp end parallel

into

tmp = a(10)
!$omp parallel reduction(+: tmp)
...
tmp = tmp + 10 
!$omp end parallel
a(10) = tmp

This is a first step towards supporting reductions on components of derived types (which will be handled with a temporary in a similar way).

@kparzysz
Copy link
Contributor

I'm ok with this approach, but I think we need to develop a solution in lowering in the longer term.

I'm concerned that doing something like the example above could potentially introduce subtle race conditions

tmp = a(10)
!$omp parallel reduction(+: tmp)
...
tmp = tmp + 10 
!$omp end parallel  ! a(10) would normally be up-to-date here on exit from the parallel region
a(10) = tmp         ! but it isn't until here

I think we can probably get away with this for now.

collate all std::visit's into lammbda functions to make them easier
to read.
Copy link
Contributor Author

@Stylie777 Stylie777 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kparzysz thanks for the initial review. I have also gone through and extracted the std::visit parts of the code into lambda functions so it is easier to follow and review as you suggested.

@kiranchandramohan
Copy link
Contributor

I haven't looked at this in detail. FYI: an expression to mlir::value map was added in #69944 for simplifying codegen for atomic.

Copy link
Contributor

@kparzysz kparzysz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, it looks a lot clearer now.

This code will only handle one reduction variable per reduction clause. For example, given

subroutine f00(x, y)
  integer :: x(:), y(:)
  integer :: i
  !$omp parallel do reduction(+: x(2), y(3))
  do i = 1, 200
    x(2) = x(2) + i
    y(3) = 6 - y(3)
  end do
  !$omp end parallel do
end

it generates

SUBROUTINE f00 (x, y)
 INTEGER x(:), y(:)
 INTEGER i
  reduction_temp_x(2)=x(2_8)
  reduction_temp_y(3)=y(3_8)
!$OMP PARALLEL DO REDUCTION(+: reduction_temp_x(2),reduction_temp_y(3))
 DO i=1_4,200_4
   x(2_8)=x(2_8)+i    ! unmodified
   reduction_temp_y(3)=6_4-reduction_temp_y(3)
 END DO
!$OMP END PARALLEL DO
  y(3_8)=reduction_temp_y(3)
END SUBROUTINE

leaving x(2) = x(2) + i unmodified. Is that a known limitation?

Also, if you modify the testcase a bit:

subroutine f00(x, y, j)
  integer :: x(:), y(:)
  integer :: i, j
  !$omp parallel do reduction(+: x(2), y(j))
  do i = 1, 200
    x(2) = x(2) + i
    y(j) = 6 - y(j)
  end do
  !$omp end parallel do
end

you get incorrect output:

SUBROUTINE f00 (x, y, j)
 INTEGER x(:), y(:)
 INTEGER i, j
  reduction_temp_x(2)=x(2_8)
  reduction_temp_y(2)=y(int(j,kind=8))
!$OMP PARALLEL DO REDUCTION(+: reduction_temp_x(2),reduction_temp_y(2))
 DO i=1_4,200_4
   reduction_temp_y(2)=reduction_temp_y(2)+i
   reduction_temp_y(2)=6_4-y(int(j,kind=8))
 END DO
!$OMP END PARALLEL DO
  y(int(j,kind=8))=reduction_temp_y(2)
END SUBROUTINE

}
}
};
auto visitArratElement = [&](parser::ArrayElement &arrayElement) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo: arrat

Comment on lines +816 to +819
parser::OmpBeginLoopDirective &ompBeginLoop{
std::get<parser::OmpBeginLoopDirective>(ompLoop.t)};
std::list<parser::OmpClause> &clauseList{
std::get<std::optional<parser::OmpClauseList>>(ompBeginLoop.t)->v};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In case you didn't know, OpenMPLoopConstruct has a function "BeginDir" that returns a reference to the contained "OmpBeginLoopDirective". The OmpBeginLoopDirective inherits from OmpDirectiveSpecification that has member function Clauses() that returns the list of clauses.
You could do

auto &clauseList{const_cast<std::list<parser::OmpClause>>(ompLoop).BeginDir().Clauses().v};

or

auto &clauseList{const_cast<parser::OmpClauseList>(ompLoop).BeginDir().Clauses()};

and use clauseList.v.

Comment on lines +869 to +870
std::list<parser::OmpObject> &objectList =
std::get<parser::OmpObjectList>(clause.v.t).v;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI, there is a function in flang/include/flang/Parser/openmp-utils.h that gets OmpObjectList from any clause (or returns nullptr if the clause doesn't have one):
const OmpObjectList *GetOmpObjectList(const OmpClause &clause);

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

flang:openmp flang:semantics flang Flang issues not falling into any other category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants