Skip to content

Commit 31a65ac

Browse files
Refactor of RK implementation (#499)
* Globally stiffly accurate ARK methods * Refactor and fix * Fixes and refactor for RKN
1 parent 2f84eda commit 31a65ac

File tree

4 files changed

+280
-157
lines changed

4 files changed

+280
-157
lines changed

pySDC/implementations/convergence_controller_classes/estimate_embedded_error.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,10 +93,10 @@ def estimate_embedded_error_serial(self, L):
9393
dtype_u: The embedded error estimate
9494
"""
9595
if self.params.sweeper_type == "RK":
96-
# lower order solution is stored in the second to last entry of L.u
97-
return abs(L.u[-2] - L.u[-1])
96+
L.sweep.compute_end_point()
97+
return abs(L.uend - L.sweep.u_secondary)
9898
elif self.params.sweeper_type == "SDC":
99-
# order rises by one between sweeps, making this so ridiculously easy
99+
# order rises by one between sweeps
100100
return abs(L.uold[-1] - L.u[-1])
101101
elif self.params.sweeper_type == 'MPI':
102102
comm = L.sweep.comm

pySDC/implementations/sweeper_classes/Runge_Kutta.py

Lines changed: 151 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -17,44 +17,16 @@ def __init__(self, weights, nodes, matrix):
1717
nodes (numpy.ndarray): Butcher tableau nodes
1818
matrix (numpy.ndarray): Butcher tableau entries
1919
"""
20-
# check if the arguments have the correct form
21-
if type(matrix) != np.ndarray:
22-
raise ParameterError('Runge-Kutta matrix needs to be supplied as a numpy array!')
23-
elif len(np.unique(matrix.shape)) != 1 or len(matrix.shape) != 2:
24-
raise ParameterError('Runge-Kutta matrix needs to be a square 2D numpy array!')
25-
26-
if type(weights) != np.ndarray:
27-
raise ParameterError('Weights need to be supplied as a numpy array!')
28-
elif len(weights.shape) != 1:
29-
raise ParameterError(f'Incompatible dimension of weights! Need 1, got {len(weights.shape)}')
30-
elif len(weights) != matrix.shape[0]:
31-
raise ParameterError(f'Incompatible number of weights! Need {matrix.shape[0]}, got {len(weights)}')
32-
33-
if type(nodes) != np.ndarray:
34-
raise ParameterError('Nodes need to be supplied as a numpy array!')
35-
elif len(nodes.shape) != 1:
36-
raise ParameterError(f'Incompatible dimension of nodes! Need 1, got {len(nodes.shape)}')
37-
elif len(nodes) != matrix.shape[0]:
38-
raise ParameterError(f'Incompatible number of nodes! Need {matrix.shape[0]}, got {len(nodes)}')
39-
40-
self.globally_stiffly_accurate = np.allclose(matrix[-1], weights)
20+
self.check_method(weights, nodes, matrix)
4121

4222
self.tleft = 0.0
4323
self.tright = 1.0
44-
self.num_solution_stages = 0 if self.globally_stiffly_accurate else 1
45-
self.num_nodes = matrix.shape[0] + self.num_solution_stages
24+
self.num_nodes = matrix.shape[0]
4625
self.weights = weights
4726

48-
if self.globally_stiffly_accurate:
49-
# For globally stiffly accurate methods, the last row of the Butcher tableau is the same as the weights.
50-
self.nodes = np.append([0], nodes)
51-
self.Qmat = np.zeros([self.num_nodes + 1, self.num_nodes + 1])
52-
self.Qmat[1:, 1:] = matrix
53-
else:
54-
self.nodes = np.append(np.append([0], nodes), [1])
55-
self.Qmat = np.zeros([self.num_nodes + 1, self.num_nodes + 1])
56-
self.Qmat[1:-1, 1:-1] = matrix
57-
self.Qmat[-1, 1:-1] = weights # this is for computing the solution to the step from the previous stages
27+
self.nodes = np.append([0], nodes)
28+
self.Qmat = np.zeros([self.num_nodes + 1, self.num_nodes + 1])
29+
self.Qmat[1:, 1:] = matrix
5830

5931
self.left_is_node = True
6032
self.right_is_node = self.nodes[-1] == self.tright
@@ -67,66 +39,58 @@ def __init__(self, weights, nodes, matrix):
6739
self.delta_m[0] = self.nodes[0] - self.tleft
6840

6941
# check if the RK scheme is implicit
70-
self.implicit = any(matrix[i, i] != 0 for i in range(self.num_nodes - self.num_solution_stages))
42+
self.implicit = any(matrix[i, i] != 0 for i in range(self.num_nodes))
7143

72-
73-
class ButcherTableauEmbedded(object):
74-
def __init__(self, weights, nodes, matrix):
44+
def check_method(self, weights, nodes, matrix):
7545
"""
76-
Initialization routine to get a quadrature matrix out of a Butcher tableau for embedded RK methods.
77-
78-
Be aware that the method that generates the final solution should be in the first row of the weights matrix.
79-
80-
Args:
81-
weights (numpy.ndarray): Butcher tableau weights
82-
nodes (numpy.ndarray): Butcher tableau nodes
83-
matrix (numpy.ndarray): Butcher tableau entries
46+
Check that the method is entered in the correct format
8447
"""
85-
# check if the arguments have the correct form
8648
if type(matrix) != np.ndarray:
8749
raise ParameterError('Runge-Kutta matrix needs to be supplied as a numpy array!')
8850
elif len(np.unique(matrix.shape)) != 1 or len(matrix.shape) != 2:
8951
raise ParameterError('Runge-Kutta matrix needs to be a square 2D numpy array!')
9052

91-
if type(weights) != np.ndarray:
92-
raise ParameterError('Weights need to be supplied as a numpy array!')
93-
elif len(weights.shape) != 2:
94-
raise ParameterError(f'Incompatible dimension of weights! Need 2, got {len(weights.shape)}')
95-
elif len(weights[0]) != matrix.shape[0]:
96-
raise ParameterError(f'Incompatible number of weights! Need {matrix.shape[0]}, got {len(weights[0])}')
97-
9853
if type(nodes) != np.ndarray:
9954
raise ParameterError('Nodes need to be supplied as a numpy array!')
10055
elif len(nodes.shape) != 1:
10156
raise ParameterError(f'Incompatible dimension of nodes! Need 1, got {len(nodes.shape)}')
10257
elif len(nodes) != matrix.shape[0]:
10358
raise ParameterError(f'Incompatible number of nodes! Need {matrix.shape[0]}, got {len(nodes)}')
10459

105-
# Set number of nodes, left and right interval boundaries
106-
self.num_solution_stages = 2
107-
self.num_nodes = matrix.shape[0] + self.num_solution_stages
108-
self.tleft = 0.0
109-
self.tright = 1.0
60+
self.check_weights(weights, nodes, matrix)
11061

111-
self.nodes = np.append(np.append([0], nodes), [1, 1])
112-
self.weights = weights
113-
self.Qmat = np.zeros([self.num_nodes + 1, self.num_nodes + 1])
114-
self.Qmat[1:-2, 1:-2] = matrix
115-
self.Qmat[-1, 1:-2] = weights[0] # this is for computing the higher order solution
116-
self.Qmat[-2, 1:-2] = weights[1] # this is for computing the lower order solution
62+
def check_weights(self, weights, nodes, matrix):
63+
"""
64+
Check that the weights of the method are entered in the correct format
65+
"""
66+
if type(weights) != np.ndarray:
67+
raise ParameterError('Weights need to be supplied as a numpy array!')
68+
elif len(weights.shape) != 1:
69+
raise ParameterError(f'Incompatible dimension of weights! Need 1, got {len(weights.shape)}')
70+
elif len(weights) != matrix.shape[0]:
71+
raise ParameterError(f'Incompatible number of weights! Need {matrix.shape[0]}, got {len(weights)}')
11772

118-
self.left_is_node = True
119-
self.right_is_node = self.nodes[-1] == self.tright
73+
@property
74+
def globally_stiffly_accurate(self):
75+
return np.allclose(self.Qmat[-1, 1:], self.weights)
12076

121-
# compute distances between the nodes
122-
if self.num_nodes > 1:
123-
self.delta_m = self.nodes[1:] - self.nodes[:-1]
124-
else:
125-
self.delta_m = np.zeros(1)
126-
self.delta_m[0] = self.nodes[0] - self.tleft
12777

128-
# check if the RK scheme is implicit
129-
self.implicit = any(matrix[i, i] != 0 for i in range(self.num_nodes - self.num_solution_stages))
78+
class ButcherTableauEmbedded(ButcherTableau):
79+
80+
def check_weights(self, weights, nodes, matrix):
81+
"""
82+
Check that the weights of the method are entered in the correct format
83+
"""
84+
if type(weights) != np.ndarray:
85+
raise ParameterError('Weights need to be supplied as a numpy array!')
86+
elif len(weights.shape) != 2:
87+
raise ParameterError(f'Incompatible dimension of weights! Need 2, got {len(weights.shape)}')
88+
elif len(weights[0]) != matrix.shape[0]:
89+
raise ParameterError(f'Incompatible number of weights! Need {matrix.shape[0]}, got {len(weights[0])}')
90+
91+
@property
92+
def globally_stiffly_accurate(self):
93+
return np.allclose(self.Qmat[-1, 1:], self.weights[0])
13094

13195

13296
class RungeKutta(Sweeper):
@@ -292,8 +256,7 @@ def update_nodes(self):
292256
lvl.u[m + 1][:] = rhs[:]
293257

294258
# update function values (we don't usually need to evaluate the RHS at the solution of the step)
295-
if m < M - self.coll.num_solution_stages or self.params.eval_rhs_at_right_boundary:
296-
lvl.f[m + 1] = prob.eval_f(lvl.u[m + 1], lvl.time + lvl.dt * self.coll.nodes[m + 1])
259+
lvl.f[m + 1] = prob.eval_f(lvl.u[m + 1], lvl.time + lvl.dt * self.coll.nodes[m + 1])
297260

298261
# indicate presence of new values at this level
299262
lvl.status.updated = True
@@ -304,7 +267,28 @@ def compute_end_point(self):
304267
"""
305268
In this Runge-Kutta implementation, the solution to the step is always stored in the last node
306269
"""
307-
self.level.uend = self.level.u[-1]
270+
lvl = self.level
271+
272+
if lvl.f[1] is None:
273+
lvl.uend = lvl.prob.dtype_u(lvl.u[0])
274+
if type(self.coll) == ButcherTableauEmbedded:
275+
self.u_secondary = lvl.prob.dtype_u(lvl.u[0])
276+
elif self.coll.globally_stiffly_accurate:
277+
lvl.uend = lvl.prob.dtype_u(lvl.u[-1])
278+
if type(self.coll) == ButcherTableauEmbedded:
279+
self.u_secondary = lvl.prob.dtype_u(lvl.u[0])
280+
for w2, k in zip(self.coll.weights[1], lvl.f[1:]):
281+
self.u_secondary += lvl.dt * w2 * k
282+
else:
283+
lvl.uend = lvl.prob.dtype_u(lvl.u[0])
284+
if type(self.coll) == ButcherTableau:
285+
for w, k in zip(self.coll.weights, lvl.f[1:]):
286+
lvl.uend += lvl.dt * w * k
287+
elif type(self.coll) == ButcherTableauEmbedded:
288+
self.u_secondary = lvl.prob.dtype_u(lvl.u[0])
289+
for w1, w2, k in zip(self.coll.weights[0], self.coll.weights[1], lvl.f[1:]):
290+
lvl.uend += lvl.dt * w1 * k
291+
self.u_secondary += lvl.dt * w2 * k
308292

309293
@property
310294
def level(self):
@@ -356,6 +340,7 @@ class RungeKuttaIMEX(RungeKutta):
356340
"""
357341

358342
matrix_explicit = None
343+
weights_explicit = None
359344
ButcherTableauClass_explicit = ButcherTableau
360345

361346
def __init__(self, params):
@@ -366,6 +351,7 @@ def __init__(self, params):
366351
params: parameters for the sweeper
367352
"""
368353
super().__init__(params)
354+
type(self).weights_explicit = self.weights if self.weights_explicit is None else self.weights_explicit
369355
self.coll_explicit = self.get_Butcher_tableau_explicit()
370356
self.QE = self.coll_explicit.Qmat
371357

@@ -388,7 +374,7 @@ def predict(self):
388374

389375
@classmethod
390376
def get_Butcher_tableau_explicit(cls):
391-
return cls.ButcherTableauClass_explicit(cls.weights, cls.nodes, cls.matrix_explicit)
377+
return cls.ButcherTableauClass_explicit(cls.weights_explicit, cls.nodes, cls.matrix_explicit)
392378

393379
def integrate(self):
394380
"""
@@ -448,15 +434,47 @@ def update_nodes(self):
448434
else:
449435
lvl.u[m + 1][:] = rhs[:]
450436

451-
# update function values (we don't usually need to evaluate the RHS at the solution of the step)
452-
if m < M - self.coll.num_solution_stages or self.params.eval_rhs_at_right_boundary:
453-
lvl.f[m + 1] = prob.eval_f(lvl.u[m + 1], lvl.time + lvl.dt * self.coll.nodes[m + 1])
437+
# update function values
438+
lvl.f[m + 1] = prob.eval_f(lvl.u[m + 1], lvl.time + lvl.dt * self.coll.nodes[m + 1])
454439

455440
# indicate presence of new values at this level
456441
lvl.status.updated = True
457442

458443
return None
459444

445+
def compute_end_point(self):
446+
"""
447+
In this Runge-Kutta implementation, the solution to the step is always stored in the last node
448+
"""
449+
lvl = self.level
450+
451+
if lvl.f[1] is None:
452+
lvl.uend = lvl.prob.dtype_u(lvl.u[0])
453+
if type(self.coll) == ButcherTableauEmbedded:
454+
self.u_secondary = lvl.prob.dtype_u(lvl.u[0])
455+
elif self.coll.globally_stiffly_accurate and self.coll_explicit.globally_stiffly_accurate:
456+
lvl.uend = lvl.u[-1]
457+
if type(self.coll) == ButcherTableauEmbedded:
458+
self.u_secondary = lvl.prob.dtype_u(lvl.u[0])
459+
for w2, w2E, k in zip(self.coll.weights[1], self.coll_explicit.weights[1], lvl.f[1:]):
460+
self.u_secondary += lvl.dt * (w2 * k.impl + w2E * k.expl)
461+
else:
462+
lvl.uend = lvl.prob.dtype_u(lvl.u[0])
463+
if type(self.coll) == ButcherTableau:
464+
for w, wE, k in zip(self.coll.weights, self.coll_explicit.weights, lvl.f[1:]):
465+
lvl.uend += lvl.dt * (w * k.impl + wE * k.expl)
466+
elif type(self.coll) == ButcherTableauEmbedded:
467+
self.u_secondary = lvl.u[0].copy()
468+
for w1, w2, w1E, w2E, k in zip(
469+
self.coll.weights[0],
470+
self.coll.weights[1],
471+
self.coll_explicit.weights[0],
472+
self.coll_explicit.weights[1],
473+
lvl.f[1:],
474+
):
475+
lvl.uend += lvl.dt * (w1 * k.impl + w1E * k.expl)
476+
self.u_secondary += lvl.dt * (w2 * k.impl + w2E * k.expl)
477+
460478

461479
class ForwardEuler(RungeKutta):
462480
"""
@@ -480,6 +498,14 @@ class BackwardEuler(RungeKutta):
480498
nodes, weights, matrix = generator.genCoeffs()
481499

482500

501+
class IMEXEuler(RungeKuttaIMEX):
502+
nodes = BackwardEuler.nodes
503+
weights = BackwardEuler.weights
504+
505+
matrix = BackwardEuler.matrix
506+
matrix_explicit = ForwardEuler.matrix
507+
508+
483509
class CrankNicolson(RungeKutta):
484510
"""
485511
Implicit Runge-Kutta method of second order, A-stable.
@@ -521,8 +547,13 @@ class Heun_Euler(RungeKutta):
521547
Second order explicit embedded Runge-Kutta method.
522548
"""
523549

550+
ButcherTableauClass = ButcherTableauEmbedded
551+
524552
generator = RK_SCHEMES["HEUN"]()
525-
nodes, weights, matrix = generator.genCoeffs()
553+
nodes, _weights, matrix = generator.genCoeffs()
554+
weights = np.zeros((2, len(_weights)))
555+
weights[0] = _weights
556+
weights[1] = matrix[-1]
526557

527558
@classmethod
528559
def get_update_order(cls):
@@ -697,3 +728,41 @@ class ARK548L2SA(RungeKuttaIMEX):
697728
@classmethod
698729
def get_update_order(cls):
699730
return 5
731+
732+
733+
class ARK324L2SAERK(RungeKutta):
734+
generator = RK_SCHEMES["ARK324L2SAERK"]()
735+
nodes, weights, matrix = generator.genCoeffs(embedded=True)
736+
ButcherTableauClass = ButcherTableauEmbedded
737+
738+
@classmethod
739+
def get_update_order(cls):
740+
return 3
741+
742+
743+
class ARK324L2SAESDIRK(ARK324L2SAERK):
744+
generator = RK_SCHEMES["ARK324L2SAESDIRK"]()
745+
matrix = generator.Q
746+
747+
748+
class ARK32(RungeKuttaIMEX):
749+
ButcherTableauClass = ButcherTableauEmbedded
750+
ButcherTableauClass_explicit = ButcherTableauEmbedded
751+
752+
nodes = ARK324L2SAESDIRK.nodes
753+
weights = ARK324L2SAESDIRK.weights
754+
755+
matrix = ARK324L2SAESDIRK.matrix
756+
matrix_explicit = ARK324L2SAERK.matrix
757+
758+
@classmethod
759+
def get_update_order(cls):
760+
return 3
761+
762+
763+
class ARK2(RungeKuttaIMEX):
764+
generator_IMP = RK_SCHEMES["ARK222EDIRK"]()
765+
generator_EXP = RK_SCHEMES["ARK222ERK"]()
766+
767+
nodes, weights, matrix = generator_IMP.genCoeffs()
768+
_, weights_explicit, matrix_explicit = generator_EXP.genCoeffs()

0 commit comments

Comments
 (0)