Skip to content

Commit 1196767

Browse files
committed
Expanded tests
1 parent 7862e7c commit 1196767

File tree

2 files changed

+45
-27
lines changed

2 files changed

+45
-27
lines changed

pySDC/implementations/problem_classes/generic_MPIFFT_Laplacian.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
11
import numpy as np
22
from mpi4py import MPI
3-
from mpi4py_fft import PFFT
3+
from mpi4py_fft import PFFT, newDistArray
44

55
from pySDC.core.errors import ProblemError
66
from pySDC.core.problem import Problem, WorkCounter
77
from pySDC.implementations.datatype_classes.mesh import mesh, imex_mesh
88

9-
from mpi4py_fft import newDistArray
10-
119

1210
class IMEX_Laplacian_MPIFFT(Problem):
1311
r"""
@@ -99,14 +97,24 @@ def __init__(
9997
'nvars', 'spectral', 'L', 'alpha', 'comm', 'x0', 'useGPU', localVars=locals(), readOnly=True
10098
)
10199

102-
# get local mesh
100+
self.getLocalGrid()
101+
self.getLaplacian()
102+
103+
# Need this for diagnostics
104+
self.dx = self.L[0] / nvars[0]
105+
self.dy = self.L[1] / nvars[1]
106+
107+
# work counters
108+
self.work_counters['rhs'] = WorkCounter()
109+
110+
def getLocalGrid(self):
103111
X = list(self.xp.ogrid[self.fft.local_slice(False)])
104112
N = self.fft.global_shape()
105113
for i in range(len(N)):
106-
X[i] = x0 + (X[i] * L[i] / N[i])
114+
X[i] = self.x0 + (X[i] * self.L[i] / N[i])
107115
self.X = [self.xp.broadcast_to(x, self.fft.shape(False)) for x in X]
108116

109-
# get local wavenumbers and Laplace operator
117+
def getLaplacian(self):
110118
s = self.fft.local_slice()
111119
N = self.fft.global_shape()
112120
k = [self.xp.fft.fftfreq(n, 1.0 / n).astype(int) for n in N]
@@ -117,14 +125,7 @@ def __init__(
117125
Ks[i] = (Ks[i] * Lp[i]).astype(float)
118126
K = [self.xp.broadcast_to(k, self.fft.shape(True)) for k in Ks]
119127
K = self.xp.array(K).astype(float)
120-
self.K2 = self.xp.sum(K * K, 0, dtype=float) # Laplacian in spectral space
121-
122-
# Need this for diagnostics
123-
self.dx = self.L[0] / nvars[0]
124-
self.dy = self.L[1] / nvars[1]
125-
126-
# work counters
127-
self.work_counters['rhs'] = WorkCounter()
128+
self.K2 = self.xp.sum(K * K, 0, dtype=float)
128129

129130
def eval_f(self, u, t):
130131
"""

pySDC/tests/test_problems/test_generic_MPIFFT.py

Lines changed: 30 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,39 +2,56 @@
22

33

44
@pytest.mark.mpi4py
5-
@pytest.mark.parametrize('nx', [16, 32])
6-
@pytest.mark.parametrize('ny', [16, 32])
7-
@pytest.mark.parametrize('nz', [0])
5+
@pytest.mark.parametrize('nx', [8, 16])
6+
@pytest.mark.parametrize('ny', [8, 16])
7+
@pytest.mark.parametrize('nz', [0, 8])
88
@pytest.mark.parametrize('f', [1, 3])
9+
@pytest.mark.parametrize('spectral', [True, False])
910
@pytest.mark.parametrize('direction', [0, 1, 10])
10-
def test_derivative(nx, ny, nz, f, direction):
11+
def test_derivative(nx, ny, nz, f, spectral, direction):
1112
from pySDC.implementations.problem_classes.generic_MPIFFT_Laplacian import IMEX_Laplacian_MPIFFT
1213

1314
nvars = (nx, ny)
1415
if nz > 0:
15-
nvars.append(nz)
16-
prob = IMEX_Laplacian_MPIFFT(nvars=nvars)
16+
nvars += (nz,)
17+
prob = IMEX_Laplacian_MPIFFT(nvars=nvars, spectral=spectral)
1718

1819
xp = prob.xp
1920

2021
if direction == 0:
21-
u = xp.sin(f * prob.X[0])
22+
_u = xp.sin(f * prob.X[0])
2223
du_expect = -(f**2) * xp.sin(f * prob.X[0])
2324
elif direction == 1:
24-
u = xp.sin(f * prob.X[1])
25+
_u = xp.sin(f * prob.X[1])
2526
du_expect = -(f**2) * xp.sin(f * prob.X[1])
2627
elif direction == 10:
27-
u = xp.sin(f * prob.X[1]) + xp.cos(f * prob.X[0])
28+
_u = xp.sin(f * prob.X[1]) + xp.cos(f * prob.X[0])
2829
du_expect = -(f**2) * xp.sin(f * prob.X[1]) - f**2 * xp.cos(f * prob.X[0])
2930
else:
3031
raise
3132

32-
du = prob.eval_f(u, 0).impl
33+
if spectral:
34+
u = prob.fft.forward(_u)
35+
else:
36+
u = _u
37+
38+
_du = prob.eval_f(u, 0).impl
39+
40+
if spectral:
41+
du = prob.fft.backward(_du)
42+
else:
43+
du = _du
3344
assert xp.allclose(du, du_expect), 'Got unexpected derivative'
3445

35-
_u = prob.solve_system(du, factor=1e8, u0=du, t=0) * -1e8
36-
assert xp.allclose(_u, u, atol=1e-7), 'Got unexpected inverse derivative'
46+
u2 = prob.solve_system(_du, factor=1e8, u0=du, t=0) * -1e8
47+
48+
if spectral:
49+
_u2 = prob.fft.backward(u2)
50+
else:
51+
_u2 = u2
52+
53+
assert xp.allclose(_u2, _u, atol=1e-7), 'Got unexpected inverse derivative'
3754

3855

3956
if __name__ == '__main__':
40-
test_derivative(32, 32, 0, 1, 1)
57+
test_derivative(6, 6, 6, 3, False, 1)

0 commit comments

Comments
 (0)