Skip to content

Commit ea1ed48

Browse files
Generic MPI FFT class (#408)
* Added generic MPIFFT problem class * Fixes * Generalized to `xp` in preparation for GPUs * Fixes * Ported Allen-Cahn to generic MPI FFT implementation
1 parent 3036351 commit ea1ed48

File tree

5 files changed

+246
-320
lines changed

5 files changed

+246
-320
lines changed

pySDC/implementations/problem_classes/AllenCahn_MPIFFT.py

Lines changed: 19 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,11 @@
11
import numpy as np
22
from mpi4py import MPI
3-
from mpi4py_fft import PFFT
4-
5-
from pySDC.core.Errors import ProblemError
6-
from pySDC.core.Problem import ptype
7-
from pySDC.implementations.datatype_classes.mesh import mesh, imex_mesh
83

4+
from pySDC.implementations.problem_classes.generic_MPIFFT_Laplacian import IMEX_Laplacian_MPIFFT
95
from mpi4py_fft import newDistArray
106

117

12-
class allencahn_imex(ptype):
8+
class allencahn_imex(IMEX_Laplacian_MPIFFT):
139
r"""
1410
Example implementing the :math:`N`-dimensional Allen-Cahn equation with periodic boundary conditions :math:`u \in [0, 1]^2`
1511
@@ -64,68 +60,21 @@ class allencahn_imex(ptype):
6460
.. [1] https://mpi4py-fft.readthedocs.io/en/latest/
6561
"""
6662

67-
dtype_u = mesh
68-
dtype_f = imex_mesh
69-
7063
def __init__(
7164
self,
72-
nvars=None,
7365
eps=0.04,
7466
radius=0.25,
75-
spectral=None,
7667
dw=0.0,
77-
L=1.0,
7868
init_type='circle',
79-
comm=MPI.COMM_WORLD,
69+
**kwargs,
8070
):
81-
"""Initialization routine"""
82-
83-
if nvars is None:
84-
nvars = (128, 128)
85-
86-
if not (isinstance(nvars, tuple) and len(nvars) > 1):
87-
raise ProblemError('Need at least two dimensions')
88-
89-
# Creating FFT structure
90-
ndim = len(nvars)
91-
axes = tuple(range(ndim))
92-
self.fft = PFFT(comm, list(nvars), axes=axes, dtype=np.float64, collapse=True)
93-
94-
# get test data to figure out type and dimensions
95-
tmp_u = newDistArray(self.fft, spectral)
96-
97-
# invoke super init, passing the communicator and the local dimensions as init
98-
super().__init__(init=(tmp_u.shape, comm, tmp_u.dtype))
99-
self._makeAttributeAndRegister(
100-
'nvars', 'eps', 'radius', 'spectral', 'dw', 'L', 'init_type', 'comm', localVars=locals(), readOnly=True
101-
)
102-
103-
L = np.array([self.L] * ndim, dtype=float)
104-
105-
# get local mesh
106-
X = np.ogrid[self.fft.local_slice(False)]
107-
N = self.fft.global_shape()
108-
for i in range(len(N)):
109-
X[i] = X[i] * L[i] / N[i]
110-
self.X = [np.broadcast_to(x, self.fft.shape(False)) for x in X]
111-
112-
# get local wavenumbers and Laplace operator
113-
s = self.fft.local_slice()
114-
N = self.fft.global_shape()
115-
k = [np.fft.fftfreq(n, 1.0 / n).astype(int) for n in N[:-1]]
116-
k.append(np.fft.rfftfreq(N[-1], 1.0 / N[-1]).astype(int))
117-
K = [ki[si] for ki, si in zip(k, s)]
118-
Ks = np.meshgrid(*K, indexing='ij', sparse=True)
119-
Lp = 2 * np.pi / L
120-
for i in range(ndim):
121-
Ks[i] = (Ks[i] * Lp[i]).astype(float)
122-
K = [np.broadcast_to(k, self.fft.shape(True)) for k in Ks]
123-
K = np.array(K).astype(float)
124-
self.K2 = np.sum(K * K, 0, dtype=float)
125-
126-
# Need this for diagnostics
127-
self.dx = self.L / nvars[0]
128-
self.dy = self.L / nvars[1]
71+
kwargs['L'] = kwargs.get('L', 1.0)
72+
super().__init__(alpha=1.0, dtype=np.dtype('float'), **kwargs)
73+
self._makeAttributeAndRegister('eps', 'radius', 'dw', 'init_type', localVars=locals(), readOnly=True)
74+
75+
def _eval_explicit_part(self, u, t, f_expl):
76+
f_expl[:] = -2.0 / self.eps**2 * u * (1.0 - u) * (1.0 - 2.0 * u) - 6.0 * self.dw * u * (1.0 - u)
77+
return f_expl
12978

13079
def eval_f(self, u, t):
13180
"""
@@ -146,56 +95,24 @@ def eval_f(self, u, t):
14695

14796
f = self.dtype_f(self.init)
14897

98+
f.impl[:] = self._eval_Laplacian(u, f.impl)
99+
149100
if self.spectral:
150101
f.impl = -self.K2 * u
151102

152103
if self.eps > 0:
153104
tmp = self.fft.backward(u)
154-
tmpf = -2.0 / self.eps**2 * tmp * (1.0 - tmp) * (1.0 - 2.0 * tmp) - 6.0 * self.dw * tmp * (1.0 - tmp)
155-
f.expl[:] = self.fft.forward(tmpf)
105+
tmp[:] = self._eval_explicit_part(tmp, t, tmp)
106+
f.expl[:] = self.fft.forward(tmp)
156107

157108
else:
158-
u_hat = self.fft.forward(u)
159-
lap_u_hat = -self.K2 * u_hat
160-
f.impl[:] = self.fft.backward(lap_u_hat, f.impl)
161109

162110
if self.eps > 0:
163-
f.expl = -2.0 / self.eps**2 * u * (1.0 - u) * (1.0 - 2.0 * u) - 6.0 * self.dw * u * (1.0 - u)
111+
f.expl[:] = self._eval_explicit_part(u, t, f.expl)
164112

113+
self.work_counters['rhs']()
165114
return f
166115

167-
def solve_system(self, rhs, factor, u0, t):
168-
"""
169-
Simple FFT solver for the diffusion part.
170-
171-
Parameters
172-
----------
173-
rhs : dtype_f
174-
Right-hand side for the linear system.
175-
factor : float
176-
Abbrev. for the node-to-node stepsize (or any other factor required).
177-
u0 : dtype_u
178-
Initial guess for the iterative solver (not used here so far).
179-
t : float
180-
Current time (e.g. for time-dependent BCs).
181-
182-
Returns
183-
-------
184-
me : dtype_u
185-
The solution as mesh.
186-
"""
187-
188-
if self.spectral:
189-
me = rhs / (1.0 + factor * self.K2)
190-
191-
else:
192-
me = self.dtype_u(self.init)
193-
rhs_hat = self.fft.forward(rhs)
194-
rhs_hat /= 1.0 + factor * self.K2
195-
me[:] = self.fft.backward(rhs_hat)
196-
197-
return me
198-
199116
def u_exact(self, t):
200117
r"""
201118
Routine to compute the exact solution at time :math:`t`.
@@ -289,8 +206,9 @@ def eval_f(self, u, t):
289206

290207
f = self.dtype_f(self.init)
291208

209+
f.impl[:] = self._eval_Laplacian(u, f.impl)
210+
292211
if self.spectral:
293-
f.impl = -self.K2 * u
294212

295213
tmp = newDistArray(self.fft, False)
296214
tmp[:] = self.fft.backward(u, tmp)
@@ -324,9 +242,6 @@ def eval_f(self, u, t):
324242
f.expl[:] = self.fft.forward(tmpf)
325243

326244
else:
327-
u_hat = self.fft.forward(u)
328-
lap_u_hat = -self.K2 * u_hat
329-
f.impl[:] = self.fft.backward(lap_u_hat, f.impl)
330245

331246
if self.eps > 0:
332247
f.expl = -2.0 / self.eps**2 * u * (1.0 - u) * (1.0 - 2.0 * u)
@@ -353,4 +268,5 @@ def eval_f(self, u, t):
353268

354269
f.expl -= 6.0 * dw * u * (1.0 - u)
355270

271+
self.work_counters['rhs']()
356272
return f

pySDC/implementations/problem_classes/Brusselator.py

Lines changed: 24 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,10 @@
11
import numpy as np
22
from mpi4py import MPI
3-
from mpi4py_fft import PFFT
43

5-
from pySDC.core.Errors import ProblemError
6-
from pySDC.core.Problem import ptype, WorkCounter
7-
from pySDC.implementations.datatype_classes.mesh import mesh, imex_mesh
4+
from pySDC.implementations.problem_classes.generic_MPIFFT_Laplacian import IMEX_Laplacian_MPIFFT
85

9-
from mpi4py_fft import newDistArray
106

11-
12-
class Brusselator(ptype):
7+
class Brusselator(IMEX_Laplacian_MPIFFT):
138
r"""
149
Two-dimensional Brusselator from [1]_.
1510
This is a reaction-diffusion equation with non-autonomous source term:
@@ -27,68 +22,29 @@ class Brusselator(ptype):
2722
.. [1] https://link.springer.com/book/10.1007/978-3-642-05221-7
2823
"""
2924

30-
dtype_u = mesh
31-
dtype_f = imex_mesh
32-
33-
def __init__(self, nvars=None, alpha=0.1, comm=MPI.COMM_WORLD):
25+
def __init__(self, alpha=0.1, **kwargs):
3426
"""Initialization routine"""
35-
nvars = (128,) * 2 if nvars is None else nvars
36-
L = 1.0
37-
38-
if not (isinstance(nvars, tuple) and len(nvars) > 1):
39-
raise ProblemError('Need at least two dimensions')
40-
41-
# Create FFT structure
42-
self.ndim = len(nvars)
43-
axes = tuple(range(self.ndim))
44-
self.fft = PFFT(
45-
comm,
46-
list(nvars),
47-
axes=axes,
48-
dtype=np.float64,
49-
collapse=True,
50-
backend='fftw',
51-
)
52-
53-
# get test data to figure out type and dimensions
54-
tmp_u = newDistArray(self.fft, False)
27+
super().__init__(spectral=False, L=1.0, dtype='d', alpha=alpha, **kwargs)
5528

5629
# prepare the array with two components
57-
shape = (2,) + tmp_u.shape
30+
shape = (2,) + (self.init[0])
5831
self.iU = 0
5932
self.iV = 1
33+
self.init = (shape, self.comm, np.dtype('float'))
34+
35+
def _eval_explicit_part(self, u, t, f_expl):
36+
iU, iV = self.iU, self.iV
37+
x, y = self.X[0], self.X[1]
6038

61-
super().__init__(init=(shape, comm, tmp_u.dtype))
62-
self._makeAttributeAndRegister('nvars', 'alpha', 'L', 'comm', localVars=locals(), readOnly=True)
63-
64-
L = np.array([self.L] * self.ndim, dtype=float)
65-
66-
# get local mesh for distributed FFT
67-
X = np.ogrid[self.fft.local_slice(False)]
68-
N = self.fft.global_shape()
69-
for i in range(len(N)):
70-
X[i] = X[i] * L[i] / N[i]
71-
self.X = [np.broadcast_to(x, self.fft.shape(False)) for x in X]
72-
73-
# get local wavenumbers and Laplace operator
74-
s = self.fft.local_slice()
75-
N = self.fft.global_shape()
76-
k = [np.fft.fftfreq(n, 1.0 / n).astype(int) for n in N[:-1]]
77-
k.append(np.fft.rfftfreq(N[-1], 1.0 / N[-1]).astype(int))
78-
K = [ki[si] for ki, si in zip(k, s)]
79-
Ks = np.meshgrid(*K, indexing='ij', sparse=True)
80-
Lp = 2 * np.pi / L
81-
for i in range(self.ndim):
82-
Ks[i] = (Ks[i] * Lp[i]).astype(float)
83-
K = [np.broadcast_to(k, self.fft.shape(True)) for k in Ks]
84-
K = np.array(K).astype(float)
85-
self.K2 = np.sum(K * K, 0, dtype=float)
86-
87-
# Need this for diagnostics
88-
self.dx = self.L / nvars[0]
89-
self.dy = self.L / nvars[1]
90-
91-
self.work_counters['rhs'] = WorkCounter()
39+
# evaluate time independent part
40+
f_expl[iU, ...] = 1.0 + u[iU] ** 2 * u[iV] - 4.4 * u[iU]
41+
f_expl[iV, ...] = 3.4 * u[iU] - u[iU] ** 2 * u[iV]
42+
43+
# add time-dependent part
44+
if t >= 1.1:
45+
mask = (x - 0.3) ** 2 + (y - 0.6) ** 2 <= 0.1**2
46+
f_expl[iU][mask] += 5.0
47+
return f_expl
9248

9349
def eval_f(self, u, t):
9450
"""
@@ -106,25 +62,13 @@ def eval_f(self, u, t):
10662
f : dtype_f
10763
The right-hand side of the problem.
10864
"""
109-
iU, iV = self.iU, self.iV
110-
x, y = self.X[0], self.X[1]
111-
11265
f = self.dtype_f(self.init)
11366

11467
# evaluate Laplacian to be solved implicitly
11568
for i in [self.iU, self.iV]:
116-
u_hat = self.fft.forward(u[i, ...])
117-
lap_u_hat = -self.alpha * self.K2 * u_hat
118-
f.impl[i, ...] = self.fft.backward(lap_u_hat, f.impl[i, ...])
69+
f.impl[i, ...] = self._eval_Laplacian(u[i], f.impl[i])
11970

120-
# evaluate time independent part
121-
f.expl[iU, ...] = 1.0 + u[iU] ** 2 * u[iV] - 4.4 * u[iU]
122-
f.expl[iV, ...] = 3.4 * u[iU] - u[iU] ** 2 * u[iV]
123-
124-
# add time-dependent part
125-
if t >= 1.1:
126-
mask = (x - 0.3) ** 2 + (y - 0.6) ** 2 <= 0.1**2
127-
f.expl[iU][mask] += 5.0
71+
f.expl[:] = self._eval_explicit_part(u, t, f.expl)
12872

12973
self.work_counters['rhs']()
13074

@@ -153,9 +97,7 @@ def solve_system(self, rhs, factor, u0, t):
15397
me = self.dtype_u(self.init)
15498

15599
for i in [self.iU, self.iV]:
156-
rhs_hat = self.fft.forward(rhs[i, ...])
157-
rhs_hat /= 1.0 + factor * self.K2 * self.alpha
158-
me[i, ...] = self.fft.backward(rhs_hat, me[i, ...])
100+
me[i, ...] = self._invert_Laplacian(me[i], factor, rhs[i])
159101

160102
return me
161103

@@ -184,8 +126,8 @@ def u_exact(self, t, u_init=None, t_init=None):
184126
me = self.dtype_u(self.init, val=0.0)
185127

186128
if t == 0:
187-
me[iU, ...] = 22.0 * y * (1 - y / self.L) ** (3.0 / 2.0) / self.L
188-
me[iV, ...] = 27.0 * x * (1 - x / self.L) ** (3.0 / 2.0) / self.L
129+
me[iU, ...] = 22.0 * y * (1 - y / self.L[0]) ** (3.0 / 2.0) / self.L[0]
130+
me[iV, ...] = 27.0 * x * (1 - x / self.L[0]) ** (3.0 / 2.0) / self.L[0]
189131
else:
190132

191133
def eval_rhs(t, u):

0 commit comments

Comments
 (0)