Skip to content

Commit e58de06

Browse files
authored
[flang][OpenMP] Reassociate logical ATOMIC update expressions (#156961)
This is a follow-up to PR153488 and PR155840, this time for expressions of logical type. The handling of logical operations in Expr<T> differs slightly from regular arithmetic operations. The difference is that the specific operation (e.g. and, or, etc.) is not a part of the type, but stored as a data member. Both the matching code and the reconstruction code needed to be extended to correctly handle the data member. This fixes #144944
1 parent b85d0c5 commit e58de06

File tree

3 files changed

+252
-30
lines changed

3 files changed

+252
-30
lines changed

flang/include/flang/Evaluate/match.h

Lines changed: 59 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include "flang/Common/Fortran-consts.h"
1212
#include "flang/Common/visit.h"
1313
#include "flang/Evaluate/expression.h"
14+
#include "flang/Support/Fortran.h"
1415
#include "llvm/ADT/STLExtras.h"
1516

1617
#include <tuple>
@@ -86,9 +87,12 @@ template <typename T> struct TypePattern {
8687
mutable const MatchType *ref{nullptr};
8788
};
8889

89-
/// Matches one of the patterns provided as template arguments. All of these
90-
/// patterns should have the same number of operands, i.e. they all should
91-
/// try to match input expression with the same number of children, i.e.
90+
/// Matches one of the patterns provided as template arguments.
91+
/// Upon creation of an AnyOfPattern object with some arguments, say args,
92+
/// each of the pattern objects will be created using args as arguments to
93+
/// the constructor. This means that each of the patterns should be
94+
/// constructible from args, in particular all patterns should take the same
95+
/// number of inputs. So, for example,
9296
/// AnyOfPattern<SomeBinaryOp, OtherBinaryOp> is ok, whereas
9397
/// AnyOfPattern<SomeBinaryOp, SomeTernaryOp> is not.
9498
template <typename... Patterns> struct AnyOfPattern {
@@ -178,16 +182,67 @@ struct OperationPattern : public TypePattern<OpType> {
178182
};
179183

180184
template <typename OpType, typename... Ops>
181-
OperationPattern(const Ops &...ops, llvm::type_identity<OpType>)
185+
OperationPattern(const Ops &..., llvm::type_identity<OpType>)
182186
-> OperationPattern<OpType, Ops...>;
183187

188+
// Encode the actual operator in the type, so that the class is constructible
189+
// only from operand patterns. This will make it usable in AnyOfPattern.
190+
template <common::LogicalOperator Operator, typename ValType, typename... Ops>
191+
struct LogicalOperationPattern
192+
: public OperationPattern<LogicalOperation<ValType::kind>, Ops...> {
193+
using Base = OperationPattern<LogicalOperation<ValType::kind>, Ops...>;
194+
static constexpr common::LogicalOperator opCode{Operator};
195+
196+
private:
197+
template <int K> bool matchOp(const LogicalOperation<K> &op) const {
198+
if constexpr (ValType::kind == K) {
199+
return op.logicalOperator == opCode;
200+
}
201+
return false;
202+
}
203+
template <typename U> bool matchOp(const U &) const { return false; }
204+
205+
public:
206+
LogicalOperationPattern(const Ops &...ops, llvm::type_identity<ValType> = {})
207+
: Base(ops...) {}
208+
209+
template <typename T> bool match(const evaluate::Expr<T> &input) const {
210+
// All logical operations (for a given type T) have the same operation
211+
// type (LogicalOperation<T::kind>), so the type-based matching will not
212+
// be able to tell specific operations from one another.
213+
// Check the operation code first, if that matches then use the the
214+
// base class's match.
215+
if (common::visit([&](auto &&s) { return matchOp(s); }, deparen(input).u)) {
216+
return Base::match(input);
217+
} else {
218+
return false;
219+
}
220+
}
221+
222+
template <typename U> bool match(const U &input) const { //
223+
return false;
224+
}
225+
};
226+
227+
// No deduction guide for LogicalOperationPattern, since the "Operator"
228+
// parameter cannot be deduced from the constructor arguments.
229+
184230
// Namespace-level definitions
185231

186232
template <typename T> using Expr = ExprPattern<T>;
187233

188234
template <typename OpType, typename... Ops>
189235
using Op = OperationPattern<OpType, Ops...>;
190236

237+
template <common::LogicalOperator Operator, typename ValType, typename... Ops>
238+
using LogicalOp = LogicalOperationPattern<Operator, ValType, Ops...>;
239+
240+
template <common::LogicalOperator Operator, typename Type, typename Op0,
241+
typename Op1>
242+
LogicalOp<Operator, Type, Op0, Op1> logical(const Op0 &op0, const Op1 &op1) {
243+
return LogicalOp<Operator, Type, Op0, Op1>(op0, op1);
244+
}
245+
191246
template <typename Pattern, typename Input>
192247
bool match(const Pattern &pattern, const Input &input) {
193248
return pattern.match(input);

flang/lib/Semantics/check-omp-atomic.cpp

Lines changed: 56 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,7 @@ template <common::TypeCategory C, int K>
6161
struct IsIntegral<evaluate::Type<C, K>> {
6262
static constexpr bool value{//
6363
C == common::TypeCategory::Integer ||
64-
C == common::TypeCategory::Unsigned ||
65-
C == common::TypeCategory::Logical};
64+
C == common::TypeCategory::Unsigned};
6665
};
6766

6867
template <typename T> constexpr bool is_integral_v{IsIntegral<T>::value};
@@ -83,10 +82,25 @@ constexpr bool is_floating_point_v{IsFloatingPoint<T>::value};
8382
template <typename T>
8483
constexpr bool is_numeric_v{is_integral_v<T> || is_floating_point_v<T>};
8584

85+
template <typename...> struct IsLogical {
86+
static constexpr bool value{false};
87+
};
88+
89+
template <common::TypeCategory C, int K>
90+
struct IsLogical<evaluate::Type<C, K>> {
91+
static constexpr bool value{C == common::TypeCategory::Logical};
92+
};
93+
94+
template <typename T> constexpr bool is_logical_v{IsLogical<T>::value};
95+
8696
template <typename T, typename Op0, typename Op1>
8797
using ReassocOpBase = evaluate::match::AnyOfPattern< //
8898
evaluate::match::Add<T, Op0, Op1>, //
89-
evaluate::match::Mul<T, Op0, Op1>>;
99+
evaluate::match::Mul<T, Op0, Op1>, //
100+
evaluate::match::LogicalOp<common::LogicalOperator::And, T, Op0, Op1>,
101+
evaluate::match::LogicalOp<common::LogicalOperator::Or, T, Op0, Op1>,
102+
evaluate::match::LogicalOp<common::LogicalOperator::Eqv, T, Op0, Op1>,
103+
evaluate::match::LogicalOp<common::LogicalOperator::Neqv, T, Op0, Op1>>;
90104

91105
template <typename T, typename Op0, typename Op1>
92106
struct ReassocOp : public ReassocOpBase<T, Op0, Op1> {
@@ -110,16 +124,16 @@ struct ReassocRewriter : public evaluate::rewrite::Identity {
110124
// Try to find cases where the input expression is of the form
111125
// (1) (a . b) . c, or
112126
// (2) a . (b . c),
113-
// where . denotes an associative operation (currently + or *), and a, b, c
114-
// are some subexpresions.
127+
// where . denotes an associative operation, and a, b, c are some
128+
// subexpresions.
115129
// If one of the operands in the nested operation is the atomic variable
116130
// (with some possible type conversions applied to it), bring it to the
117131
// top-level operation, and move the top-level operand into the nested
118132
// operation.
119133
// For example, assuming x is the atomic variable:
120134
// (a + x) + b -> (a + b) + x, i.e. (conceptually) swap x and b.
121135
template <typename T, typename U,
122-
typename = std::enable_if_t<is_numeric_v<T>>>
136+
typename = std::enable_if_t<is_numeric_v<T> || is_logical_v<T>>>
123137
evaluate::Expr<T> operator()(evaluate::Expr<T> &&x, const U &u) {
124138
if constexpr (is_floating_point_v<T>) {
125139
if (!context_.langOptions().AssociativeMath) {
@@ -133,8 +147,8 @@ struct ReassocRewriter : public evaluate::rewrite::Identity {
133147
// some order) from the example above.
134148
evaluate::match::Expr<T> sub[3];
135149
auto inner{reassocOp<T>(sub[0], sub[1])};
136-
auto outer1{reassocOp<T>(inner, sub[2])}; // inner + something
137-
auto outer2{reassocOp<T>(sub[2], inner)}; // something + inner
150+
auto outer1{reassocOp<T>(inner, sub[2])}; // inner . something
151+
auto outer2{reassocOp<T>(sub[2], inner)}; // something . inner
138152
#if !defined(__clang__) && !defined(_MSC_VER) && \
139153
(__GNUC__ < 8 || (__GNUC__ == 8 && __GNUC_MINOR__ < 5))
140154
// If GCC version < 8.5, use this definition. For the other definition
@@ -167,37 +181,53 @@ struct ReassocRewriter : public evaluate::rewrite::Identity {
167181
}
168182
return common::visit(
169183
[&](auto &&s) {
170-
using Expr = evaluate::Expr<T>;
171-
using TypeS = llvm::remove_cvref_t<decltype(s)>;
172-
// This visitor has to be semantically correct for all possible
173-
// types of s even though at runtime s will only be one of the
174-
// matched types.
175-
// Limit the construction to the operation types that we tried
176-
// to match (otherwise TypeS(op1, op2) would fail for non-binary
177-
// operations).
178-
if constexpr (common::HasMember<TypeS, MatchTypes>) {
179-
Expr atom{*sub[atomIdx].ref};
180-
Expr op1{*sub[(atomIdx + 1) % 3].ref};
181-
Expr op2{*sub[(atomIdx + 2) % 3].ref};
182-
return Expr(
183-
TypeS(atom, Expr(TypeS(std::move(op1), std::move(op2)))));
184-
} else {
185-
return Expr(TypeS(s));
186-
}
184+
// Build the new expression from the matched components.
185+
return Reconstruct<T, MatchTypes>(s, *sub[atomIdx].ref,
186+
*sub[(atomIdx + 1) % 3].ref, *sub[(atomIdx + 2) % 3].ref);
187187
},
188188
evaluate::match::deparen(x).u);
189189
}
190190
return Id::operator()(std::move(x), u);
191191
}
192192

193193
template <typename T, typename U,
194-
typename = std::enable_if_t<!is_numeric_v<T>>>
194+
typename = std::enable_if_t<!is_numeric_v<T> && !is_logical_v<T>>>
195195
evaluate::Expr<T> operator()(
196196
evaluate::Expr<T> &&x, const U &u, NonIntegralTag = {}) {
197197
return Id::operator()(std::move(x), u);
198198
}
199199

200200
private:
201+
template <typename T, typename MatchTypes, typename S>
202+
evaluate::Expr<T> Reconstruct(const S &op, evaluate::Expr<T> atom,
203+
evaluate::Expr<T> op1, evaluate::Expr<T> op2) {
204+
using TypeS = llvm::remove_cvref_t<decltype(op)>;
205+
// This function has to be semantically correct for all possible types
206+
// of S even though at runtime s will only be one of the matched types.
207+
// Limit the construction to the operation types that we tried to match
208+
// (otherwise TypeS(op1, op2) would fail for non-binary operations).
209+
if constexpr (!common::HasMember<TypeS, MatchTypes>) {
210+
return evaluate::Expr<T>(TypeS(op));
211+
} else if constexpr (is_logical_v<T>) {
212+
constexpr int K{T::kind};
213+
if constexpr (std::is_same_v<TypeS, evaluate::LogicalOperation<K>>) {
214+
// Logical operators take an extra argument in their constructor,
215+
// so they need their own reconstruction code.
216+
common::LogicalOperator opCode{op.logicalOperator};
217+
return evaluate::Expr<T>(TypeS( //
218+
opCode, std::move(atom),
219+
evaluate::Expr<T>(TypeS( //
220+
opCode, std::move(op1), std::move(op2)))));
221+
}
222+
} else {
223+
// Generic reconstruction.
224+
return evaluate::Expr<T>(TypeS( //
225+
std::move(atom),
226+
evaluate::Expr<T>(TypeS( //
227+
std::move(op1), std::move(op2)))));
228+
}
229+
}
230+
201231
template <typename T> bool IsAtom(const evaluate::Expr<T> &x) const {
202232
return IsSameOrConvertOf(evaluate::AsGenericExpr(AsRvalue(x)), atom_);
203233
}
Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
!RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=60 %s -o - | FileCheck %s
2+
3+
subroutine f00(x, y, z)
4+
implicit none
5+
logical :: x, y, z
6+
7+
!$omp atomic update
8+
x = x .and. y .and. z
9+
end
10+
11+
!CHECK-LABEL: func.func @_QPf00
12+
!CHECK: %[[X:[0-9]+]]:2 = hlfir.declare %arg0
13+
!CHECK: %[[Y:[0-9]+]]:2 = hlfir.declare %arg1
14+
!CHECK: %[[Z:[0-9]+]]:2 = hlfir.declare %arg2
15+
!CHECK: %[[LOAD_Y:[0-9]+]] = fir.load %[[Y]]#0 : !fir.ref<!fir.logical<4>>
16+
!CHECK: %[[LOAD_Z:[0-9]+]] = fir.load %[[Z]]#0 : !fir.ref<!fir.logical<4>>
17+
!CHECK: %[[CVT_Y:[0-9]+]] = fir.convert %[[LOAD_Y]] : (!fir.logical<4>) -> i1
18+
!CHECK: %[[CVT_Z:[0-9]+]] = fir.convert %[[LOAD_Z]] : (!fir.logical<4>) -> i1
19+
!CHECK: %[[AND_YZ:[0-9]+]] = arith.andi %[[CVT_Y]], %[[CVT_Z]] : i1
20+
!CHECK: omp.atomic.update memory_order(relaxed) %[[X]]#0 : !fir.ref<!fir.logical<4>> {
21+
!CHECK: ^bb0(%[[ARG:arg[0-9]+]]: !fir.logical<4>):
22+
!CHECK: %[[CVT_X:[0-9]+]] = fir.convert %[[ARG]] : (!fir.logical<4>) -> i1
23+
!CHECK: %[[AND_XYZ:[0-9]+]] = arith.andi %[[CVT_X]], %[[AND_YZ]] : i1
24+
!CHECK: %[[RET:[0-9]+]] = fir.convert %[[AND_XYZ]] : (i1) -> !fir.logical<4>
25+
!CHECK: omp.yield(%[[RET]] : !fir.logical<4>)
26+
!CHECK: }
27+
28+
29+
subroutine f01(x, y, z)
30+
implicit none
31+
logical :: x, y, z
32+
33+
!$omp atomic update
34+
x = x .or. y .or. z
35+
end
36+
37+
!CHECK-LABEL: func.func @_QPf01
38+
!CHECK: %[[X:[0-9]+]]:2 = hlfir.declare %arg0
39+
!CHECK: %[[Y:[0-9]+]]:2 = hlfir.declare %arg1
40+
!CHECK: %[[Z:[0-9]+]]:2 = hlfir.declare %arg2
41+
!CHECK: %[[LOAD_Y:[0-9]+]] = fir.load %[[Y]]#0 : !fir.ref<!fir.logical<4>>
42+
!CHECK: %[[LOAD_Z:[0-9]+]] = fir.load %[[Z]]#0 : !fir.ref<!fir.logical<4>>
43+
!CHECK: %[[CVT_Y:[0-9]+]] = fir.convert %[[LOAD_Y]] : (!fir.logical<4>) -> i1
44+
!CHECK: %[[CVT_Z:[0-9]+]] = fir.convert %[[LOAD_Z]] : (!fir.logical<4>) -> i1
45+
!CHECK: %[[OR_YZ:[0-9]+]] = arith.ori %[[CVT_Y]], %[[CVT_Z]] : i1
46+
!CHECK: omp.atomic.update memory_order(relaxed) %[[X]]#0 : !fir.ref<!fir.logical<4>> {
47+
!CHECK: ^bb0(%[[ARG:arg[0-9]+]]: !fir.logical<4>):
48+
!CHECK: %[[CVT_X:[0-9]+]] = fir.convert %[[ARG]] : (!fir.logical<4>) -> i1
49+
!CHECK: %[[OR_XYZ:[0-9]+]] = arith.ori %[[CVT_X]], %[[OR_YZ]] : i1
50+
!CHECK: %[[RET:[0-9]+]] = fir.convert %[[OR_XYZ]] : (i1) -> !fir.logical<4>
51+
!CHECK: omp.yield(%[[RET]] : !fir.logical<4>)
52+
!CHECK: }
53+
54+
55+
subroutine f02(x, y, z)
56+
implicit none
57+
logical :: x, y, z
58+
59+
!$omp atomic update
60+
x = x .eqv. y .eqv. z
61+
end
62+
63+
!CHECK-LABEL: func.func @_QPf02
64+
!CHECK: %[[X:[0-9]+]]:2 = hlfir.declare %arg0
65+
!CHECK: %[[Y:[0-9]+]]:2 = hlfir.declare %arg1
66+
!CHECK: %[[Z:[0-9]+]]:2 = hlfir.declare %arg2
67+
!CHECK: %[[LOAD_Y:[0-9]+]] = fir.load %[[Y]]#0 : !fir.ref<!fir.logical<4>>
68+
!CHECK: %[[LOAD_Z:[0-9]+]] = fir.load %[[Z]]#0 : !fir.ref<!fir.logical<4>>
69+
!CHECK: %[[CVT_Y:[0-9]+]] = fir.convert %[[LOAD_Y]] : (!fir.logical<4>) -> i1
70+
!CHECK: %[[CVT_Z:[0-9]+]] = fir.convert %[[LOAD_Z]] : (!fir.logical<4>) -> i1
71+
!CHECK: %[[EQV_YZ:[0-9]+]] = arith.cmpi eq, %[[CVT_Y]], %[[CVT_Z]] : i1
72+
!CHECK: omp.atomic.update memory_order(relaxed) %[[X]]#0 : !fir.ref<!fir.logical<4>> {
73+
!CHECK: ^bb0(%[[ARG:arg[0-9]+]]: !fir.logical<4>):
74+
!CHECK: %[[CVT_X:[0-9]+]] = fir.convert %[[ARG]] : (!fir.logical<4>) -> i1
75+
!CHECK: %[[EQV_XYZ:[0-9]+]] = arith.cmpi eq, %[[CVT_X]], %[[EQV_YZ]] : i1
76+
!CHECK: %[[RET:[0-9]+]] = fir.convert %[[EQV_XYZ]] : (i1) -> !fir.logical<4>
77+
!CHECK: omp.yield(%[[RET]] : !fir.logical<4>)
78+
!CHECK: }
79+
80+
81+
subroutine f03(x, y, z)
82+
implicit none
83+
logical :: x, y, z
84+
85+
!$omp atomic update
86+
x = x .neqv. y .neqv. z
87+
end
88+
89+
!CHECK-LABEL: func.func @_QPf03
90+
!CHECK: %[[X:[0-9]+]]:2 = hlfir.declare %arg0
91+
!CHECK: %[[Y:[0-9]+]]:2 = hlfir.declare %arg1
92+
!CHECK: %[[Z:[0-9]+]]:2 = hlfir.declare %arg2
93+
!CHECK: %[[LOAD_Y:[0-9]+]] = fir.load %[[Y]]#0 : !fir.ref<!fir.logical<4>>
94+
!CHECK: %[[LOAD_Z:[0-9]+]] = fir.load %[[Z]]#0 : !fir.ref<!fir.logical<4>>
95+
!CHECK: %[[CVT_Y:[0-9]+]] = fir.convert %[[LOAD_Y]] : (!fir.logical<4>) -> i1
96+
!CHECK: %[[CVT_Z:[0-9]+]] = fir.convert %[[LOAD_Z]] : (!fir.logical<4>) -> i1
97+
!CHECK: %[[NEQV_YZ:[0-9]+]] = arith.cmpi ne, %[[CVT_Y]], %[[CVT_Z]] : i1
98+
!CHECK: omp.atomic.update memory_order(relaxed) %[[X]]#0 : !fir.ref<!fir.logical<4>> {
99+
!CHECK: ^bb0(%[[ARG:arg[0-9]+]]: !fir.logical<4>):
100+
!CHECK: %[[CVT_X:[0-9]+]] = fir.convert %[[ARG]] : (!fir.logical<4>) -> i1
101+
!CHECK: %[[NEQV_XYZ:[0-9]+]] = arith.cmpi ne, %[[CVT_X]], %[[NEQV_YZ]] : i1
102+
!CHECK: %[[RET:[0-9]+]] = fir.convert %[[NEQV_XYZ]] : (i1) -> !fir.logical<4>
103+
!CHECK: omp.yield(%[[RET]] : !fir.logical<4>)
104+
!CHECK: }
105+
106+
107+
subroutine f04(x, a, b, c)
108+
implicit none
109+
logical(kind=4) :: x
110+
logical(kind=8) :: a, b, c
111+
112+
!$omp atomic update
113+
x = ((b .and. a) .and. x) .and. c
114+
end
115+
116+
!CHECK-LABEL: func.func @_QPf04
117+
!CHECK: %[[A:[0-9]+]]:2 = hlfir.declare %arg1
118+
!CHECK: %[[B:[0-9]+]]:2 = hlfir.declare %arg2
119+
!CHECK: %[[C:[0-9]+]]:2 = hlfir.declare %arg3
120+
!CHECK: %[[X:[0-9]+]]:2 = hlfir.declare %arg0
121+
!CHECK: %[[LOAD_B:[0-9]+]] = fir.load %[[B]]#0 : !fir.ref<!fir.logical<8>>
122+
!CHECK: %[[LOAD_A:[0-9]+]] = fir.load %[[A]]#0 : !fir.ref<!fir.logical<8>>
123+
!CHECK: %[[CVT_B:[0-9]+]] = fir.convert %[[LOAD_B]] : (!fir.logical<8>) -> i1
124+
!CHECK: %[[CVT_A:[0-9]+]] = fir.convert %[[LOAD_A]] : (!fir.logical<8>) -> i1
125+
!CHECK: %[[AND_BA:[0-9]+]] = arith.andi %[[CVT_B]], %[[CVT_A]] : i1
126+
!CHECK: %[[LOAD_C:[0-9]+]] = fir.load %[[C]]#0 : !fir.ref<!fir.logical<8>>
127+
!CHECK: %[[CVT_C:[0-9]+]] = fir.convert %[[LOAD_C]] : (!fir.logical<8>) -> i1
128+
!CHECK: %[[AND_BAC:[0-9]+]] = arith.andi %[[AND_BA]], %[[CVT_C]] : i1
129+
!CHECK: omp.atomic.update memory_order(relaxed) %[[X]]#0 : !fir.ref<!fir.logical<4>> {
130+
!CHECK: ^bb0(%[[ARG:arg[0-9]+]]: !fir.logical<4>):
131+
!CHECK: %[[CVT8_X:[0-9]+]] = fir.convert %[[ARG]] : (!fir.logical<4>) -> !fir.logical<8>
132+
!CHECK: %[[CVT_X:[0-9]+]] = fir.convert %[[CVT8_X]] : (!fir.logical<8>) -> i1
133+
!CHECK: %[[AND_XBAC:[0-9]+]] = arith.andi %[[CVT_X]], %[[AND_BAC]] : i1
134+
135+
!CHECK: %[[RET:[0-9]+]] = fir.convert %[[AND_XBAC]] : (i1) -> !fir.logical<4>
136+
!CHECK: omp.yield(%[[RET]] : !fir.logical<4>)
137+
!CHECK: }

0 commit comments

Comments
 (0)