Skip to content

Commit 4b7f380

Browse files
authored
[flang][OpenMP] Move rewriting of min/max from Lower to Semantics (#153038)
There semantic analysis of the ATOMIC construct will require additional rewriting (reassociation of certain expressions for user convenience), and that will be driven by diagnoses made in the semantic checks. While the rewriting of min/max is not required to be done in semantic analysis, moving it there will make all rewriting for ATOMIC construct be located in a single location.
1 parent 54f53c9 commit 4b7f380

File tree

3 files changed

+134
-272
lines changed

3 files changed

+134
-272
lines changed

flang/include/flang/Semantics/openmp-utils.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,21 @@
2222

2323
#include <optional>
2424
#include <string>
25+
#include <type_traits>
26+
#include <utility>
2527

2628
namespace Fortran::semantics {
2729
class SemanticsContext;
2830
class Symbol;
2931

3032
// Add this namespace to avoid potential conflicts
3133
namespace omp {
34+
template <typename T, typename U = std::remove_const_t<T>> U AsRvalue(T &t) {
35+
return U(t);
36+
}
37+
38+
template <typename T> T &&AsRvalue(T &&t) { return std::move(t); }
39+
3240
// There is no consistent way to get the source of an ActionStmt, but there
3341
// is "source" in Statement<T>. This structure keeps the ActionStmt with the
3442
// extracted source for further use.

flang/lib/Lower/OpenMP/Atomic.cpp

Lines changed: 0 additions & 271 deletions
Original file line numberDiff line numberDiff line change
@@ -43,179 +43,6 @@ namespace omp {
4343
using namespace Fortran::lower::omp;
4444
}
4545

46-
namespace {
47-
// An example of a type that can be used to get the return value from
48-
// the visitor:
49-
// visitor(type_identity<Xyz>) -> result_type
50-
using SomeArgType = evaluate::Type<common::TypeCategory::Integer, 4>;
51-
52-
struct GetProc
53-
: public evaluate::Traverse<GetProc, const evaluate::ProcedureDesignator *,
54-
false> {
55-
using Result = const evaluate::ProcedureDesignator *;
56-
using Base = evaluate::Traverse<GetProc, Result, false>;
57-
GetProc() : Base(*this) {}
58-
59-
using Base::operator();
60-
61-
static Result Default() { return nullptr; }
62-
63-
Result operator()(const evaluate::ProcedureDesignator &p) const { return &p; }
64-
static Result Combine(Result a, Result b) { return a != nullptr ? a : b; }
65-
};
66-
67-
struct WithType {
68-
WithType(const evaluate::DynamicType &t) : type(t) {
69-
assert(type.category() != common::TypeCategory::Derived &&
70-
"Type cannot be a derived type");
71-
}
72-
73-
template <typename VisitorTy> //
74-
auto visit(VisitorTy &&visitor) const
75-
-> std::invoke_result_t<VisitorTy, SomeArgType> {
76-
switch (type.category()) {
77-
case common::TypeCategory::Integer:
78-
switch (type.kind()) {
79-
case 1:
80-
return visitor(llvm::type_identity<evaluate::Type<Integer, 1>>{});
81-
case 2:
82-
return visitor(llvm::type_identity<evaluate::Type<Integer, 2>>{});
83-
case 4:
84-
return visitor(llvm::type_identity<evaluate::Type<Integer, 4>>{});
85-
case 8:
86-
return visitor(llvm::type_identity<evaluate::Type<Integer, 8>>{});
87-
case 16:
88-
return visitor(llvm::type_identity<evaluate::Type<Integer, 16>>{});
89-
}
90-
break;
91-
case common::TypeCategory::Unsigned:
92-
switch (type.kind()) {
93-
case 1:
94-
return visitor(llvm::type_identity<evaluate::Type<Unsigned, 1>>{});
95-
case 2:
96-
return visitor(llvm::type_identity<evaluate::Type<Unsigned, 2>>{});
97-
case 4:
98-
return visitor(llvm::type_identity<evaluate::Type<Unsigned, 4>>{});
99-
case 8:
100-
return visitor(llvm::type_identity<evaluate::Type<Unsigned, 8>>{});
101-
case 16:
102-
return visitor(llvm::type_identity<evaluate::Type<Unsigned, 16>>{});
103-
}
104-
break;
105-
case common::TypeCategory::Real:
106-
switch (type.kind()) {
107-
case 2:
108-
return visitor(llvm::type_identity<evaluate::Type<Real, 2>>{});
109-
case 3:
110-
return visitor(llvm::type_identity<evaluate::Type<Real, 3>>{});
111-
case 4:
112-
return visitor(llvm::type_identity<evaluate::Type<Real, 4>>{});
113-
case 8:
114-
return visitor(llvm::type_identity<evaluate::Type<Real, 8>>{});
115-
case 10:
116-
return visitor(llvm::type_identity<evaluate::Type<Real, 10>>{});
117-
case 16:
118-
return visitor(llvm::type_identity<evaluate::Type<Real, 16>>{});
119-
}
120-
break;
121-
case common::TypeCategory::Complex:
122-
switch (type.kind()) {
123-
case 2:
124-
return visitor(llvm::type_identity<evaluate::Type<Complex, 2>>{});
125-
case 3:
126-
return visitor(llvm::type_identity<evaluate::Type<Complex, 3>>{});
127-
case 4:
128-
return visitor(llvm::type_identity<evaluate::Type<Complex, 4>>{});
129-
case 8:
130-
return visitor(llvm::type_identity<evaluate::Type<Complex, 8>>{});
131-
case 10:
132-
return visitor(llvm::type_identity<evaluate::Type<Complex, 10>>{});
133-
case 16:
134-
return visitor(llvm::type_identity<evaluate::Type<Complex, 16>>{});
135-
}
136-
break;
137-
case common::TypeCategory::Logical:
138-
switch (type.kind()) {
139-
case 1:
140-
return visitor(llvm::type_identity<evaluate::Type<Logical, 1>>{});
141-
case 2:
142-
return visitor(llvm::type_identity<evaluate::Type<Logical, 2>>{});
143-
case 4:
144-
return visitor(llvm::type_identity<evaluate::Type<Logical, 4>>{});
145-
case 8:
146-
return visitor(llvm::type_identity<evaluate::Type<Logical, 8>>{});
147-
}
148-
break;
149-
case common::TypeCategory::Character:
150-
switch (type.kind()) {
151-
case 1:
152-
return visitor(llvm::type_identity<evaluate::Type<Character, 1>>{});
153-
case 2:
154-
return visitor(llvm::type_identity<evaluate::Type<Character, 2>>{});
155-
case 4:
156-
return visitor(llvm::type_identity<evaluate::Type<Character, 4>>{});
157-
}
158-
break;
159-
case common::TypeCategory::Derived:
160-
(void)Derived;
161-
break;
162-
}
163-
llvm_unreachable("Unhandled type");
164-
}
165-
166-
const evaluate::DynamicType &type;
167-
168-
private:
169-
// Shorter names.
170-
static constexpr auto Character = common::TypeCategory::Character;
171-
static constexpr auto Complex = common::TypeCategory::Complex;
172-
static constexpr auto Derived = common::TypeCategory::Derived;
173-
static constexpr auto Integer = common::TypeCategory::Integer;
174-
static constexpr auto Logical = common::TypeCategory::Logical;
175-
static constexpr auto Real = common::TypeCategory::Real;
176-
static constexpr auto Unsigned = common::TypeCategory::Unsigned;
177-
};
178-
179-
template <typename T, typename U = std::remove_const_t<T>>
180-
U AsRvalue(T &t) {
181-
U copy{t};
182-
return std::move(copy);
183-
}
184-
185-
template <typename T>
186-
T &&AsRvalue(T &&t) {
187-
return std::move(t);
188-
}
189-
190-
struct ArgumentReplacer
191-
: public evaluate::Traverse<ArgumentReplacer, bool, false> {
192-
using Base = evaluate::Traverse<ArgumentReplacer, bool, false>;
193-
using Result = bool;
194-
195-
Result Default() const { return false; }
196-
197-
ArgumentReplacer(evaluate::ActualArguments &&newArgs)
198-
: Base(*this), args_(std::move(newArgs)) {}
199-
200-
using Base::operator();
201-
202-
template <typename T>
203-
Result operator()(const evaluate::FunctionRef<T> &x) {
204-
assert(!done_);
205-
auto &mut = const_cast<evaluate::FunctionRef<T> &>(x);
206-
mut.arguments() = args_;
207-
done_ = true;
208-
return true;
209-
}
210-
211-
Result Combine(Result &&a, Result &&b) { return a || b; }
212-
213-
private:
214-
bool done_{false};
215-
evaluate::ActualArguments &&args_;
216-
};
217-
} // namespace
218-
21946
[[maybe_unused]] static void
22047
dumpAtomicAnalysis(const parser::OpenMPAtomicConstruct::Analysis &analysis) {
22148
auto whatStr = [](int k) {
@@ -412,85 +239,6 @@ makeMemOrderAttr(lower::AbstractConverter &converter,
412239
return nullptr;
413240
}
414241

415-
static bool replaceArgs(semantics::SomeExpr &expr,
416-
evaluate::ActualArguments &&newArgs) {
417-
return ArgumentReplacer(std::move(newArgs))(expr);
418-
}
419-
420-
static semantics::SomeExpr makeCall(const evaluate::DynamicType &type,
421-
const evaluate::ProcedureDesignator &proc,
422-
const evaluate::ActualArguments &args) {
423-
return WithType(type).visit([&](auto &&s) -> semantics::SomeExpr {
424-
using Type = typename llvm::remove_cvref_t<decltype(s)>::type;
425-
return evaluate::AsGenericExpr(
426-
evaluate::FunctionRef<Type>(AsRvalue(proc), AsRvalue(args)));
427-
});
428-
}
429-
430-
static const evaluate::ProcedureDesignator &
431-
getProcedureDesignator(const semantics::SomeExpr &call) {
432-
const evaluate::ProcedureDesignator *proc = GetProc{}(call);
433-
assert(proc && "Call has no procedure designator");
434-
return *proc;
435-
}
436-
437-
static semantics::SomeExpr //
438-
genReducedMinMax(const semantics::SomeExpr &orig,
439-
const semantics::SomeExpr *atomArg,
440-
const std::vector<semantics::SomeExpr> &args) {
441-
// Take a list of arguments to a min/max operation, e.g. [a0, a1, ...]
442-
// One of the a_i's, say a_t, must be atomArg.
443-
// Generate tmp = min/max(a0, a1, ... [except a_t]). Then generate
444-
// call = min/max(a_t, tmp).
445-
// Return "call".
446-
447-
// The min/max intrinsics have 2 mandatory arguments, the rest is optional.
448-
// Make sure that the "tmp = min/max(...)" doesn't promote an optional
449-
// argument to a non-optional position. This could happen if a_t is at
450-
// position 0 or 1.
451-
if (args.size() <= 2)
452-
return orig;
453-
454-
evaluate::ActualArguments nonAtoms;
455-
456-
auto AsActual = [](const semantics::SomeExpr &x) {
457-
semantics::SomeExpr copy = x;
458-
return evaluate::ActualArgument(std::move(copy));
459-
};
460-
// Semantic checks guarantee that the "atom" shows exactly once in the
461-
// argument list (with potential conversions around it).
462-
// For the first two (non-optional) arguments, if "atom" is among them,
463-
// replace it with another occurrence of the other non-optional argument.
464-
if (atomArg == &args[0]) {
465-
// (atom, x, y...) -> (x, x, y...)
466-
nonAtoms.push_back(AsActual(args[1]));
467-
nonAtoms.push_back(AsActual(args[1]));
468-
} else if (atomArg == &args[1]) {
469-
// (x, atom, y...) -> (x, x, y...)
470-
nonAtoms.push_back(AsActual(args[0]));
471-
nonAtoms.push_back(AsActual(args[0]));
472-
} else {
473-
// (x, y, z...) -> unchanged
474-
nonAtoms.push_back(AsActual(args[0]));
475-
nonAtoms.push_back(AsActual(args[1]));
476-
}
477-
478-
// The rest of arguments are optional, so we can just skip "atom".
479-
for (size_t i = 2, e = args.size(); i != e; ++i) {
480-
if (atomArg != &args[i])
481-
nonAtoms.push_back(AsActual(args[i]));
482-
}
483-
484-
// The type of the intermediate min/max is the same as the type of its
485-
// arguments, which may be different from the type of the original
486-
// expression. The original expression may have additional coverts.
487-
auto tmp =
488-
makeCall(*atomArg->GetType(), getProcedureDesignator(orig), nonAtoms);
489-
semantics::SomeExpr call = orig;
490-
replaceArgs(call, {AsActual(*atomArg), AsActual(tmp)});
491-
return call;
492-
}
493-
494242
static mlir::Operation * //
495243
genAtomicRead(lower::AbstractConverter &converter,
496244
semantics::SemanticsContext &semaCtx, mlir::Location loc,
@@ -610,25 +358,6 @@ genAtomicUpdate(lower::AbstractConverter &converter,
610358
auto [opcode, args] = evaluate::GetTopLevelOperationIgnoreResizing(input);
611359
assert(!args.empty() && "Update operation without arguments");
612360

613-
// Pass args as an argument to avoid capturing a structured binding.
614-
const semantics::SomeExpr *atomArg = [&](auto &args) {
615-
for (const semantics::SomeExpr &e : args) {
616-
if (evaluate::IsSameOrConvertOf(e, atom))
617-
return &e;
618-
}
619-
llvm_unreachable("Atomic variable not in argument list");
620-
}(args);
621-
622-
if (opcode == evaluate::operation::Operator::Min ||
623-
opcode == evaluate::operation::Operator::Max) {
624-
// Min and max operations are expanded inline, so reduce them to
625-
// operations with exactly two (non-optional) arguments.
626-
rhs = genReducedMinMax(rhs, atomArg, args);
627-
input = *evaluate::GetConvertInput(rhs);
628-
std::tie(opcode, args) =
629-
evaluate::GetTopLevelOperationIgnoreResizing(input);
630-
atomArg = nullptr; // No longer valid.
631-
}
632361
for (auto &arg : args) {
633362
if (!evaluate::IsSameOrConvertOf(arg, atom)) {
634363
mlir::Value val = fir::getBase(converter.genExprValue(arg, naCtx, &loc));

0 commit comments

Comments
 (0)