From 59ca403b295c706dbd15d4291b7fd34591403164 Mon Sep 17 00:00:00 2001 From: Krzysztof Parzyszek Date: Fri, 8 Aug 2025 14:46:40 -0500 Subject: [PATCH 1/4] [flang][OpenMP] Refactor creating atomic analysis, NFC Turn it into a class that combines the information and generates the analysis instead of having independent functions do it. --- flang/lib/Semantics/check-omp-atomic.cpp | 161 ++++++++++++++--------- 1 file changed, 97 insertions(+), 64 deletions(-) diff --git a/flang/lib/Semantics/check-omp-atomic.cpp b/flang/lib/Semantics/check-omp-atomic.cpp index fcb0f9ad1e25d..a5fe820b1069b 100644 --- a/flang/lib/Semantics/check-omp-atomic.cpp +++ b/flang/lib/Semantics/check-omp-atomic.cpp @@ -222,47 +222,77 @@ static void SetAssignment(parser::AssignmentStmt::TypedAssignment &assign, } } -static parser::OpenMPAtomicConstruct::Analysis::Op MakeAtomicAnalysisOp( - int what, - const std::optional &maybeAssign = std::nullopt) { - parser::OpenMPAtomicConstruct::Analysis::Op operation; - operation.what = what; - SetAssignment(operation.assign, maybeAssign); - return operation; -} +namespace { +struct AtomicAnalysis { + AtomicAnalysis(const SomeExpr &atom, const MaybeExpr &cond = std::nullopt) + : atom_(atom), cond_(cond) {} + + AtomicAnalysis &addOp0(int what, + const std::optional &maybeAssign = std::nullopt) { + return addOp(op0_, what, maybeAssign); + } + AtomicAnalysis &addOp1(int what, + const std::optional &maybeAssign = std::nullopt) { + return addOp(op1_, what, maybeAssign); + } -static parser::OpenMPAtomicConstruct::Analysis MakeAtomicAnalysis( - const SomeExpr &atom, const MaybeExpr &cond, - parser::OpenMPAtomicConstruct::Analysis::Op &&op0, - parser::OpenMPAtomicConstruct::Analysis::Op &&op1) { - // Defined in flang/include/flang/Parser/parse-tree.h - // - // struct Analysis { - // struct Kind { - // static constexpr int None = 0; - // static constexpr int Read = 1; - // static constexpr int Write = 2; - // static constexpr int Update = Read | Write; - // static constexpr int Action = 3; // Bits containing N, R, W, U - // static constexpr int IfTrue = 4; - // static constexpr int IfFalse = 8; - // static constexpr int Condition = 12; // Bits containing IfTrue, IfFalse - // }; - // struct Op { - // int what; - // TypedAssignment assign; - // }; - // TypedExpr atom, cond; - // Op op0, op1; - // }; - - parser::OpenMPAtomicConstruct::Analysis an; - SetExpr(an.atom, atom); - SetExpr(an.cond, cond); - an.op0 = std::move(op0); - an.op1 = std::move(op1); - return an; -} + operator parser::OpenMPAtomicConstruct::Analysis() const { + // Defined in flang/include/flang/Parser/parse-tree.h + // + // struct Analysis { + // struct Kind { + // static constexpr int None = 0; + // static constexpr int Read = 1; + // static constexpr int Write = 2; + // static constexpr int Update = Read | Write; + // static constexpr int Action = 3; // Bits containing None, Read, + // // Write, Update + // static constexpr int IfTrue = 4; + // static constexpr int IfFalse = 8; + // static constexpr int Condition = 12; // Bits containing IfTrue, + // // IfFalse + // }; + // struct Op { + // int what; + // TypedAssignment assign; + // }; + // TypedExpr atom, cond; + // Op op0, op1; + // }; + + parser::OpenMPAtomicConstruct::Analysis an; + SetExpr(an.atom, atom_); + SetExpr(an.cond, cond_); + an.op0 = std::move(op0_); + an.op1 = std::move(op1_); + return an; + } + +private: + struct Op { + operator parser::OpenMPAtomicConstruct::Analysis::Op() const { + parser::OpenMPAtomicConstruct::Analysis::Op op; + op.what = what; + SetAssignment(op.assign, assign); + return op; + } + + int what; + std::optional assign; + }; + + AtomicAnalysis &addOp(Op &op, int what, + const std::optional &maybeAssign) { + op.what = what; + op.assign = maybeAssign; + return *this; + } + + const SomeExpr &atom_; + const MaybeExpr &cond_; + Op op0_, op1_; +}; +} // namespace /// Check if `expr` satisfies the following conditions for x and v: /// @@ -805,9 +835,9 @@ void OmpStructureChecker::CheckAtomicUpdateOnly( CheckAtomicUpdateAssignment(*maybeUpdate, action.source); using Analysis = parser::OpenMPAtomicConstruct::Analysis; - x.analysis = MakeAtomicAnalysis(atom, std::nullopt, - MakeAtomicAnalysisOp(Analysis::Update, maybeUpdate), - MakeAtomicAnalysisOp(Analysis::None)); + x.analysis = AtomicAnalysis(atom) + .addOp0(Analysis::Update, maybeUpdate) + .addOp1(Analysis::None); } else if (!IsAssignment(action.stmt)) { context_.Say( source, "ATOMIC UPDATE operation should be an assignment"_err_en_US); @@ -889,9 +919,11 @@ void OmpStructureChecker::CheckAtomicConditionalUpdate( } using Analysis = parser::OpenMPAtomicConstruct::Analysis; - x.analysis = MakeAtomicAnalysis(assign.lhs, update.cond, - MakeAtomicAnalysisOp(Analysis::Update | Analysis::IfTrue, assign), - MakeAtomicAnalysisOp(Analysis::None)); + const SomeExpr &atom{assign.lhs}; + + x.analysis = AtomicAnalysis(atom, update.cond) + .addOp0(Analysis::Update | Analysis::IfTrue, assign) + .addOp1(Analysis::None); } void OmpStructureChecker::CheckAtomicUpdateCapture( @@ -936,13 +968,13 @@ void OmpStructureChecker::CheckAtomicUpdateCapture( } if (GetActionStmt(&body.front()).stmt == uact.stmt) { - x.analysis = MakeAtomicAnalysis(atom, std::nullopt, - MakeAtomicAnalysisOp(action, update), - MakeAtomicAnalysisOp(Analysis::Read, capture)); + x.analysis = AtomicAnalysis(atom) + .addOp0(action, update) + .addOp1(Analysis::Read, capture); } else { - x.analysis = MakeAtomicAnalysis(atom, std::nullopt, - MakeAtomicAnalysisOp(Analysis::Read, capture), - MakeAtomicAnalysisOp(action, update)); + x.analysis = AtomicAnalysis(atom) + .addOp0(Analysis::Read, capture) + .addOp1(action, update); } } @@ -1087,15 +1119,16 @@ void OmpStructureChecker::CheckAtomicConditionalUpdateCapture( evaluate::Assignment updAssign{*GetEvaluateAssignment(update.ift.stmt)}; evaluate::Assignment capAssign{*GetEvaluateAssignment(capture.stmt)}; + const SomeExpr &atom{updAssign.lhs}; if (captureFirst) { - x.analysis = MakeAtomicAnalysis(updAssign.lhs, update.cond, - MakeAtomicAnalysisOp(Analysis::Read | captureWhen, capAssign), - MakeAtomicAnalysisOp(Analysis::Write | updateWhen, updAssign)); + x.analysis = AtomicAnalysis(atom, update.cond) + .addOp0(Analysis::Read | captureWhen, capAssign) + .addOp1(Analysis::Write | updateWhen, updAssign); } else { - x.analysis = MakeAtomicAnalysis(updAssign.lhs, update.cond, - MakeAtomicAnalysisOp(Analysis::Write | updateWhen, updAssign), - MakeAtomicAnalysisOp(Analysis::Read | captureWhen, capAssign)); + x.analysis = AtomicAnalysis(atom, update.cond) + .addOp0(Analysis::Write | updateWhen, updAssign) + .addOp1(Analysis::Read | captureWhen, capAssign); } } @@ -1125,9 +1158,9 @@ void OmpStructureChecker::CheckAtomicRead( if (auto maybe{GetConvertInput(maybeRead->rhs)}) { const SomeExpr &atom{*maybe}; using Analysis = parser::OpenMPAtomicConstruct::Analysis; - x.analysis = MakeAtomicAnalysis(atom, std::nullopt, - MakeAtomicAnalysisOp(Analysis::Read, maybeRead), - MakeAtomicAnalysisOp(Analysis::None)); + x.analysis = AtomicAnalysis(atom) + .addOp0(Analysis::Read, maybeRead) + .addOp1(Analysis::None); } } else if (!IsAssignment(action.stmt)) { context_.Say( @@ -1159,9 +1192,9 @@ void OmpStructureChecker::CheckAtomicWrite( CheckAtomicWriteAssignment(*maybeWrite, action.source); using Analysis = parser::OpenMPAtomicConstruct::Analysis; - x.analysis = MakeAtomicAnalysis(atom, std::nullopt, - MakeAtomicAnalysisOp(Analysis::Write, maybeWrite), - MakeAtomicAnalysisOp(Analysis::None)); + x.analysis = AtomicAnalysis(atom) + .addOp0(Analysis::Write, maybeWrite) + .addOp1(Analysis::None); } else if (!IsAssignment(action.stmt)) { context_.Say( x.source, "ATOMIC WRITE operation should be an assignment"_err_en_US); From 2120125ad74caa41402f782a93a5a153520fdc1c Mon Sep 17 00:00:00 2001 From: Krzysztof Parzyszek Date: Tue, 29 Jul 2025 16:17:29 -0500 Subject: [PATCH 2/4] [flang][Evaluate] Implement rewriting framework for evaluate::Expr The structure of evaluate::Expr is highly customized for the specific operation or entity that it represents. The different cases are expressed with different types, which makes the traversal and modifications somewhat complicated. There exists a framework for read-only traversal (traverse.h), but there is nothing that helps with modifying evaluate::Expr. It's rare that evaluate::Expr needs to be modified, but for the cases where it needs to be, this code will make it easier. --- flang/include/flang/Evaluate/rewrite.h | 160 +++++++++++++++++++++++++ 1 file changed, 160 insertions(+) create mode 100644 flang/include/flang/Evaluate/rewrite.h diff --git a/flang/include/flang/Evaluate/rewrite.h b/flang/include/flang/Evaluate/rewrite.h new file mode 100644 index 0000000000000..034b1efa21977 --- /dev/null +++ b/flang/include/flang/Evaluate/rewrite.h @@ -0,0 +1,160 @@ +//===-- include/flang/Evaluate/rewrite.h ------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#ifndef FORTRAN_EVALUATE_REWRITE_H_ +#define FORTRAN_EVALUATE_REWRITE_H_ + +#include "flang/Common/visit.h" +#include "flang/Evaluate/expression.h" +#include "flang/Support/Fortran.h" +#include "llvm/ADT/STLExtras.h" + +#include +#include +#include +#include + +namespace Fortran::evaluate { +namespace rewrite { +namespace detail { +template // +struct IsOperation { + static constexpr bool value{false}; +}; + +template +struct IsOperation> { + static constexpr bool value{true}; +}; +} // namespace detail + +template +constexpr bool is_operation_v{detail::IsOperation::value}; + +/// Individual Expr rewriter that simply constructs an expression that is +/// identical to the input. This is a suitable base class for all user-defined +/// rewriters. +struct Identity { + template + Expr operator()(Expr &&x, const U &op) { + return std::move(x); + } +}; + +/// Bottom-up Expr rewriter. +/// +/// The Mutator traverses and reconstructs given Expr. Going bottom-up, +/// whenever the traversal visits a sub-node of type Expr (for some U), +/// it will invoke the user-provided rewriter via the () operator. +/// +/// If x is of type Expr, it will call (in pseudo-code): +/// rewriter_(x, active_member_of(x.u)) +/// The second parameter is there to make it easier to overload the () operator +/// for specific operations in Expr<...>. +/// +/// The user rewriter is only invoked for Expr, not for Operation, nor any +/// other subobject. +template struct Mutator { + Mutator(Rewriter &rewriter) : rewriter_(rewriter) {} + + template > + U operator()(T &&x) { + if constexpr (std::is_lvalue_reference_v) { + return Mutate(U(x)); + } else { + return Mutate(std::move(x)); + } + } + +private: + template struct LambdaWithRvalueCapture { + LambdaWithRvalueCapture(Rewriter &r, Expr &&c) + : rewriter_(r), capture_(std::move(c)) {} + template Expr operator()(const S &s) { + return rewriter_(std::move(capture_), s); + } + + private: + Rewriter &rewriter_; + Expr &&capture_; + }; + + template >> + T Mutate(T &&x) const { + return std::move(x); + } + + template >> + D Mutate(D &&op, std::make_index_sequence t = {}) const { + return MutateOp(std::move(op), t); + } + + template // + Expr Mutate(Expr &&x) const { + // First construct the new expression with the rewritten op. + Expr n{common::visit( + [&](auto &&s) { // + return Expr(Mutate(std::move(s))); + }, + std::move(x.u))}; + // Return the rewritten expression. The second visit it to make sure + // that the second argument in the call to the rewriter is a part of + // the Expr passed to it. + return common::visit( + LambdaWithRvalueCapture(rewriter_, std::move(n)), std::move(n.u)); + } + + template + std::variant Mutate(std::variant &&u) const { + return common::visit( + [this](auto &&s) { return Mutate(std::move(s)); }, std::move(u)); + } + + template + std::tuple Mutate(std::tuple &&t) const { + return MutateTuple(std::move(t), std::index_sequence_for{}); + } + + template + std::tuple MutateTuple( + std::tuple &&t, std::index_sequence) const { + return std::make_tuple(Mutate(std::move(std::get(t))...)); + } + + template + D MutateOp(D &&op, std::index_sequence) const { + return D(Mutate(std::move(op.template operand()))...); + } + + template + Extremum MutateOp(Extremum &&op, std::index_sequence) const { + return Extremum( + op.ordering, Mutate(std::move(op.template operand()))...); + } + + template + ComplexComponent MutateOp( + ComplexComponent &&op, std::index_sequence) const { + return ComplexComponent( + op.isImaginaryPart, Mutate(std::move(op.template operand()))...); + } + + template + LogicalOperation MutateOp( + LogicalOperation &&op, std::index_sequence) const { + return LogicalOperation( + op.logicalOperator, Mutate(std::move(op.template operand()))...); + } + + Rewriter &rewriter_; +}; + +template Mutator(Rewriter &) -> Mutator; +} // namespace rewrite +} // namespace Fortran::evaluate + +#endif // FORTRAN_EVALUATE_REWRITE_H_ From 84a340b92ebdc2f891367a7d9204e85b13942020 Mon Sep 17 00:00:00 2001 From: Krzysztof Parzyszek Date: Wed, 30 Jul 2025 11:55:50 -0500 Subject: [PATCH 3/4] [flang][OpenMP] Move rewriting of min/max from Lower to Semantics There semantic analysis of the ATOMIC construct will require additional rewriting (reassociation of certain expressions for user convenience), and that will be driven by diagnoses made in the semantic checks. While the rewriting of min/max is not required to be done in semantic analysis, moving it there will make all rewriting for ATOMIC construct be located in a single location. --- flang/include/flang/Semantics/openmp-utils.h | 8 + flang/lib/Lower/OpenMP/Atomic.cpp | 271 ------------------- flang/lib/Semantics/check-omp-atomic.cpp | 127 ++++++++- 3 files changed, 134 insertions(+), 272 deletions(-) diff --git a/flang/include/flang/Semantics/openmp-utils.h b/flang/include/flang/Semantics/openmp-utils.h index b8ad9ed17c720..1c54124a5738a 100644 --- a/flang/include/flang/Semantics/openmp-utils.h +++ b/flang/include/flang/Semantics/openmp-utils.h @@ -22,6 +22,8 @@ #include #include +#include +#include namespace Fortran::semantics { class SemanticsContext; @@ -29,6 +31,12 @@ class Symbol; // Add this namespace to avoid potential conflicts namespace omp { +template > U AsRvalue(T &t) { + return U(t); +} + +template T &&AsRvalue(T &&t) { return std::move(t); } + // There is no consistent way to get the source of an ActionStmt, but there // is "source" in Statement. This structure keeps the ActionStmt with the // extracted source for further use. diff --git a/flang/lib/Lower/OpenMP/Atomic.cpp b/flang/lib/Lower/OpenMP/Atomic.cpp index ed0bff04ed889..ff82a36951bfa 100644 --- a/flang/lib/Lower/OpenMP/Atomic.cpp +++ b/flang/lib/Lower/OpenMP/Atomic.cpp @@ -43,179 +43,6 @@ namespace omp { using namespace Fortran::lower::omp; } -namespace { -// An example of a type that can be used to get the return value from -// the visitor: -// visitor(type_identity) -> result_type -using SomeArgType = evaluate::Type; - -struct GetProc - : public evaluate::Traverse { - using Result = const evaluate::ProcedureDesignator *; - using Base = evaluate::Traverse; - GetProc() : Base(*this) {} - - using Base::operator(); - - static Result Default() { return nullptr; } - - Result operator()(const evaluate::ProcedureDesignator &p) const { return &p; } - static Result Combine(Result a, Result b) { return a != nullptr ? a : b; } -}; - -struct WithType { - WithType(const evaluate::DynamicType &t) : type(t) { - assert(type.category() != common::TypeCategory::Derived && - "Type cannot be a derived type"); - } - - template // - auto visit(VisitorTy &&visitor) const - -> std::invoke_result_t { - switch (type.category()) { - case common::TypeCategory::Integer: - switch (type.kind()) { - case 1: - return visitor(llvm::type_identity>{}); - case 2: - return visitor(llvm::type_identity>{}); - case 4: - return visitor(llvm::type_identity>{}); - case 8: - return visitor(llvm::type_identity>{}); - case 16: - return visitor(llvm::type_identity>{}); - } - break; - case common::TypeCategory::Unsigned: - switch (type.kind()) { - case 1: - return visitor(llvm::type_identity>{}); - case 2: - return visitor(llvm::type_identity>{}); - case 4: - return visitor(llvm::type_identity>{}); - case 8: - return visitor(llvm::type_identity>{}); - case 16: - return visitor(llvm::type_identity>{}); - } - break; - case common::TypeCategory::Real: - switch (type.kind()) { - case 2: - return visitor(llvm::type_identity>{}); - case 3: - return visitor(llvm::type_identity>{}); - case 4: - return visitor(llvm::type_identity>{}); - case 8: - return visitor(llvm::type_identity>{}); - case 10: - return visitor(llvm::type_identity>{}); - case 16: - return visitor(llvm::type_identity>{}); - } - break; - case common::TypeCategory::Complex: - switch (type.kind()) { - case 2: - return visitor(llvm::type_identity>{}); - case 3: - return visitor(llvm::type_identity>{}); - case 4: - return visitor(llvm::type_identity>{}); - case 8: - return visitor(llvm::type_identity>{}); - case 10: - return visitor(llvm::type_identity>{}); - case 16: - return visitor(llvm::type_identity>{}); - } - break; - case common::TypeCategory::Logical: - switch (type.kind()) { - case 1: - return visitor(llvm::type_identity>{}); - case 2: - return visitor(llvm::type_identity>{}); - case 4: - return visitor(llvm::type_identity>{}); - case 8: - return visitor(llvm::type_identity>{}); - } - break; - case common::TypeCategory::Character: - switch (type.kind()) { - case 1: - return visitor(llvm::type_identity>{}); - case 2: - return visitor(llvm::type_identity>{}); - case 4: - return visitor(llvm::type_identity>{}); - } - break; - case common::TypeCategory::Derived: - (void)Derived; - break; - } - llvm_unreachable("Unhandled type"); - } - - const evaluate::DynamicType &type; - -private: - // Shorter names. - static constexpr auto Character = common::TypeCategory::Character; - static constexpr auto Complex = common::TypeCategory::Complex; - static constexpr auto Derived = common::TypeCategory::Derived; - static constexpr auto Integer = common::TypeCategory::Integer; - static constexpr auto Logical = common::TypeCategory::Logical; - static constexpr auto Real = common::TypeCategory::Real; - static constexpr auto Unsigned = common::TypeCategory::Unsigned; -}; - -template > -U AsRvalue(T &t) { - U copy{t}; - return std::move(copy); -} - -template -T &&AsRvalue(T &&t) { - return std::move(t); -} - -struct ArgumentReplacer - : public evaluate::Traverse { - using Base = evaluate::Traverse; - using Result = bool; - - Result Default() const { return false; } - - ArgumentReplacer(evaluate::ActualArguments &&newArgs) - : Base(*this), args_(std::move(newArgs)) {} - - using Base::operator(); - - template - Result operator()(const evaluate::FunctionRef &x) { - assert(!done_); - auto &mut = const_cast &>(x); - mut.arguments() = args_; - done_ = true; - return true; - } - - Result Combine(Result &&a, Result &&b) { return a || b; } - -private: - bool done_{false}; - evaluate::ActualArguments &&args_; -}; -} // namespace - [[maybe_unused]] static void dumpAtomicAnalysis(const parser::OpenMPAtomicConstruct::Analysis &analysis) { auto whatStr = [](int k) { @@ -412,85 +239,6 @@ makeMemOrderAttr(lower::AbstractConverter &converter, return nullptr; } -static bool replaceArgs(semantics::SomeExpr &expr, - evaluate::ActualArguments &&newArgs) { - return ArgumentReplacer(std::move(newArgs))(expr); -} - -static semantics::SomeExpr makeCall(const evaluate::DynamicType &type, - const evaluate::ProcedureDesignator &proc, - const evaluate::ActualArguments &args) { - return WithType(type).visit([&](auto &&s) -> semantics::SomeExpr { - using Type = typename llvm::remove_cvref_t::type; - return evaluate::AsGenericExpr( - evaluate::FunctionRef(AsRvalue(proc), AsRvalue(args))); - }); -} - -static const evaluate::ProcedureDesignator & -getProcedureDesignator(const semantics::SomeExpr &call) { - const evaluate::ProcedureDesignator *proc = GetProc{}(call); - assert(proc && "Call has no procedure designator"); - return *proc; -} - -static semantics::SomeExpr // -genReducedMinMax(const semantics::SomeExpr &orig, - const semantics::SomeExpr *atomArg, - const std::vector &args) { - // Take a list of arguments to a min/max operation, e.g. [a0, a1, ...] - // One of the a_i's, say a_t, must be atomArg. - // Generate tmp = min/max(a0, a1, ... [except a_t]). Then generate - // call = min/max(a_t, tmp). - // Return "call". - - // The min/max intrinsics have 2 mandatory arguments, the rest is optional. - // Make sure that the "tmp = min/max(...)" doesn't promote an optional - // argument to a non-optional position. This could happen if a_t is at - // position 0 or 1. - if (args.size() <= 2) - return orig; - - evaluate::ActualArguments nonAtoms; - - auto AsActual = [](const semantics::SomeExpr &x) { - semantics::SomeExpr copy = x; - return evaluate::ActualArgument(std::move(copy)); - }; - // Semantic checks guarantee that the "atom" shows exactly once in the - // argument list (with potential conversions around it). - // For the first two (non-optional) arguments, if "atom" is among them, - // replace it with another occurrence of the other non-optional argument. - if (atomArg == &args[0]) { - // (atom, x, y...) -> (x, x, y...) - nonAtoms.push_back(AsActual(args[1])); - nonAtoms.push_back(AsActual(args[1])); - } else if (atomArg == &args[1]) { - // (x, atom, y...) -> (x, x, y...) - nonAtoms.push_back(AsActual(args[0])); - nonAtoms.push_back(AsActual(args[0])); - } else { - // (x, y, z...) -> unchanged - nonAtoms.push_back(AsActual(args[0])); - nonAtoms.push_back(AsActual(args[1])); - } - - // The rest of arguments are optional, so we can just skip "atom". - for (size_t i = 2, e = args.size(); i != e; ++i) { - if (atomArg != &args[i]) - nonAtoms.push_back(AsActual(args[i])); - } - - // The type of the intermediate min/max is the same as the type of its - // arguments, which may be different from the type of the original - // expression. The original expression may have additional coverts. - auto tmp = - makeCall(*atomArg->GetType(), getProcedureDesignator(orig), nonAtoms); - semantics::SomeExpr call = orig; - replaceArgs(call, {AsActual(*atomArg), AsActual(tmp)}); - return call; -} - static mlir::Operation * // genAtomicRead(lower::AbstractConverter &converter, semantics::SemanticsContext &semaCtx, mlir::Location loc, @@ -610,25 +358,6 @@ genAtomicUpdate(lower::AbstractConverter &converter, auto [opcode, args] = evaluate::GetTopLevelOperationIgnoreResizing(input); assert(!args.empty() && "Update operation without arguments"); - // Pass args as an argument to avoid capturing a structured binding. - const semantics::SomeExpr *atomArg = [&](auto &args) { - for (const semantics::SomeExpr &e : args) { - if (evaluate::IsSameOrConvertOf(e, atom)) - return &e; - } - llvm_unreachable("Atomic variable not in argument list"); - }(args); - - if (opcode == evaluate::operation::Operator::Min || - opcode == evaluate::operation::Operator::Max) { - // Min and max operations are expanded inline, so reduce them to - // operations with exactly two (non-optional) arguments. - rhs = genReducedMinMax(rhs, atomArg, args); - input = *evaluate::GetConvertInput(rhs); - std::tie(opcode, args) = - evaluate::GetTopLevelOperationIgnoreResizing(input); - atomArg = nullptr; // No longer valid. - } for (auto &arg : args) { if (!evaluate::IsSameOrConvertOf(arg, atom)) { mlir::Value val = fir::getBase(converter.genExprValue(arg, naCtx, &loc)); diff --git a/flang/lib/Semantics/check-omp-atomic.cpp b/flang/lib/Semantics/check-omp-atomic.cpp index a5fe820b1069b..0c0e6158485e9 100644 --- a/flang/lib/Semantics/check-omp-atomic.cpp +++ b/flang/lib/Semantics/check-omp-atomic.cpp @@ -14,6 +14,7 @@ #include "flang/Common/indirection.h" #include "flang/Evaluate/expression.h" +#include "flang/Evaluate/rewrite.h" #include "flang/Evaluate/tools.h" #include "flang/Parser/char-block.h" #include "flang/Parser/parse-tree.h" @@ -42,6 +43,8 @@ using namespace Fortran::semantics::omp; namespace operation = Fortran::evaluate::operation; +static MaybeExpr PostSemaRewrite(const SomeExpr &atom, const SomeExpr &expr); + template static bool operator!=(const evaluate::Expr &e, const evaluate::Expr &f) { return !(e == f); @@ -284,7 +287,15 @@ struct AtomicAnalysis { AtomicAnalysis &addOp(Op &op, int what, const std::optional &maybeAssign) { op.what = what; - op.assign = maybeAssign; + if (maybeAssign) { + if (MaybeExpr rewritten{PostSemaRewrite(atom_, maybeAssign->rhs)}) { + op.assign = evaluate::Assignment( + AsRvalue(maybeAssign->lhs), std::move(*rewritten)); + op.assign->u = std::move(maybeAssign->u); + } else { + op.assign = *maybeAssign; + } + } return *this; } @@ -1293,4 +1304,118 @@ void OmpStructureChecker::Leave(const parser::OpenMPAtomicConstruct &) { dirContext_.pop_back(); } +// Rewrite min/max: +// Min and max intrinsics in Fortran take an arbitrary number of arguments +// (two or more). The first two are mandatory, the rest is optional. That +// means that arguments beyond the first two may be optional dummy argument +// from the caller. In that case, a reference to such an argument will +// cause presence test to be emitted, which cannot go inside of the atomic +// operation. Since the atom operand must be present, rewrite the min/max +// operation in a way that avoid the presence tests in the atomic code. +// For example, in +// subroutine f(atom, x, y, z) +// integer :: atom, x +// integer, optional :: y, z +// !$omp atomic update +// atom = min(atom, x, y, z) +// end +// the min operation will become +// atom = min(atom, min(x, y, z)) +// and in the final code +// // Presence check is fine here. +// tmp = min(x, y, z) +// atomic update { +// // Both operands are mandatory, no presence check needed. +// atom = min(atom, tmp) +// } +struct MinMaxRewriter : public evaluate::rewrite::Identity { + using Id = evaluate::rewrite::Identity; + using Id::operator(); + + MinMaxRewriter(const SomeExpr &atom) : atom_(atom) {} + + static bool IsMinMax(const evaluate::ProcedureDesignator &p) { + if (auto *intrin{p.GetSpecificIntrinsic()}) { + return intrin->name == "min" || intrin->name == "max"; + } + return false; + } + + // Take a list of arguments to a min/max operation, e.g. [a0, a1, ...] + // One of the a_i's, say a_t, must be the atom. + // Generate + // min/max(a_t, min/max(a0, a1, ... [except a_t])) + template + evaluate::Expr operator()( + evaluate::Expr &&x, const evaluate::FunctionRef &f) { + const evaluate::ProcedureDesignator &proc = f.proc(); + if (!IsMinMax(proc) || f.arguments().size() <= 2) { + return Id::operator()(std::move(x), f); + } + + // Collect arguments as SomeExpr's and find out which argument + // corresponds to atom. + const SomeExpr *atomArg{nullptr}; + std::vector args; + for (const std::optional &a : f.arguments()) { + if (!a) { + continue; + } + if (const SomeExpr *e{a->UnwrapExpr()}) { + if (evaluate::IsSameOrConvertOf(*e, atom_)) { + atomArg = e; + } + args.push_back(e); + } + } + if (!atomArg) { + return Id::operator()(std::move(x), f); + } + + evaluate::ActualArguments nonAtoms; + + auto AsActual = [](const SomeExpr &z) { + SomeExpr copy = z; + return evaluate::ActualArgument(std::move(copy)); + }; + // Semantic checks guarantee that the "atom" shows exactly once in the + // argument list (with potential conversions around it). + // For the first two (non-optional) arguments, if "atom" is among them, + // replace it with another occurrence of the other non-optional argument. + if (atomArg == args[0]) { + // (atom, x, y...) -> (x, x, y...) + nonAtoms.push_back(AsActual(*args[1])); + nonAtoms.push_back(AsActual(*args[1])); + } else if (atomArg == args[1]) { + // (x, atom, y...) -> (x, x, y...) + nonAtoms.push_back(AsActual(*args[0])); + nonAtoms.push_back(AsActual(*args[0])); + } else { + // (x, y, z...) -> unchanged + nonAtoms.push_back(AsActual(*args[0])); + nonAtoms.push_back(AsActual(*args[1])); + } + + // The rest of arguments are optional, so we can just skip "atom". + for (size_t i = 2, e = args.size(); i != e; ++i) { + if (atomArg != args[i]) + nonAtoms.push_back(AsActual(*args[i])); + } + + SomeExpr tmp = evaluate::AsGenericExpr( + evaluate::FunctionRef(AsRvalue(proc), AsRvalue(nonAtoms))); + + return evaluate::Expr(evaluate::FunctionRef( + AsRvalue(proc), {AsActual(*atomArg), AsActual(tmp)})); + } + +private: + const SomeExpr &atom_; +}; + +static MaybeExpr PostSemaRewrite(const SomeExpr &atom, const SomeExpr &expr) { + MinMaxRewriter rewriter(atom); + return evaluate::rewrite::Mutator(rewriter)(expr); +} + } // namespace Fortran::semantics From ef35eeae708cb5f4e7782d857a0551a263edbbe6 Mon Sep 17 00:00:00 2001 From: Krzysztof Parzyszek Date: Fri, 8 Aug 2025 13:31:24 -0500 Subject: [PATCH 4/4] [flang][Evaluate] Pattern matching framework for evaluate::Expr Implement a framework to make it easier to detect if evaluate::Expr has certain structure. --- flang/include/flang/Evaluate/match.h | 211 +++++++++++++++++++++++++++ 1 file changed, 211 insertions(+) create mode 100644 flang/include/flang/Evaluate/match.h diff --git a/flang/include/flang/Evaluate/match.h b/flang/include/flang/Evaluate/match.h new file mode 100644 index 0000000000000..79da40f7c1338 --- /dev/null +++ b/flang/include/flang/Evaluate/match.h @@ -0,0 +1,211 @@ +//===-- include/flang/Evaluate/match.h --------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#ifndef FORTRAN_EVALUATE_MATCH_H_ +#define FORTRAN_EVALUATE_MATCH_H_ + +#include "flang/Common/visit.h" +#include "flang/Evaluate/expression.h" +#include "llvm/ADT/STLExtras.h" + +#include +#include +#include +#include + +namespace Fortran::evaluate { +namespace match { +namespace detail { +template // +struct IsOperation { + static constexpr bool value{false}; +}; + +template +struct IsOperation> { + static constexpr bool value{true}; +}; +} // namespace detail + +template +constexpr bool is_operation_v{detail::IsOperation::value}; + +template +const evaluate::Expr &deparen(const evaluate::Expr &x) { + if (auto *parens{std::get_if>(&x.u)}) { + return deparen(parens->template operand<0>()); + } else { + return x; + } +} + +// Expr matchers (patterns) +// +// Each pattern should implement +// bool match(const U &input) const +// member function that returns `true` when the match was successful, +// and `false` otherwise. +// +// Patterns are intended to be composable, i.e. a pattern can take operands +// which themselves are patterns. This composition is expected to match if +// the root pattern and all its operands match given input. + +/// Matches any input as long as it has the expected type `MatchType`. +/// Additionally, it sets the member `ref` to the matched input. +template struct TypePattern { + using MatchType = llvm::remove_cvref_t; + + template bool match(const U &input) const { + if constexpr (std::is_same_v) { + ref = &input; + return true; + } else { + return false; + } + } + + mutable const MatchType *ref{nullptr}; +}; + +/// Matches one of the patterns provided as template arguments. All of these +/// patterns should have the same number of operands, i.e. they all should +/// try to match input expression with the same number of children, i.e. +/// AnyOfPattern is ok, whereas +/// AnyOfPattern is not. +template struct AnyOfPattern { + static_assert(sizeof...(Patterns) != 0); + +private: + using PatternTuple = std::tuple; + + template + using Pattern = typename std::tuple_element::type; + + template + AnyOfPattern(std::index_sequence, const Ops &...ops) + : patterns(std::make_tuple(Pattern(ops...)...)) {} + + template + bool matchOne(const P &pattern, const U &input) const { + if (pattern.match(input)) { + ref = &pattern; + return true; + } + return false; + } + + template + bool matchImpl(const U &input, std::index_sequence) const { + return (matchOne(std::get(patterns), input) || ...); + } + + PatternTuple patterns; + +public: + using Indexes = std::index_sequence_for; + using MatchTypes = std::tuple; + + template + AnyOfPattern(const Ops &...ops) : AnyOfPattern(Indexes{}, ops...) {} + + template bool match(const U &input) const { + return matchImpl(input, Indexes{}); + } + + mutable std::variant ref{ + std::monostate{}}; +}; + +/// Matches any input of type Expr +/// The indent if this pattern is to be a leaf in multi-operand patterns. +template // +struct ExprPattern : public TypePattern> {}; + +/// Matches evaluate::Expr that contains evaluate::Opreration. +template +struct OperationPattern : public TypePattern { +private: + using Indexes = std::index_sequence_for; + + template + bool matchImpl(const S &op, std::index_sequence) const { + using TypeS = llvm::remove_cvref_t; + if constexpr (is_operation_v) { + if constexpr (TypeS::operands == Indexes::size()) { + return TypePattern::match(op) && + (std::get(operands).match(op.template operand()) && ...); + } + } + return false; + } + + std::tuple operands; + +public: + using MatchType = OpType; + + OperationPattern(const Ops &...ops, llvm::type_identity = {}) + : operands(ops...) {} + + template bool match(const evaluate::Expr &input) const { + return common::visit( + [&](auto &&s) { return matchImpl(s, Indexes{}); }, deparen(input).u); + } + + template bool match(const U &input) const { + // Only match Expr + return false; + } +}; + +template +OperationPattern(const Ops &...ops, llvm::type_identity) + -> OperationPattern; + +// Namespace-level definitions + +template using Expr = ExprPattern; + +template +using Op = OperationPattern; + +template +bool match(const Pattern &pattern, const Input &input) { + return pattern.match(input); +} + +// Specific operation patterns + +// -- Add +template +struct Add : public Op, Op0, Op1> { + using Base = Op, Op0, Op1>; + + Add(const Op0 &op0, const Op1 &op1) : Base(op0, op1) {} +}; + +template +Add add(const Op0 &op0, const Op1 &op1) { + return Add(op0, op1); +} + +// -- Mul +template +struct Mul : public Op, Op0, Op1> { + using Base = Op, Op0, Op1>; + + Mul(const Op0 &op0, const Op1 &op1) : Base(op0, op1) {} +}; + +template +Mul mul(const Op0 &op0, const Op1 &op1) { + return Mul(op0, op1); +} +} // namespace match +} // namespace Fortran::evaluate + +#endif // FORTRAN_EVALUATE_MATCH_H_