Skip to content

Commit e4816ff

Browse files
committed
Refactored increment formulation
1 parent eb7cfe3 commit e4816ff

File tree

7 files changed

+80
-50
lines changed

7 files changed

+80
-50
lines changed

pySDC/core/controller.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from pySDC.core.base_transfer import BaseTransfer
77
from pySDC.helpers.pysdc_helper import FrozenClass
88
from pySDC.implementations.convergence_controller_classes.check_convergence import CheckConvergence
9-
from pySDC.implementations.convergence_controller_classes.store_uold import StoreUOld
109
from pySDC.implementations.hooks.default_hook import DefaultHooks
1110
from pySDC.implementations.hooks.log_timings import CPUTimings
1211

@@ -378,12 +377,10 @@ def __init__(self, controller_params, description, n_steps, useMPI=None):
378377

379378
controller_params['all_to_done'] = True
380379
super().__init__(controller_params=controller_params, description=description, useMPI=useMPI)
381-
self.base_convergence_controllers += [StoreUOld]
382380

383-
self.ParaDiag_block_u0 = None
384381
self.n_steps = n_steps
385382

386-
def FFT_in_time(self):
383+
def FFT_in_time(self, quantity):
387384
"""
388385
Compute weighted forward FFT in time. The weighting is determined by the alpha parameter in ParaDiag
389386
@@ -395,9 +392,9 @@ def FFT_in_time(self):
395392

396393
self.__FFT_matrix = get_weighted_FFT_matrix(self.n_steps, self.params.alpha)
397394

398-
self.apply_matrix(self.__FFT_matrix)
395+
self.apply_matrix(self.__FFT_matrix, quantity)
399396

400-
def iFFT_in_time(self):
397+
def iFFT_in_time(self, quantity):
401398
"""
402399
Compute weighted backward FFT in time. The weighting is determined by the alpha parameter in ParaDiag
403400
"""
@@ -406,4 +403,4 @@ def iFFT_in_time(self):
406403

407404
self.__iFFT_matrix = get_weighted_iFFT_matrix(self.n_steps, self.params.alpha)
408405

409-
self.apply_matrix(self.__iFFT_matrix)
406+
self.apply_matrix(self.__iFFT_matrix, quantity)

pySDC/core/level.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ def __init__(self, problem_class, problem_params, sweeper_class, sweeper_params,
8484
self.uold = [None] * (self.sweep.coll.num_nodes + 1)
8585
self.u_avg = [None] * self.sweep.coll.num_nodes
8686
self.residual = [None] * self.sweep.coll.num_nodes
87+
self.increment = [None] * self.sweep.coll.num_nodes
8788
self.f = [None] * (self.sweep.coll.num_nodes + 1)
8889
self.fold = [None] * (self.sweep.coll.num_nodes + 1)
8990

pySDC/implementations/controller_classes/controller_ParaDiag_nonMPI.py

Lines changed: 28 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
11
import itertools
2-
import copy as cp
32
import numpy as np
4-
import dill
53

64
from pySDC.core.controller import ParaDiagController
75
from pySDC.core import step as stepclass
8-
from pySDC.core.errors import ControllerError, CommunicationError
6+
from pySDC.core.errors import ControllerError
97
from pySDC.implementations.convergence_controller_classes.basic_restarting import BasicRestarting
108
from pySDC.helpers.ParaDiagHelper import get_G_inv_matrix
119

@@ -92,7 +90,7 @@ def ParaDiag(self, local_MS_active):
9290

9391
return all(S.status.done for S in local_MS_active)
9492

95-
def apply_matrix(self, mat):
93+
def apply_matrix(self, mat, quantity):
9694
"""
9795
Apply a matrix on the step level. Needs to be square. Puts the result back into the controller.
9896
@@ -112,29 +110,33 @@ def apply_matrix(self, mat):
112110
None,
113111
] * L
114112

113+
if quantity == 'residual':
114+
me = [S.levels[0].residual for S in self.MS]
115+
elif quantity == 'increment':
116+
me = [S.levels[0].increment for S in self.MS]
117+
else:
118+
raise NotImplementedError
119+
115120
# compute matrix-vector product
116121
for i in range(mat.shape[0]):
117-
res[i] = [prob.u_init for _ in range(M + 1)]
122+
res[i] = [prob.u_init for _ in range(M)]
118123
for j in range(mat.shape[1]):
119-
for m in range(M + 1):
120-
res[i][m] += mat[i, j] * self.MS[j].levels[0].u[m]
124+
for m in range(M):
125+
res[i][m] += mat[i, j] * me[j][m]
121126

122127
# put the result in the "output"
123128
for i in range(mat.shape[0]):
124-
for m in range(M + 1):
125-
self.MS[i].levels[0].u[m] = res[i][m]
129+
for m in range(M):
130+
me[i][m] = res[i][m]
126131

127-
def swap_solution_for_all_at_once_residual(self, local_MS_running):
132+
def compute_all_at_once_residual(self, local_MS_running):
128133
"""
129-
Replace the solution values in the steps with the all-at-once residual.
130-
131134
This requires to communicate the solutions at the end of the steps to be the initial conditions for the next
132135
steps. Afterwards, the residual can be computed locally on the steps.
133136
134137
Args:
135138
local_MS_running (list): list of currently running steps
136139
"""
137-
prob = self.MS[0].levels[0].prob
138140

139141
for S in local_MS_running:
140142
# communicate initial conditions
@@ -143,9 +145,7 @@ def swap_solution_for_all_at_once_residual(self, local_MS_running):
143145
for hook in self.hooks:
144146
hook.pre_comm(step=S, level_number=0)
145147

146-
if S.status.first:
147-
S.levels[0].u[0] = prob.dtype_u(self.ParaDiag_block_u0)
148-
else:
148+
if not S.status.first:
149149
S.levels[0].u[0] = S.prev.levels[0].uend
150150

151151
for hook in self.hooks:
@@ -154,25 +154,16 @@ def swap_solution_for_all_at_once_residual(self, local_MS_running):
154154
# compute residuals locally
155155
S.levels[0].sweep.compute_residual()
156156

157-
# put residual in the solution variables
158-
for m in range(S.levels[0].sweep.coll.num_nodes):
159-
S.levels[0].u[m + 1] = S.levels[0].residual[m]
160-
161-
def swap_increment_for_solution(self, local_MS_running):
157+
def update_solution(self, local_MS_running):
162158
"""
163-
After inversion of the preconditioner, the values stored in the steps are the increment. This function adds the
164-
solution after the previous iteration to arrive at the solution after the current iteration.
165-
Note that we also need to put in the initial conditions back in the first step because they will be perturbed by
166-
the circular preconditioner.
159+
Since we solve for the increment, we need to update the solution between iterations by adding the increment.
167160
168161
Args:
169162
local_MS_running (list): list of currently running steps
170163
"""
171164
for S in local_MS_running:
172-
for m in range(S.levels[0].sweep.coll.num_nodes + 1):
173-
S.levels[0].u[m] = S.levels[0].uold[m] + S.levels[0].u[m]
174-
if S.status.first:
175-
S.levels[0].u[0] = self.ParaDiag_block_u0
165+
for m in range(S.levels[0].sweep.coll.num_nodes):
166+
S.levels[0].u[m + 1] += S.levels[0].increment[m]
176167

177168
def prepare_Jacobians(self, local_MS_running):
178169
# get solutions for constructing average Jacobians
@@ -215,22 +206,22 @@ def it_ParaDiag(self, local_MS_running):
215206
# communicate average residual for setting up Jacobians for non-linear problems
216207
self.prepare_Jacobians(local_MS_running)
217208

218-
# replace the values stored in the steps with the residuals in order to compute the increment
219-
self.swap_solution_for_all_at_once_residual(local_MS_running)
209+
# compute the all-at-once residual to use as right hand side
210+
self.compute_all_at_once_residual(local_MS_running)
220211

221-
# weighted FFT in time
222-
self.FFT_in_time()
212+
# weighted FFT of the residual in time
213+
self.FFT_in_time(quantity='residual')
223214

224215
# perform local solves of "collocation problems" on the steps (can be done in parallel)
225216
for S in local_MS_running:
226217
assert len(S.levels) == 1, 'Multi-level SDC not implemented in ParaDiag'
227218
S.levels[0].sweep.update_nodes()
228219

229-
# inverse FFT in time
230-
self.iFFT_in_time()
220+
# inverse FFT of the increment in time
221+
self.iFFT_in_time(quantity='increment')
231222

232-
# replace the values stored in the steps with the previous solution plus the increment
233-
self.swap_increment_for_solution(local_MS_running)
223+
# get the next iterate by adding increment to previous iterate
224+
self.update_solution(local_MS_running)
234225

235226
for S in local_MS_running:
236227
for hook in self.hooks:
@@ -438,7 +429,6 @@ def restart_block(self, active_slots, time, u0):
438429
u0: initial value to distribute across the steps
439430
440431
"""
441-
self.ParaDiag_block_u0 = u0 # need this for computing residual
442432

443433
for j in range(len(active_slots)):
444434
# get slot number

pySDC/implementations/sweeper_classes/ParaDiagSweepers.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def update_nodes(self):
102102

103103
# perform local solves on the collocation nodes, can be parallelized!
104104
if self.params.ignore_ic:
105-
x1 = self.mat_vec(self.S_inv, [self.level.u[m + 1] for m in range(M)])
105+
x1 = self.mat_vec(self.S_inv, [self.level.residual[m] for m in range(M)])
106106
else:
107107
x1 = self.mat_vec(self.S_inv, [self.level.u[0] for _ in range(M)])
108108

@@ -120,7 +120,10 @@ def update_nodes(self):
120120

121121
# update solution and evaluate right hand side
122122
for m in range(M):
123-
L.u[m + 1] = y[m]
123+
if self.params.ignore_ic:
124+
L.increment[m] = y[m]
125+
else:
126+
L.u[m + 1] = y[m]
124127
if self.params.update_f_evals:
125128
L.f[m + 1] = P.eval_f(L.u[m + 1], L.time + L.dt * self.coll.nodes[m])
126129

pySDC/tests/test_controllers/test_controller_ParaDiag_nonMPI.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,9 +261,44 @@ def test_ParaDiag_convergence_rate(L, M, N, alpha):
261261
), f'Convergence rate {convergence_rate:.2e} exceeds upper bound of {convergence_bound:.2e}!'
262262

263263

264+
@pytest.mark.base
265+
@pytest.mark.parametrize('L', [4, 12])
266+
@pytest.mark.parametrize('M', [3, 4])
267+
@pytest.mark.parametrize('N', [1, 2])
268+
def test_fft(L, M, N):
269+
import numpy as np
270+
from pySDC.helpers.ParaDiagHelper import get_FFT_matrix
271+
272+
dt = 1e-2
273+
controller, prob = get_composite_collocation_problem(L, M, N, alpha=1e-1, dt=dt, problem='Dahlquist')
274+
# generate random data
275+
data = np.random.random((L, M, N))
276+
data = np.ones((L, M, N))
277+
278+
for l in range(L):
279+
for m in range(M):
280+
controller.MS[l].levels[0].residual[m] = prob.u_init
281+
controller.MS[l].levels[0].residual[m][:] = data[l, m]
282+
283+
fft_matrix = get_FFT_matrix(L)
284+
controller.apply_matrix(fft_matrix, 'residual')
285+
data_fft = np.fft.fft(data, axis=0, norm='ortho')
286+
287+
for l in range(L):
288+
for m in range(M):
289+
assert np.allclose(controller.MS[l].levels[0].residual[m], data_fft[l, m])
290+
291+
controller.apply_matrix(np.conjugate(fft_matrix), 'residual')
292+
for l in range(L):
293+
for m in range(M):
294+
assert np.allclose(controller.MS[l].levels[0].residual[m], data[l, m])
295+
296+
264297
if __name__ == '__main__':
265-
test_ParaDiag_convergence_rate(4, 3, 1, 1e-4)
298+
test_fft(3, 2, 2)
299+
# test_ParaDiag_convergence_rate(4, 3, 1, 1e-4)
266300
# test_ParaDiag_vs_PFASST(4, 3, 2, 'Dahlquist')
267301
# test_ParaDiag_convergence(4, 3, 1, 1e-4, 'vdp')
268302
# test_IMEX_ParaDiag_convergence(4, 3, 64, 1e-4)
269303
# test_ParaDiag_order(3, 3, 1, 1e-4)
304+
print('done')

pySDC/tests/test_sweepers/test_ParaDiag_sweepers.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,8 @@ def test_direct_solve(M, N, ignore_ic):
7575
level.u[m] = prob.u_exact(0)
7676
level.f[m] = prob.eval_f(level.u[m], 0)
7777

78+
level.sweep.compute_residual()
79+
7880
if ignore_ic:
7981
level.u[0][:] = None
8082

@@ -92,6 +94,8 @@ def test_direct_solve(M, N, ignore_ic):
9294
u = sp.linalg.spsolve(C_coll, u0.flatten()).reshape(u0.shape)
9395

9496
for m in range(M):
97+
if ignore_ic:
98+
level.u[m + 1] = level.u[m + 1] + level.increment[m]
9599
assert np.allclose(u[m], level.u[m + 1])
96100

97101
if not ignore_ic:
@@ -100,4 +104,4 @@ def test_direct_solve(M, N, ignore_ic):
100104

101105

102106
if __name__ == '__main__':
103-
test_direct_solve(2, 1, False)
107+
test_direct_solve(2, 1, True)

pySDC/tutorial/step_9/C_paradiag_in_pySDC.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,5 +176,5 @@ def compare_ParaDiag_and_PFASST(n_steps, problem):
176176
params = {
177177
'n_steps': 16,
178178
}
179-
compare_ParaDiag_and_PFASST(**params, problem='advection')
179+
# compare_ParaDiag_and_PFASST(**params, problem='advection')
180180
compare_ParaDiag_and_PFASST(**params, problem='vdp')

0 commit comments

Comments
 (0)