Skip to content

Commit 21e1ab3

Browse files
authored
[flang][OpenMP] Reassociate floating-point ATOMIC update expressions (#155840)
This is a follow-up to PR153488, this time the reassociation is enabled for floating-point expressions, but only when associative-nath is enabled in the language options. This can be done via -ffast-math on the command line.
1 parent 4f6032f commit 21e1ab3

File tree

3 files changed

+145
-7
lines changed

3 files changed

+145
-7
lines changed

flang/include/flang/Evaluate/match.h

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#ifndef FORTRAN_EVALUATE_MATCH_H_
99
#define FORTRAN_EVALUATE_MATCH_H_
1010

11+
#include "flang/Common/Fortran-consts.h"
1112
#include "flang/Common/visit.h"
1213
#include "flang/Evaluate/expression.h"
1314
#include "llvm/ADT/STLExtras.h"
@@ -34,15 +35,29 @@ struct IsOperation<T, std::void_t<decltype(T::operands)>> {
3435
template <typename T>
3536
constexpr bool is_operation_v{detail::IsOperation<T>::value};
3637

37-
template <typename T>
38-
const evaluate::Expr<T> &deparen(const evaluate::Expr<T> &x) {
39-
if (auto *parens{std::get_if<evaluate::Parentheses<T>>(&x.u)}) {
38+
template <common::TypeCategory C, int K>
39+
const evaluate::Expr<Type<C, K>> &deparen(const evaluate::Expr<Type<C, K>> &x) {
40+
if (auto *parens{std::get_if<Parentheses<Type<C, K>>>(&x.u)}) {
4041
return deparen(parens->template operand<0>());
4142
} else {
4243
return x;
4344
}
4445
}
4546

47+
template <common::TypeCategory C>
48+
const evaluate::Expr<SomeKind<C>> &deparen(
49+
const evaluate::Expr<SomeKind<C>> &x) {
50+
return x;
51+
}
52+
53+
// Some expressions (e.g. TypelessExpression) don't allow parentheses, while
54+
// those that do have Expr<Type> as the argument to the parentheses. This means
55+
// that there is no consistent return type that works for all expressions.
56+
// Delete this overload explicitly so an attempt to use it creates a clearer
57+
// error message.
58+
const evaluate::Expr<SomeType> &deparen(
59+
const evaluate::Expr<SomeType> &) = delete;
60+
4661
// Expr<T> matchers (patterns)
4762
//
4863
// Each pattern should implement

flang/lib/Semantics/check-omp-atomic.cpp

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,22 @@ struct IsIntegral<evaluate::Type<C, K>> {
6767

6868
template <typename T> constexpr bool is_integral_v{IsIntegral<T>::value};
6969

70+
template <typename...> struct IsFloatingPoint {
71+
static constexpr bool value{false};
72+
};
73+
74+
template <common::TypeCategory C, int K>
75+
struct IsFloatingPoint<evaluate::Type<C, K>> {
76+
static constexpr bool value{//
77+
C == common::TypeCategory::Real || C == common::TypeCategory::Complex};
78+
};
79+
80+
template <typename T>
81+
constexpr bool is_floating_point_v{IsFloatingPoint<T>::value};
82+
83+
template <typename T>
84+
constexpr bool is_numeric_v{is_integral_v<T> || is_floating_point_v<T>};
85+
7086
template <typename T, typename Op0, typename Op1>
7187
using ReassocOpBase = evaluate::match::AnyOfPattern< //
7288
evaluate::match::Add<T, Op0, Op1>, //
@@ -88,7 +104,8 @@ struct ReassocRewriter : public evaluate::rewrite::Identity {
88104
using Id = evaluate::rewrite::Identity;
89105
struct NonIntegralTag {};
90106

91-
ReassocRewriter(const SomeExpr &atom) : atom_(atom) {}
107+
ReassocRewriter(const SomeExpr &atom, const SemanticsContext &context)
108+
: atom_(atom), context_(context) {}
92109

93110
// Try to find cases where the input expression is of the form
94111
// (1) (a . b) . c, or
@@ -102,8 +119,13 @@ struct ReassocRewriter : public evaluate::rewrite::Identity {
102119
// For example, assuming x is the atomic variable:
103120
// (a + x) + b -> (a + b) + x, i.e. (conceptually) swap x and b.
104121
template <typename T, typename U,
105-
typename = std::enable_if_t<is_integral_v<T>>>
122+
typename = std::enable_if_t<is_numeric_v<T>>>
106123
evaluate::Expr<T> operator()(evaluate::Expr<T> &&x, const U &u) {
124+
if constexpr (is_floating_point_v<T>) {
125+
if (!context_.langOptions().AssociativeMath) {
126+
return Id::operator()(std::move(x), u);
127+
}
128+
}
107129
// As per the above comment, there are 3 subexpressions involved in this
108130
// transformation. A match::Expr<T> will match evaluate::Expr<U> when T is
109131
// same as U, plus it will store a pointer (ref) to the matched expression.
@@ -169,7 +191,7 @@ struct ReassocRewriter : public evaluate::rewrite::Identity {
169191
}
170192

171193
template <typename T, typename U,
172-
typename = std::enable_if_t<!is_integral_v<T>>>
194+
typename = std::enable_if_t<!is_numeric_v<T>>>
173195
evaluate::Expr<T> operator()(
174196
evaluate::Expr<T> &&x, const U &u, NonIntegralTag = {}) {
175197
return Id::operator()(std::move(x), u);
@@ -181,6 +203,7 @@ struct ReassocRewriter : public evaluate::rewrite::Identity {
181203
}
182204

183205
const SomeExpr &atom_;
206+
const SemanticsContext &context_;
184207
};
185208

186209
struct AnalyzedCondStmt {
@@ -809,7 +832,7 @@ OmpStructureChecker::CheckAtomicUpdateAssignment(
809832
CheckStorageOverlap(atom, GetNonAtomArguments(atom, update.rhs), source);
810833
return std::nullopt;
811834
} else if (tryReassoc) {
812-
ReassocRewriter ra(atom);
835+
ReassocRewriter ra(atom, context_);
813836
SomeExpr raRhs{evaluate::rewrite::Mutator(ra)(update.rhs)};
814837

815838
std::tie(hasErrors, tryReassoc) = CheckAtomicUpdateAssignmentRhs(
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
!RUN: %flang_fc1 -emit-hlfir -ffast-math -fopenmp -fopenmp-version=60 %s -o - | FileCheck %s
2+
3+
subroutine f00(x, y)
4+
implicit none
5+
real :: x, y
6+
7+
!$omp atomic update
8+
x = ((x + 1) + y) + 2
9+
end
10+
11+
!CHECK-LABEL: func.func @_QPf00
12+
!CHECK: %[[X:[0-9]+]]:2 = hlfir.declare %arg0
13+
!CHECK: %[[Y:[0-9]+]]:2 = hlfir.declare %arg1
14+
!CHECK: %cst = arith.constant 1.000000e+00 : f32
15+
!CHECK: %[[LOAD_Y:[0-9]+]] = fir.load %[[Y]]#0 : !fir.ref<f32>
16+
!CHECK: %[[Y_1:[0-9]+]] = arith.addf %cst, %[[LOAD_Y]] fastmath<fast> : f32
17+
!CHECK: %cst_0 = arith.constant 2.000000e+00 : f32
18+
!CHECK: %[[Y_1_2:[0-9]+]] = arith.addf %[[Y_1]], %cst_0 fastmath<fast> : f32
19+
!CHECK: omp.atomic.update memory_order(relaxed) %[[X]]#0 : !fir.ref<f32> {
20+
!CHECK: ^bb0(%[[ARG:arg[0-9]+]]: f32):
21+
!CHECK: %[[ARG_P:[0-9]+]] = arith.addf %[[ARG]], %[[Y_1_2]] fastmath<fast> : f32
22+
!CHECK: omp.yield(%[[ARG_P]] : f32)
23+
!CHECK: }
24+
25+
26+
subroutine f01(x, y, z)
27+
implicit none
28+
complex :: x, y, z
29+
30+
!$omp atomic update
31+
x = (x + y) + z
32+
end
33+
34+
!CHECK-LABEL: func.func @_QPf01
35+
!CHECK: %[[X:[0-9]+]]:2 = hlfir.declare %arg0
36+
!CHECK: %[[Y:[0-9]+]]:2 = hlfir.declare %arg1
37+
!CHECK: %[[Z:[0-9]+]]:2 = hlfir.declare %arg2
38+
!CHECK: %[[LOAD_Y:[0-9]+]] = fir.load %[[Y]]#0 : !fir.ref<complex<f32>>
39+
!CHECK: %[[LOAD_Z:[0-9]+]] = fir.load %[[Z]]#0 : !fir.ref<complex<f32>>
40+
!CHECK: %[[Y_Z:[0-9]+]] = fir.addc %[[LOAD_Y]], %[[LOAD_Z]] {fastmath = #arith.fastmath<fast>} : complex<f32>
41+
!CHECK: omp.atomic.update memory_order(relaxed) %[[X]]#0 : !fir.ref<complex<f32>> {
42+
!CHECK: ^bb0(%[[ARG:arg[0-9]+]]: complex<f32>):
43+
!CHECK: %[[ARG_P:[0-9]+]] = fir.addc %[[ARG]], %[[Y_Z]] {fastmath = #arith.fastmath<fast>} : complex<f32>
44+
!CHECK: omp.yield(%[[ARG_P]] : complex<f32>)
45+
!CHECK: }
46+
47+
48+
subroutine f02(x, y)
49+
implicit none
50+
complex :: x
51+
real :: y
52+
53+
!$omp atomic update
54+
x = (real(x) + y) + 1
55+
end
56+
57+
!CHECK-LABEL: func.func @_QPf02
58+
!CHECK: %[[X:[0-9]+]]:2 = hlfir.declare %arg0
59+
!CHECK: %[[Y:[0-9]+]]:2 = hlfir.declare %arg1
60+
!CHECK: %[[LOAD_Y:[0-9]+]] = fir.load %[[Y]]#0 : !fir.ref<f32>
61+
!CHECK: %cst = arith.constant 1.000000e+00 : f32
62+
!CHECK: %[[Y_1:[0-9]+]] = arith.addf %[[LOAD_Y]], %cst fastmath<fast> : f32
63+
!CHECK: omp.atomic.update memory_order(relaxed) %[[X]]#0 : !fir.ref<complex<f32>> {
64+
!CHECK: ^bb0(%[[ARG:arg[0-9]+]]: complex<f32>):
65+
!CHECK: %[[ARG_X:[0-9]+]] = fir.extract_value %[[ARG]], [0 : index] : (complex<f32>) -> f32
66+
!CHECK: %[[ARG_P:[0-9]+]] = arith.addf %[[ARG_X]], %[[Y_1]] fastmath<fast> : f32
67+
!CHECK: %cst_0 = arith.constant 0.000000e+00 : f32
68+
!CHECK: %[[CPLX:[0-9]+]] = fir.undefined complex<f32>
69+
!CHECK: %[[CPLX_I:[0-9]+]] = fir.insert_value %[[CPLX]], %[[ARG_P]], [0 : index] : (complex<f32>, f32) -> complex<f32>
70+
!CHECK: %[[CPLX_R:[0-9]+]] = fir.insert_value %[[CPLX_I]], %cst_0, [1 : index] : (complex<f32>, f32) -> complex<f32>
71+
!CHECK: omp.yield(%[[CPLX_R]] : complex<f32>)
72+
!CHECK: }
73+
74+
75+
subroutine f03(x, a, b, c)
76+
implicit none
77+
real(kind=4) :: x
78+
real(kind=8) :: a, b, c
79+
80+
!$omp atomic update
81+
x = ((b + a) + x) + c
82+
end
83+
84+
!CHECK-LABEL: func.func @_QPf03
85+
!CHECK: %[[A:[0-9]+]]:2 = hlfir.declare %arg1
86+
!CHECK: %[[B:[0-9]+]]:2 = hlfir.declare %arg2
87+
!CHECK: %[[C:[0-9]+]]:2 = hlfir.declare %arg3
88+
!CHECK: %[[X:[0-9]+]]:2 = hlfir.declare %arg0
89+
!CHECK: %[[LOAD_B:[0-9]+]] = fir.load %[[B]]#0 : !fir.ref<f64>
90+
!CHECK: %[[LOAD_A:[0-9]+]] = fir.load %[[A]]#0 : !fir.ref<f64>
91+
!CHECK: %[[A_B:[0-9]+]] = arith.addf %[[LOAD_B]], %[[LOAD_A]] fastmath<fast> : f64
92+
!CHECK: %[[LOAD_C:[0-9]+]] = fir.load %[[C]]#0 : !fir.ref<f64>
93+
!CHECK: %[[A_B_C:[0-9]+]] = arith.addf %[[A_B]], %[[LOAD_C]] fastmath<fast> : f64
94+
!CHECK: omp.atomic.update memory_order(relaxed) %[[X]]#0 : !fir.ref<f32> {
95+
!CHECK: ^bb0(%[[ARG:arg[0-9]+]]: f32):
96+
!CHECK: %[[ARG_8:[0-9]+]] = fir.convert %[[ARG]] : (f32) -> f64
97+
!CHECK: %[[ARG_P:[0-9]+]] = arith.addf %[[ARG_8]], %[[A_B_C]] fastmath<fast> : f64
98+
!CHECK: %[[ARG_4:[0-9]+]] = fir.convert %[[ARG_P]] : (f64) -> f32
99+
!CHECK: omp.yield(%[[ARG_4]] : f32)
100+
!CHECK: }

0 commit comments

Comments
 (0)