Skip to content

Commit bdf1a2d

Browse files
committed
Fixes and refactor for RKN
1 parent 82864f7 commit bdf1a2d

File tree

3 files changed

+132
-97
lines changed

3 files changed

+132
-97
lines changed

pySDC/implementations/sweeper_classes/Runge_Kutta.py

Lines changed: 32 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -267,26 +267,28 @@ def compute_end_point(self):
267267
"""
268268
In this Runge-Kutta implementation, the solution to the step is always stored in the last node
269269
"""
270-
if self.level.f[1] is None:
271-
self.level.uend = self.level.u[0]
270+
lvl = self.level
271+
272+
if lvl.f[1] is None:
273+
lvl.uend = lvl.prob.dtype_u(lvl.u[0])
272274
if type(self.coll) == ButcherTableauEmbedded:
273-
self.u_secondary = self.level.u[0].copy()
275+
self.u_secondary = lvl.prob.dtype_u(lvl.u[0])
274276
elif self.coll.globally_stiffly_accurate:
275-
self.level.uend = self.level.u[-1]
277+
lvl.uend = lvl.prob.dtype_u(lvl.u[-1])
276278
if type(self.coll) == ButcherTableauEmbedded:
277-
self.u_secondary = self.level.u[0].copy()
278-
for w2, k in zip(self.coll.weights[1], self.level.f[1:]):
279-
self.u_secondary += self.level.dt * w2 * k
279+
self.u_secondary = lvl.prob.dtype_u(lvl.u[0])
280+
for w2, k in zip(self.coll.weights[1], lvl.f[1:]):
281+
self.u_secondary += lvl.dt * w2 * k
280282
else:
281-
self.level.uend = self.level.u[0].copy()
283+
lvl.uend = lvl.prob.dtype_u(lvl.u[0])
282284
if type(self.coll) == ButcherTableau:
283-
for w, k in zip(self.coll.weights, self.level.f[1:]):
284-
self.level.uend += self.level.dt * w * k
285+
for w, k in zip(self.coll.weights, lvl.f[1:]):
286+
lvl.uend += lvl.dt * w * k
285287
elif type(self.coll) == ButcherTableauEmbedded:
286-
self.u_secondary = self.level.u[0].copy()
287-
for w1, w2, k in zip(self.coll.weights[0], self.coll.weights[1], self.level.f[1:]):
288-
self.level.uend += self.level.dt * w1 * k
289-
self.u_secondary += self.level.dt * w2 * k
288+
self.u_secondary = lvl.prob.dtype_u(lvl.u[0])
289+
for w1, w2, k in zip(self.coll.weights[0], self.coll.weights[1], lvl.f[1:]):
290+
lvl.uend += lvl.dt * w1 * k
291+
self.u_secondary += lvl.dt * w2 * k
290292

291293
@property
292294
def level(self):
@@ -444,32 +446,34 @@ def compute_end_point(self):
444446
"""
445447
In this Runge-Kutta implementation, the solution to the step is always stored in the last node
446448
"""
447-
if self.level.f[1] is None:
448-
self.level.uend = self.level.u[0]
449+
lvl = self.level
450+
451+
if lvl.f[1] is None:
452+
lvl.uend = lvl.prob.dtype_u(lvl.u[0])
449453
if type(self.coll) == ButcherTableauEmbedded:
450-
self.u_secondary = self.level.u[0].copy()
454+
self.u_secondary = lvl.prob.dtype_u(lvl.u[0])
451455
elif self.coll.globally_stiffly_accurate and self.coll_explicit.globally_stiffly_accurate:
452-
self.level.uend = self.level.u[-1]
456+
lvl.uend = lvl.u[-1]
453457
if type(self.coll) == ButcherTableauEmbedded:
454-
self.u_secondary = self.level.u[0].copy()
455-
for w2, w2E, k in zip(self.coll.weights[1], self.coll_explicit.weights[1], self.level.f[1:]):
456-
self.u_secondary += self.level.dt * (w2 * k.impl + w2E * k.expl)
458+
self.u_secondary = lvl.prob.dtype_u(lvl.u[0])
459+
for w2, w2E, k in zip(self.coll.weights[1], self.coll_explicit.weights[1], lvl.f[1:]):
460+
self.u_secondary += lvl.dt * (w2 * k.impl + w2E * k.expl)
457461
else:
458-
self.level.uend = self.level.u[0].copy()
462+
lvl.uend = lvl.prob.dtype_u(lvl.u[0])
459463
if type(self.coll) == ButcherTableau:
460-
for w, wE, k in zip(self.coll.weights, self.coll_explicit.weights, self.level.f[1:]):
461-
self.level.uend += self.level.dt * (w * k.impl + wE * k.expl)
464+
for w, wE, k in zip(self.coll.weights, self.coll_explicit.weights, lvl.f[1:]):
465+
lvl.uend += lvl.dt * (w * k.impl + wE * k.expl)
462466
elif type(self.coll) == ButcherTableauEmbedded:
463-
self.u_secondary = self.level.u[0].copy()
467+
self.u_secondary = lvl.u[0].copy()
464468
for w1, w2, w1E, w2E, k in zip(
465469
self.coll.weights[0],
466470
self.coll.weights[1],
467471
self.coll_explicit.weights[0],
468472
self.coll_explicit.weights[1],
469-
self.level.f[1:],
473+
lvl.f[1:],
470474
):
471-
self.level.uend += self.level.dt * (w1 * k.impl + w1E * k.expl)
472-
self.u_secondary += self.level.dt * (w2 * k.impl + w2E * k.expl)
475+
lvl.uend += lvl.dt * (w1 * k.impl + w1E * k.expl)
476+
self.u_secondary += lvl.dt * (w2 * k.impl + w2E * k.expl)
473477

474478

475479
class ForwardEuler(RungeKutta):

pySDC/implementations/sweeper_classes/Runge_Kutta_Nystrom.py

Lines changed: 96 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,74 @@
44
from pySDC.core.sweeper import Sweeper, _Pars
55
from pySDC.core.errors import ParameterError
66
from pySDC.implementations.datatype_classes.particles import particles, fields, acceleration
7-
from pySDC.implementations.sweeper_classes.Runge_Kutta import ButcherTableau
8-
from copy import deepcopy
97
from pySDC.implementations.sweeper_classes.Runge_Kutta import RungeKutta
108

119

10+
class ButcherTableauNoCollUpdate(object):
11+
"""Version of Butcher Tableau that does not need a collocation update because the weights are put in the last line of Q"""
12+
13+
def __init__(self, weights, nodes, matrix):
14+
"""
15+
Initialization routine to get a quadrature matrix out of a Butcher tableau
16+
17+
Args:
18+
weights (numpy.ndarray): Butcher tableau weights
19+
nodes (numpy.ndarray): Butcher tableau nodes
20+
matrix (numpy.ndarray): Butcher tableau entries
21+
"""
22+
# check if the arguments have the correct form
23+
if type(matrix) != np.ndarray:
24+
raise ParameterError('Runge-Kutta matrix needs to be supplied as a numpy array!')
25+
elif len(np.unique(matrix.shape)) != 1 or len(matrix.shape) != 2:
26+
raise ParameterError('Runge-Kutta matrix needs to be a square 2D numpy array!')
27+
28+
if type(weights) != np.ndarray:
29+
raise ParameterError('Weights need to be supplied as a numpy array!')
30+
elif len(weights.shape) != 1:
31+
raise ParameterError(f'Incompatible dimension of weights! Need 1, got {len(weights.shape)}')
32+
elif len(weights) != matrix.shape[0]:
33+
raise ParameterError(f'Incompatible number of weights! Need {matrix.shape[0]}, got {len(weights)}')
34+
35+
if type(nodes) != np.ndarray:
36+
raise ParameterError('Nodes need to be supplied as a numpy array!')
37+
elif len(nodes.shape) != 1:
38+
raise ParameterError(f'Incompatible dimension of nodes! Need 1, got {len(nodes.shape)}')
39+
elif len(nodes) != matrix.shape[0]:
40+
raise ParameterError(f'Incompatible number of nodes! Need {matrix.shape[0]}, got {len(nodes)}')
41+
42+
self.globally_stiffly_accurate = np.allclose(matrix[-1], weights)
43+
44+
self.tleft = 0.0
45+
self.tright = 1.0
46+
self.num_solution_stages = 0 if self.globally_stiffly_accurate else 1
47+
self.num_nodes = matrix.shape[0] + self.num_solution_stages
48+
self.weights = weights
49+
50+
if self.globally_stiffly_accurate:
51+
# For globally stiffly accurate methods, the last row of the Butcher tableau is the same as the weights.
52+
self.nodes = np.append([0], nodes)
53+
self.Qmat = np.zeros([self.num_nodes + 1, self.num_nodes + 1])
54+
self.Qmat[1:, 1:] = matrix
55+
else:
56+
self.nodes = np.append(np.append([0], nodes), [1])
57+
self.Qmat = np.zeros([self.num_nodes + 1, self.num_nodes + 1])
58+
self.Qmat[1:-1, 1:-1] = matrix
59+
self.Qmat[-1, 1:-1] = weights # this is for computing the solution to the step from the previous stages
60+
61+
self.left_is_node = True
62+
self.right_is_node = self.nodes[-1] == self.tright
63+
64+
# compute distances between the nodes
65+
if self.num_nodes > 1:
66+
self.delta_m = self.nodes[1:] - self.nodes[:-1]
67+
else:
68+
self.delta_m = np.zeros(1)
69+
self.delta_m[0] = self.nodes[0] - self.tleft
70+
71+
# check if the RK scheme is implicit
72+
self.implicit = any(matrix[i, i] != 0 for i in range(self.num_nodes - self.num_solution_stages))
73+
74+
1275
class RungeKuttaNystrom(RungeKutta):
1376
"""
1477
Runge-Kutta scheme that fits the interface of a sweeper.
@@ -34,56 +97,31 @@ class RungeKuttaNystrom(RungeKutta):
3497
All of these variables are either determined by the RK rule, or are not part of an RK scheme.
3598
3699
Attribues:
37-
butcher_tableau (ButcherTableau): Butcher tableau for the Runge-Kutta scheme that you want
100+
butcher_tableau (ButcherTableauNoCollUpdate): Butcher tableau for the Runge-Kutta scheme that you want
38101
"""
39102

103+
ButcherTableauClass = ButcherTableauNoCollUpdate
104+
weights_bar = None
105+
matrix_bar = None
106+
40107
def __init__(self, params):
41108
"""
42109
Initialization routine for the custom sweeper
43110
44111
Args:
45112
params: parameters for the sweeper
46113
"""
47-
# set up logger
48-
self.logger = logging.getLogger('sweeper')
49-
50-
essential_keys = ['butcher_tableau']
51-
for key in essential_keys:
52-
if key not in params:
53-
msg = 'need %s to instantiate step, only got %s' % (key, str(params.keys()))
54-
self.logger.error(msg)
55-
raise ParameterError(msg)
56-
57-
# check if some parameters are set which only apply to actual sweepers
58-
for key in ['initial_guess', 'collocation_class', 'num_nodes']:
59-
if key in params:
60-
self.logger.warning(f'"{key}" will be ignored by Runge-Kutta sweeper')
61-
62-
# set parameters to their actual values
63-
params['initial_guess'] = 'zero'
64-
params['collocation_class'] = type(params['butcher_tableau'])
65-
params['num_nodes'] = params['butcher_tableau'].num_nodes
66-
67-
# disable residual computation by default
68-
params['skip_residual_computation'] = params.get(
69-
'skip_residual_computation', ('IT_CHECK', 'IT_FINE', 'IT_COARSE', 'IT_UP', 'IT_DOWN')
70-
)
71-
72-
self.params = _Pars(params)
73-
74-
self.coll = params['butcher_tableau']
75-
self.coll_bar = params['butcher_tableau_bar']
76-
77-
# This will be set as soon as the sweeper is instantiated at the level
78-
self.__level = None
79-
80-
self.parallelizable = False
81-
self.QI = self.coll.Qmat
114+
super().__init__(params)
115+
self.coll_bar = self.get_Butcher_tableau_bar()
82116
self.Qx = self.coll_bar.Qmat
83117

118+
@classmethod
119+
def get_Butcher_tableau_bar(cls):
120+
return cls.ButcherTableauClass(cls.weights_bar, cls.nodes, cls.matrix_bar)
121+
84122
def get_full_f(self, f):
85123
"""
86-
Test the right hand side funtion is the correct type
124+
Test the right hand side function is the correct type
87125
88126
Args:
89127
f (dtype_f): Right hand side at a single node
@@ -118,7 +156,7 @@ def update_nodes(self):
118156

119157
for m in range(0, M):
120158
# build rhs, consisting of the known values from above and new values from previous nodes (at k+1)
121-
rhs = deepcopy(L.u[0])
159+
rhs = P.dtype_u(L.u[0])
122160
rhs.pos += L.dt * self.coll.nodes[m + 1] * L.u[0].vel
123161

124162
for j in range(1, m + 1):
@@ -147,7 +185,7 @@ def update_nodes(self):
147185
if self.coll.implicit:
148186
# That is why it only works for the Velocity-Verlet scheme
149187
L.f[0] = P.eval_f(L.u[0], L.time)
150-
L.f[m + 1] = deepcopy(L.f[0])
188+
L.f[m + 1] = P.dtype_f(L.f[0])
151189
else:
152190
if m != self.coll.num_nodes - 1:
153191
L.f[m + 1] = P.eval_f(L.u[m + 1], L.time + L.dt * self.coll.nodes[m])
@@ -173,23 +211,18 @@ class RKN(RungeKuttaNystrom):
173211
Chapter: II.14 Numerical methods for Second order differential equations
174212
"""
175213

176-
def __init__(self, params):
177-
nodes = np.array([0.0, 0.5, 0.5, 1])
178-
weights = np.array([1.0, 2.0, 2.0, 1.0]) / 6.0
179-
matrix = np.zeros([4, 4])
180-
matrix[1, 0] = 0.5
181-
matrix[2, 1] = 0.5
182-
matrix[3, 2] = 1.0
183-
184-
weights_bar = np.array([1.0, 1.0, 1.0, 0]) / 6.0
185-
matrix_bar = np.zeros([4, 4])
186-
matrix_bar[1, 0] = 1 / 8
187-
matrix_bar[2, 0] = 1 / 8
188-
matrix_bar[3, 2] = 1 / 2
189-
params['butcher_tableau'] = ButcherTableau(weights, nodes, matrix)
190-
params['butcher_tableau_bar'] = ButcherTableau(weights_bar, nodes, matrix_bar)
214+
nodes = np.array([0.0, 0.5, 0.5, 1])
215+
weights = np.array([1.0, 2.0, 2.0, 1.0]) / 6.0
216+
matrix = np.zeros([4, 4])
217+
matrix[1, 0] = 0.5
218+
matrix[2, 1] = 0.5
219+
matrix[3, 2] = 1.0
191220

192-
super(RKN, self).__init__(params)
221+
weights_bar = np.array([1.0, 1.0, 1.0, 0]) / 6.0
222+
matrix_bar = np.zeros([4, 4])
223+
matrix_bar[1, 0] = 1 / 8
224+
matrix_bar[2, 0] = 1 / 8
225+
matrix_bar[3, 2] = 1 / 2
193226

194227

195228
class Velocity_Verlet(RungeKuttaNystrom):
@@ -198,15 +231,9 @@ class Velocity_Verlet(RungeKuttaNystrom):
198231
https://de.wikipedia.org/wiki/Verlet-Algorithmus
199232
"""
200233

201-
def __init__(self, params):
202-
nodes = np.array([1.0, 1.0])
203-
weights = np.array([1 / 2, 0])
204-
matrix = np.zeros([2, 2])
205-
matrix[1, 1] = 1
206-
weights_bar = np.array([1 / 2, 0])
207-
matrix_bar = np.zeros([2, 2])
208-
params['butcher_tableau'] = ButcherTableau(weights, nodes, matrix)
209-
params['butcher_tableau_bar'] = ButcherTableau(weights_bar, nodes, matrix_bar)
210-
params['Velocity_verlet'] = True
211-
212-
super(Velocity_Verlet, self).__init__(params)
234+
nodes = np.array([1.0, 1.0])
235+
weights = np.array([1 / 2, 0])
236+
matrix = np.zeros([2, 2])
237+
matrix[1, 1] = 1
238+
weights_bar = np.array([1 / 2, 0])
239+
matrix_bar = np.zeros([2, 2])

pySDC/tests/test_sweepers/test_Runge_Kutta_sweeper.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,12 +120,16 @@ def single_run(sweeper_name, dt, lambdas, use_RK_sweeper=True, Tend=None, useGPU
120120
sweeper = controller.MS[0].levels[0].sweep
121121
sweeper.QI = rk_sweeper.get_Q_matrix()
122122
sweeper.coll = rk_sweeper.get_Butcher_tableau()
123+
_compute_end_point = type(sweeper).compute_end_point
123124
type(sweeper).compute_end_point = rk_sweeper.compute_end_point
124125

125126
prob = controller.MS[0].levels[0].prob
126127
ic = prob.u_exact(0)
127128
u_end, stats = controller.run(ic, 0.0, 5 * dt if Tend is None else Tend)
128129

130+
if not use_RK_sweeper:
131+
type(sweeper).compute_end_point = _compute_end_point
132+
129133
return stats, ic, controller
130134

131135

0 commit comments

Comments
 (0)