1111#include " flang/Evaluate/expression.h"
1212#include " flang/Evaluate/fold.h"
1313#include " flang/Evaluate/tools.h"
14+ #include " flang/Evaluate/traverse.h"
15+ #include " flang/Evaluate/type.h"
1416#include " flang/Lower/AbstractConverter.h"
1517#include " flang/Lower/PFTBuilder.h"
1618#include " flang/Lower/StatementContext.h"
@@ -41,6 +43,179 @@ namespace omp {
4143using namespace Fortran ::lower::omp;
4244}
4345
46+ namespace {
47+ // An example of a type that can be used to get the return value from
48+ // the visitor:
49+ // visitor(type_identity<Xyz>) -> result_type
50+ using SomeArgType = evaluate::Type<common::TypeCategory::Integer, 4 >;
51+
52+ struct GetProc
53+ : public evaluate::Traverse<GetProc, const evaluate::ProcedureDesignator *,
54+ false > {
55+ using Result = const evaluate::ProcedureDesignator *;
56+ using Base = evaluate::Traverse<GetProc, Result, false >;
57+ GetProc () : Base(*this ) {}
58+
59+ using Base::operator ();
60+
61+ static Result Default () { return nullptr ; }
62+
63+ Result operator ()(const evaluate::ProcedureDesignator &p) const { return &p; }
64+ static Result Combine (Result a, Result b) { return a != nullptr ? a : b; }
65+ };
66+
67+ struct WithType {
68+ WithType (const evaluate::DynamicType &t) : type(t) {
69+ assert (type.category () != common::TypeCategory::Derived &&
70+ " Type cannot be a derived type" );
71+ }
72+
73+ template <typename VisitorTy> //
74+ auto visit (VisitorTy &&visitor) const
75+ -> std::invoke_result_t<VisitorTy, SomeArgType> {
76+ switch (type.category ()) {
77+ case common::TypeCategory::Integer:
78+ switch (type.kind ()) {
79+ case 1 :
80+ return visitor (llvm::type_identity<evaluate::Type<Integer, 1 >>{});
81+ case 2 :
82+ return visitor (llvm::type_identity<evaluate::Type<Integer, 2 >>{});
83+ case 4 :
84+ return visitor (llvm::type_identity<evaluate::Type<Integer, 4 >>{});
85+ case 8 :
86+ return visitor (llvm::type_identity<evaluate::Type<Integer, 8 >>{});
87+ case 16 :
88+ return visitor (llvm::type_identity<evaluate::Type<Integer, 16 >>{});
89+ }
90+ break ;
91+ case common::TypeCategory::Unsigned:
92+ switch (type.kind ()) {
93+ case 1 :
94+ return visitor (llvm::type_identity<evaluate::Type<Unsigned, 1 >>{});
95+ case 2 :
96+ return visitor (llvm::type_identity<evaluate::Type<Unsigned, 2 >>{});
97+ case 4 :
98+ return visitor (llvm::type_identity<evaluate::Type<Unsigned, 4 >>{});
99+ case 8 :
100+ return visitor (llvm::type_identity<evaluate::Type<Unsigned, 8 >>{});
101+ case 16 :
102+ return visitor (llvm::type_identity<evaluate::Type<Unsigned, 16 >>{});
103+ }
104+ break ;
105+ case common::TypeCategory::Real:
106+ switch (type.kind ()) {
107+ case 2 :
108+ return visitor (llvm::type_identity<evaluate::Type<Real, 2 >>{});
109+ case 3 :
110+ return visitor (llvm::type_identity<evaluate::Type<Real, 3 >>{});
111+ case 4 :
112+ return visitor (llvm::type_identity<evaluate::Type<Real, 4 >>{});
113+ case 8 :
114+ return visitor (llvm::type_identity<evaluate::Type<Real, 8 >>{});
115+ case 10 :
116+ return visitor (llvm::type_identity<evaluate::Type<Real, 10 >>{});
117+ case 16 :
118+ return visitor (llvm::type_identity<evaluate::Type<Real, 16 >>{});
119+ }
120+ break ;
121+ case common::TypeCategory::Complex:
122+ switch (type.kind ()) {
123+ case 2 :
124+ return visitor (llvm::type_identity<evaluate::Type<Complex, 2 >>{});
125+ case 3 :
126+ return visitor (llvm::type_identity<evaluate::Type<Complex, 3 >>{});
127+ case 4 :
128+ return visitor (llvm::type_identity<evaluate::Type<Complex, 4 >>{});
129+ case 8 :
130+ return visitor (llvm::type_identity<evaluate::Type<Complex, 8 >>{});
131+ case 10 :
132+ return visitor (llvm::type_identity<evaluate::Type<Complex, 10 >>{});
133+ case 16 :
134+ return visitor (llvm::type_identity<evaluate::Type<Complex, 16 >>{});
135+ }
136+ break ;
137+ case common::TypeCategory::Logical:
138+ switch (type.kind ()) {
139+ case 1 :
140+ return visitor (llvm::type_identity<evaluate::Type<Logical, 1 >>{});
141+ case 2 :
142+ return visitor (llvm::type_identity<evaluate::Type<Logical, 2 >>{});
143+ case 4 :
144+ return visitor (llvm::type_identity<evaluate::Type<Logical, 4 >>{});
145+ case 8 :
146+ return visitor (llvm::type_identity<evaluate::Type<Logical, 8 >>{});
147+ }
148+ break ;
149+ case common::TypeCategory::Character:
150+ switch (type.kind ()) {
151+ case 1 :
152+ return visitor (llvm::type_identity<evaluate::Type<Character, 1 >>{});
153+ case 2 :
154+ return visitor (llvm::type_identity<evaluate::Type<Character, 2 >>{});
155+ case 4 :
156+ return visitor (llvm::type_identity<evaluate::Type<Character, 4 >>{});
157+ }
158+ break ;
159+ case common::TypeCategory::Derived:
160+ (void )Derived;
161+ break ;
162+ }
163+ llvm_unreachable (" Unhandled type" );
164+ }
165+
166+ const evaluate::DynamicType &type;
167+
168+ private:
169+ // Shorter names.
170+ static constexpr auto Character = common::TypeCategory::Character;
171+ static constexpr auto Complex = common::TypeCategory::Complex;
172+ static constexpr auto Derived = common::TypeCategory::Derived;
173+ static constexpr auto Integer = common::TypeCategory::Integer;
174+ static constexpr auto Logical = common::TypeCategory::Logical;
175+ static constexpr auto Real = common::TypeCategory::Real;
176+ static constexpr auto Unsigned = common::TypeCategory::Unsigned;
177+ };
178+
179+ template <typename T, typename U = std::remove_const_t <T>>
180+ U AsRvalue (T &t) {
181+ U copy{t};
182+ return std::move (copy);
183+ }
184+
185+ template <typename T>
186+ T &&AsRvalue(T &&t) {
187+ return std::move (t);
188+ }
189+
190+ struct ArgumentReplacer
191+ : public evaluate::Traverse<ArgumentReplacer, bool , false > {
192+ using Base = evaluate::Traverse<ArgumentReplacer, bool , false >;
193+ using Result = bool ;
194+
195+ Result Default () const { return false ; }
196+
197+ ArgumentReplacer (evaluate::ActualArguments &&newArgs)
198+ : Base(*this ), args_(std::move(newArgs)) {}
199+
200+ using Base::operator ();
201+
202+ template <typename T>
203+ Result operator ()(const evaluate::FunctionRef<T> &x) {
204+ assert (!done_);
205+ auto &mut = const_cast <evaluate::FunctionRef<T> &>(x);
206+ mut.arguments () = args_;
207+ done_ = true ;
208+ return true ;
209+ }
210+
211+ Result Combine (Result &&a, Result &&b) { return a || b; }
212+
213+ private:
214+ bool done_{false };
215+ evaluate::ActualArguments &&args_;
216+ };
217+ } // namespace
218+
44219[[maybe_unused]] static void
45220dumpAtomicAnalysis (const parser::OpenMPAtomicConstruct::Analysis &analysis) {
46221 auto whatStr = [](int k) {
@@ -237,6 +412,85 @@ makeMemOrderAttr(lower::AbstractConverter &converter,
237412 return nullptr ;
238413}
239414
415+ static bool replaceArgs (semantics::SomeExpr &expr,
416+ evaluate::ActualArguments &&newArgs) {
417+ return ArgumentReplacer (std::move (newArgs))(expr);
418+ }
419+
420+ static semantics::SomeExpr makeCall (const evaluate::DynamicType &type,
421+ const evaluate::ProcedureDesignator &proc,
422+ const evaluate::ActualArguments &args) {
423+ return WithType (type).visit ([&](auto &&s) -> semantics::SomeExpr {
424+ using Type = typename llvm::remove_cvref_t <decltype (s)>::type;
425+ return evaluate::AsGenericExpr (
426+ evaluate::FunctionRef<Type>(AsRvalue (proc), AsRvalue (args)));
427+ });
428+ }
429+
430+ static const evaluate::ProcedureDesignator &
431+ getProcedureDesignator (const semantics::SomeExpr &call) {
432+ const evaluate::ProcedureDesignator *proc = GetProc{}(call);
433+ assert (proc && " Call has no procedure designator" );
434+ return *proc;
435+ }
436+
437+ static semantics::SomeExpr //
438+ genReducedMinMax (const semantics::SomeExpr &orig,
439+ const semantics::SomeExpr *atomArg,
440+ const std::vector<semantics::SomeExpr> &args) {
441+ // Take a list of arguments to a min/max operation, e.g. [a0, a1, ...]
442+ // One of the a_i's, say a_t, must be atomArg.
443+ // Generate tmp = min/max(a0, a1, ... [except a_t]). Then generate
444+ // call = min/max(a_t, tmp).
445+ // Return "call".
446+
447+ // The min/max intrinsics have 2 mandatory arguments, the rest is optional.
448+ // Make sure that the "tmp = min/max(...)" doesn't promote an optional
449+ // argument to a non-optional position. This could happen if a_t is at
450+ // position 0 or 1.
451+ if (args.size () <= 2 )
452+ return orig;
453+
454+ evaluate::ActualArguments nonAtoms;
455+
456+ auto AsActual = [](const semantics::SomeExpr &x) {
457+ semantics::SomeExpr copy = x;
458+ return evaluate::ActualArgument (std::move (copy));
459+ };
460+ // Semantic checks guarantee that the "atom" shows exactly once in the
461+ // argument list (with potential conversions around it).
462+ // For the first two (non-optional) arguments, if "atom" is among them,
463+ // replace it with another occurrence of the other non-optional argument.
464+ if (atomArg == &args[0 ]) {
465+ // (atom, x, y...) -> (x, x, y...)
466+ nonAtoms.push_back (AsActual (args[1 ]));
467+ nonAtoms.push_back (AsActual (args[1 ]));
468+ } else if (atomArg == &args[1 ]) {
469+ // (x, atom, y...) -> (x, x, y...)
470+ nonAtoms.push_back (AsActual (args[0 ]));
471+ nonAtoms.push_back (AsActual (args[0 ]));
472+ } else {
473+ // (x, y, z...) -> unchanged
474+ nonAtoms.push_back (AsActual (args[0 ]));
475+ nonAtoms.push_back (AsActual (args[1 ]));
476+ }
477+
478+ // The rest of arguments are optional, so we can just skip "atom".
479+ for (size_t i = 2 , e = args.size (); i != e; ++i) {
480+ if (atomArg != &args[i])
481+ nonAtoms.push_back (AsActual (args[i]));
482+ }
483+
484+ // The type of the intermediate min/max is the same as the type of its
485+ // arguments, which may be different from the type of the original
486+ // expression. The original expression may have additional coverts.
487+ auto tmp =
488+ makeCall (*atomArg->GetType (), getProcedureDesignator (orig), nonAtoms);
489+ semantics::SomeExpr call = orig;
490+ replaceArgs (call, {AsActual (*atomArg), AsActual (tmp)});
491+ return call;
492+ }
493+
240494static mlir::Operation * //
241495genAtomicRead (lower::AbstractConverter &converter,
242496 semantics::SemanticsContext &semaCtx, mlir::Location loc,
@@ -350,10 +604,29 @@ genAtomicUpdate(lower::AbstractConverter &converter,
350604 mlir::Type atomType = fir::unwrapRefType (atomAddr.getType ());
351605
352606 // This must exist by now.
353- semantics::SomeExpr input = * evaluate::GetConvertInput ( assign.rhs ) ;
354- std::vector< semantics::SomeExpr> args =
355- evaluate::GetTopLevelOperation (input). second ;
607+ semantics::SomeExpr rhs = assign.rhs ;
608+ semantics::SomeExpr input = * evaluate::GetConvertInput (rhs);
609+ auto [opcode, args] = evaluate::GetTopLevelOperation (input);
356610 assert (!args.empty () && " Update operation without arguments" );
611+
612+ // Pass args as an argument to avoid capturing a structured binding.
613+ const semantics::SomeExpr *atomArg = [&](auto &args) {
614+ for (const semantics::SomeExpr &e : args) {
615+ if (evaluate::IsSameOrConvertOf (e, atom))
616+ return &e;
617+ }
618+ llvm_unreachable (" Atomic variable not in argument list" );
619+ }(args);
620+
621+ if (opcode == evaluate::operation::Operator::Min ||
622+ opcode == evaluate::operation::Operator::Max) {
623+ // Min and max operations are expanded inline, so reduce them to
624+ // operations with exactly two (non-optional) arguments.
625+ rhs = genReducedMinMax (rhs, atomArg, args);
626+ input = *evaluate::GetConvertInput (rhs);
627+ std::tie (opcode, args) = evaluate::GetTopLevelOperation (input);
628+ atomArg = nullptr ; // No longer valid.
629+ }
357630 for (auto &arg : args) {
358631 if (!evaluate::IsSameOrConvertOf (arg, atom)) {
359632 mlir::Value val = fir::getBase (converter.genExprValue (arg, naCtx, &loc));
@@ -372,7 +645,7 @@ genAtomicUpdate(lower::AbstractConverter &converter,
372645
373646 converter.overrideExprValues (&overrides);
374647 mlir::Value updated =
375- fir::getBase (converter.genExprValue (assign. rhs , stmtCtx, &loc));
648+ fir::getBase (converter.genExprValue (rhs, stmtCtx, &loc));
376649 mlir::Value converted = builder.createConvert (loc, atomType, updated);
377650 builder.create <mlir::omp::YieldOp>(loc, converted);
378651 converter.resetExprOverrides ();
0 commit comments