|
2 | 2 |
|
3 | 3 |
|
4 | 4 | @pytest.mark.mpi4py |
5 | | -@pytest.mark.parametrize('nx', [16, 32]) |
6 | | -@pytest.mark.parametrize('ny', [16, 32]) |
7 | | -@pytest.mark.parametrize('nz', [0]) |
| 5 | +@pytest.mark.parametrize('nx', [8, 16]) |
| 6 | +@pytest.mark.parametrize('ny', [8, 16]) |
| 7 | +@pytest.mark.parametrize('nz', [0, 8]) |
8 | 8 | @pytest.mark.parametrize('f', [1, 3]) |
| 9 | +@pytest.mark.parametrize('spectral', [True, False]) |
9 | 10 | @pytest.mark.parametrize('direction', [0, 1, 10]) |
10 | | -def test_derivative(nx, ny, nz, f, direction): |
| 11 | +def test_derivative(nx, ny, nz, f, spectral, direction): |
11 | 12 | from pySDC.implementations.problem_classes.generic_MPIFFT_Laplacian import IMEX_Laplacian_MPIFFT |
12 | 13 |
|
13 | 14 | nvars = (nx, ny) |
14 | 15 | if nz > 0: |
15 | | - nvars.append(nz) |
16 | | - prob = IMEX_Laplacian_MPIFFT(nvars=nvars) |
| 16 | + nvars += (nz,) |
| 17 | + prob = IMEX_Laplacian_MPIFFT(nvars=nvars, spectral=spectral) |
17 | 18 |
|
18 | 19 | xp = prob.xp |
19 | 20 |
|
20 | 21 | if direction == 0: |
21 | | - u = xp.sin(f * prob.X[0]) |
| 22 | + _u = xp.sin(f * prob.X[0]) |
22 | 23 | du_expect = -(f**2) * xp.sin(f * prob.X[0]) |
23 | 24 | elif direction == 1: |
24 | | - u = xp.sin(f * prob.X[1]) |
| 25 | + _u = xp.sin(f * prob.X[1]) |
25 | 26 | du_expect = -(f**2) * xp.sin(f * prob.X[1]) |
26 | 27 | elif direction == 10: |
27 | | - u = xp.sin(f * prob.X[1]) + xp.cos(f * prob.X[0]) |
| 28 | + _u = xp.sin(f * prob.X[1]) + xp.cos(f * prob.X[0]) |
28 | 29 | du_expect = -(f**2) * xp.sin(f * prob.X[1]) - f**2 * xp.cos(f * prob.X[0]) |
29 | 30 | else: |
30 | 31 | raise |
31 | 32 |
|
32 | | - du = prob.eval_f(u, 0).impl |
| 33 | + if spectral: |
| 34 | + u = prob.fft.forward(_u) |
| 35 | + else: |
| 36 | + u = _u |
| 37 | + |
| 38 | + _du = prob.eval_f(u, 0).impl |
| 39 | + |
| 40 | + if spectral: |
| 41 | + du = prob.fft.backward(_du) |
| 42 | + else: |
| 43 | + du = _du |
33 | 44 | assert xp.allclose(du, du_expect), 'Got unexpected derivative' |
34 | 45 |
|
35 | | - _u = prob.solve_system(du, factor=1e8, u0=du, t=0) * -1e8 |
36 | | - assert xp.allclose(_u, u, atol=1e-7), 'Got unexpected inverse derivative' |
| 46 | + u2 = prob.solve_system(_du, factor=1e8, u0=du, t=0) * -1e8 |
| 47 | + |
| 48 | + if spectral: |
| 49 | + _u2 = prob.fft.backward(u2) |
| 50 | + else: |
| 51 | + _u2 = u2 |
| 52 | + |
| 53 | + assert xp.allclose(_u2, _u, atol=1e-7), 'Got unexpected inverse derivative' |
37 | 54 |
|
38 | 55 |
|
39 | 56 | if __name__ == '__main__': |
40 | | - test_derivative(32, 32, 0, 1, 1) |
| 57 | + test_derivative(6, 6, 6, 3, False, 1) |
0 commit comments