Skip to content

Commit 8c1dad7

Browse files
committed
compiler: fix complex reductions for gnu
1 parent a49020d commit 8c1dad7

File tree

4 files changed

+109
-4
lines changed

4 files changed

+109
-4
lines changed

devito/mpi/distributed.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
def cleanup():
4141
devito_mpi_finalize()
4242
atexit.register(cleanup)
43-
except ImportError as e:
43+
except (RuntimeError, ImportError) as e:
4444
# Dummy fallback in case mpi4py/MPI aren't available
4545
class NoneMetaclass(type):
4646
def __getattr__(self, name):

devito/passes/iet/languages/openmp.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,12 @@
22
from packaging.version import Version
33

44
import cgen as c
5+
import numpy as np
56
from sympy import And, Ne, Not
67

78
from devito.arch import AMDGPUX, NVIDIAX, INTELGPUX, PVC
89
from devito.arch.compiler import GNUCompiler, NvidiaCompiler
10+
from devito.exceptions import InvalidOperator
911
from devito.ir import (Call, Conditional, DeviceCall, List, Pragma, Prodder,
1012
ParallelBlock, PointerCast, While, FindSymbols)
1113
from devito.passes.iet.definitions import DataManager, DeviceAwareDataManager
@@ -18,6 +20,7 @@
1820
from devito.passes.iet.languages.C import CBB
1921
from devito.passes.iet.languages.CXX import CXXBB
2022
from devito.symbolics import CondEq, DefFunction
23+
from devito.symbolics.extended_sympy import UnaryOp
2124
from devito.tools import filter_ordered
2225

2326
__all__ = ['SimdOmpizer', 'Ompizer', 'OmpIteration', 'OmpRegion',
@@ -113,6 +116,44 @@ def _generate(self):
113116
return self.pragma % (joins(*items), n)
114117

115118

119+
class RealExt(UnaryOp):
120+
121+
_op = '__real__ '
122+
123+
124+
class ImagExt(UnaryOp):
125+
126+
_op = '__imag__ '
127+
128+
129+
def atomic_add(i, pragmas):
130+
lhs, rhs = i.expr.lhs, i.expr.rhs
131+
if (np.issubdtype(lhs.dtype, np.complexfloating)
132+
and np.issubdtype(rhs.dtype, np.complexfloating)):
133+
# Complex i, complex j
134+
# Atomic add real and imaginary parts separately
135+
lhsr, rhsr = RealExt(lhs), RealExt(rhs)
136+
lhsi, rhsi = ImagExt(lhs), ImagExt(rhs)
137+
real = i._rebuild(expr=i.expr._rebuild(lhs=lhsr, rhs=rhsr),
138+
pragmas=pragmas)
139+
imag = i._rebuild(expr=i.expr._rebuild(lhs=lhsi, rhs=rhsi),
140+
pragmas=pragmas)
141+
return List(body=[real, imag])
142+
143+
elif (np.issubdtype(lhs.dtype, np.complexfloating)
144+
and not np.issubdtype(rhs.dtype, np.complexfloating)):
145+
# Complex i, real j
146+
# Atomic add j to real part of i
147+
lhsr, rhsr = RealExt(lhs), rhs
148+
real = i._rebuild(expr=i.expr._rebuild(lhs=lhsr, rhs=rhsr),
149+
pragmas=pragmas)
150+
return real
151+
else:
152+
# Real i, complex j
153+
raise InvalidOperator("Atomic add not implemented for real "
154+
"Functions with complex increments")
155+
156+
116157
class AbstractOmpBB(LangBB):
117158

118159
mapper = {
@@ -134,7 +175,8 @@ class AbstractOmpBB(LangBB):
134175
'simd-for-aligned': lambda n, *a:
135176
SimdForAligned('omp simd aligned(%s:%d)', arguments=(n, *a)),
136177
'atomic':
137-
Pragma('omp atomic update')
178+
Pragma('omp atomic update'),
179+
'split-atomic': lambda i: atomic_add(i, Pragma('omp atomic update'))
138180
}
139181

140182
Region = OmpRegion
@@ -241,6 +283,13 @@ def _support_array_reduction(cls, compiler):
241283
else:
242284
return True
243285

286+
@classmethod
287+
def _support_complex_reduction(cls, compiler):
288+
if isinstance(compiler, GNUCompiler):
289+
# Gcc doesn't supports complex reduction
290+
return False
291+
return True
292+
244293

245294
class Ompizer(AbstractOmpizer):
246295
langbb = OmpBB

devito/passes/iet/parpragma.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,10 @@ class PragmaSimdTransformer(PragmaTransformer):
4242
def _support_array_reduction(cls, compiler):
4343
return True
4444

45+
@classmethod
46+
def _support_complex_reduction(cls, compiler):
47+
return False
48+
4549
@property
4650
def simd_reg_nbytes(self):
4751
return self.platform.simd_reg_nbytes
@@ -238,8 +242,13 @@ def _make_reductions(self, partree):
238242
# Implement reduction
239243
mapper = {partree.root: partree.root._rebuild(reduction=reductions)}
240244
elif all(i is OpInc for _, _, i in reductions):
241-
# Use atomic increments
242-
mapper = {i: i._rebuild(pragmas=self.langbb['atomic']) for i in exprs}
245+
test2 = not self._support_complex_reduction(self.compiler) and \
246+
any(np.iscomplexobj(i.dtype(0)) for i, _, _ in reductions)
247+
if test2:
248+
mapper = {i: self.langbb['split-atomic'](i) for i in exprs}
249+
else:
250+
# Use atomic increments
251+
mapper = {i: i._rebuild(pragmas=self.langbb['atomic']) for i in exprs}
243252
else:
244253
raise NotImplementedError
245254

tests/test_dtypes.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,13 @@
22
import pytest
33
import sympy
44

5+
from conftest import skipif
6+
57
from devito import (
68
Constant, Eq, Function, Grid, Operator, exp, log, sin, configuration
79
)
10+
from devito.arch.compiler import GNUCompiler
11+
from devito.exceptions import InvalidOperator
812
from devito.ir.cgen.printer import BasePrinter
913
from devito.passes.iet.langbase import LangBB
1014
from devito.passes.iet.languages.C import CBB, CPrinter
@@ -13,6 +17,7 @@
1317
from devito.symbolics.extended_dtypes import ctypes_vector_mapper
1418
from devito.types.basic import Basic, Scalar, Symbol
1519
from devito.types.dense import TimeFunction
20+
from devito.types.sparse import SparseTimeFunction
1621

1722
# Mappers for language-specific types and headers
1823
_languages: dict[str, type[LangBB]] = {
@@ -274,3 +279,45 @@ def test_complex_space_deriv(dtype: np.dtype[np.complexfloating]) -> None:
274279
dfdy = h.data.T[1:-1, 1:-1]
275280
assert np.allclose(dfdx, np.ones((5, 5), dtype=dtype))
276281
assert np.allclose(dfdy, np.ones((5, 5), dtype=dtype))
282+
283+
284+
@skipif('device')
285+
@pytest.mark.parametrize('dtypeu', [np.float32, np.complex64, np.complex128])
286+
def test_complex_reduction(dtypeu: np.dtype[np.complexfloating]) -> None:
287+
"""
288+
Tests reductions over complex-valued functions.
289+
"""
290+
grid = Grid((11, 11))
291+
292+
u = TimeFunction(name="u", grid=grid, space_order=2, time_order=1, dtype=dtypeu)
293+
for dtypes in [dtypeu, dtypeu(0).real.__class__]:
294+
u.data.fill(0)
295+
s = SparseTimeFunction(name="s", grid=grid, npoint=1, nt=10, dtype=dtypes)
296+
if np.issubdtype(dtypes, np.complexfloating):
297+
s.data[:] = 1 + 2j
298+
expected = 8. + 16.j
299+
else:
300+
s.data[:] = 1
301+
expected = 8.
302+
s.coordinates.data[:] = [.5, .5]
303+
304+
# s complex and u real should error
305+
if np.issubdtype(dtypeu, np.floating) and \
306+
np.issubdtype(dtypes, np.complexfloating):
307+
with pytest.raises(InvalidOperator):
308+
op = Operator([Eq(u.forward, u)] + s.inject(u.forward, expr=s))
309+
continue
310+
else:
311+
op = Operator([Eq(u.forward, u)] + s.inject(u.forward, expr=s))
312+
op()
313+
314+
if isinstance(configuration['compiler'], GNUCompiler) and \
315+
np.issubdtype(dtypeu, np.complexfloating):
316+
ru = '__real__ u[t1][rsx + posx + 2][rsy + posy + 2]'
317+
iu = '__imag__ u[t1][rsx + posx + 2][rsy + posy + 2]'
318+
assert f'{ru} += __real__ r0' in str(op)
319+
assert f'{iu} += __imag__ r0' in str(op)
320+
else:
321+
assert 'u[t1][rsx + posx + 2][rsy + posy + 2] += r0' in str(op)
322+
323+
assert np.isclose(u.data[0, 5, 5], expected)

0 commit comments

Comments
 (0)