Skip to content

Commit 10c4ba1

Browse files
committed
Globally stiffly accurate ARK methods
1 parent 2f84eda commit 10c4ba1

File tree

3 files changed

+143
-33
lines changed

3 files changed

+143
-33
lines changed

pySDC/implementations/convergence_controller_classes/estimate_embedded_error.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,10 +93,13 @@ def estimate_embedded_error_serial(self, L):
9393
dtype_u: The embedded error estimate
9494
"""
9595
if self.params.sweeper_type == "RK":
96-
# lower order solution is stored in the second to last entry of L.u
97-
return abs(L.u[-2] - L.u[-1])
96+
if L.f[1] is None:
97+
return -1
98+
else:
99+
L.sweep.compute_end_point()
100+
return abs(L.uend - L.sweep.u_secondary)
98101
elif self.params.sweeper_type == "SDC":
99-
# order rises by one between sweeps, making this so ridiculously easy
102+
# order rises by one between sweeps
100103
return abs(L.uold[-1] - L.u[-1])
101104
elif self.params.sweeper_type == 'MPI':
102105
comm = L.sweep.comm

pySDC/implementations/sweeper_classes/Runge_Kutta.py

Lines changed: 111 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -41,20 +41,12 @@ def __init__(self, weights, nodes, matrix):
4141

4242
self.tleft = 0.0
4343
self.tright = 1.0
44-
self.num_solution_stages = 0 if self.globally_stiffly_accurate else 1
45-
self.num_nodes = matrix.shape[0] + self.num_solution_stages
44+
self.num_nodes = matrix.shape[0]
4645
self.weights = weights
4746

48-
if self.globally_stiffly_accurate:
49-
# For globally stiffly accurate methods, the last row of the Butcher tableau is the same as the weights.
50-
self.nodes = np.append([0], nodes)
51-
self.Qmat = np.zeros([self.num_nodes + 1, self.num_nodes + 1])
52-
self.Qmat[1:, 1:] = matrix
53-
else:
54-
self.nodes = np.append(np.append([0], nodes), [1])
55-
self.Qmat = np.zeros([self.num_nodes + 1, self.num_nodes + 1])
56-
self.Qmat[1:-1, 1:-1] = matrix
57-
self.Qmat[-1, 1:-1] = weights # this is for computing the solution to the step from the previous stages
47+
self.nodes = np.append([0], nodes)
48+
self.Qmat = np.zeros([self.num_nodes + 1, self.num_nodes + 1])
49+
self.Qmat[1:, 1:] = matrix
5850

5951
self.left_is_node = True
6052
self.right_is_node = self.nodes[-1] == self.tright
@@ -67,7 +59,7 @@ def __init__(self, weights, nodes, matrix):
6759
self.delta_m[0] = self.nodes[0] - self.tleft
6860

6961
# check if the RK scheme is implicit
70-
self.implicit = any(matrix[i, i] != 0 for i in range(self.num_nodes - self.num_solution_stages))
62+
self.implicit = any(matrix[i, i] != 0 for i in range(self.num_nodes))
7163

7264

7365
class ButcherTableauEmbedded(object):
@@ -103,21 +95,20 @@ def __init__(self, weights, nodes, matrix):
10395
raise ParameterError(f'Incompatible number of nodes! Need {matrix.shape[0]}, got {len(nodes)}')
10496

10597
# Set number of nodes, left and right interval boundaries
106-
self.num_solution_stages = 2
107-
self.num_nodes = matrix.shape[0] + self.num_solution_stages
98+
self.num_nodes = matrix.shape[0]
10899
self.tleft = 0.0
109100
self.tright = 1.0
110101

111102
self.nodes = np.append(np.append([0], nodes), [1, 1])
112103
self.weights = weights
113104
self.Qmat = np.zeros([self.num_nodes + 1, self.num_nodes + 1])
114-
self.Qmat[1:-2, 1:-2] = matrix
115-
self.Qmat[-1, 1:-2] = weights[0] # this is for computing the higher order solution
116-
self.Qmat[-2, 1:-2] = weights[1] # this is for computing the lower order solution
105+
self.Qmat[1:, 1:] = matrix
117106

118107
self.left_is_node = True
119108
self.right_is_node = self.nodes[-1] == self.tright
120109

110+
self.globally_stiffly_accurate = np.allclose(matrix[-1], weights[0])
111+
121112
# compute distances between the nodes
122113
if self.num_nodes > 1:
123114
self.delta_m = self.nodes[1:] - self.nodes[:-1]
@@ -126,7 +117,7 @@ def __init__(self, weights, nodes, matrix):
126117
self.delta_m[0] = self.nodes[0] - self.tleft
127118

128119
# check if the RK scheme is implicit
129-
self.implicit = any(matrix[i, i] != 0 for i in range(self.num_nodes - self.num_solution_stages))
120+
self.implicit = any(matrix[i, i] != 0 for i in range(self.num_nodes))
130121

131122

132123
class RungeKutta(Sweeper):
@@ -292,8 +283,7 @@ def update_nodes(self):
292283
lvl.u[m + 1][:] = rhs[:]
293284

294285
# update function values (we don't usually need to evaluate the RHS at the solution of the step)
295-
if m < M - self.coll.num_solution_stages or self.params.eval_rhs_at_right_boundary:
296-
lvl.f[m + 1] = prob.eval_f(lvl.u[m + 1], lvl.time + lvl.dt * self.coll.nodes[m + 1])
286+
lvl.f[m + 1] = prob.eval_f(lvl.u[m + 1], lvl.time + lvl.dt * self.coll.nodes[m + 1])
297287

298288
# indicate presence of new values at this level
299289
lvl.status.updated = True
@@ -304,7 +294,22 @@ def compute_end_point(self):
304294
"""
305295
In this Runge-Kutta implementation, the solution to the step is always stored in the last node
306296
"""
307-
self.level.uend = self.level.u[-1]
297+
if self.coll.globally_stiffly_accurate:
298+
self.level.uend = self.level.u[-1]
299+
if type(self.coll) == ButcherTableauEmbedded:
300+
self.u_secondary = self.level.u[0].copy()
301+
for w2, k in zip(self.coll.weights[1], self.level.f[1:]):
302+
self.u_secondary += self.level.dt * w2 * k
303+
else:
304+
self.level.uend = self.level.u[0].copy()
305+
if type(self.coll) == ButcherTableau:
306+
for w, k in zip(self.coll.weights, self.level.f[1:]):
307+
self.level.uend += self.level.dt * w * k
308+
elif type(self.coll) == ButcherTableauEmbedded:
309+
self.u_secondary = self.level.u[0].copy()
310+
for w1, w2, k in zip(self.coll.weights[0], self.coll.weights[1], self.level.f[1:]):
311+
self.level.uend += self.level.dt * w1 * k
312+
self.u_secondary += self.level.dt * w2 * k
308313

309314
@property
310315
def level(self):
@@ -356,6 +361,7 @@ class RungeKuttaIMEX(RungeKutta):
356361
"""
357362

358363
matrix_explicit = None
364+
weights_explicit = None
359365
ButcherTableauClass_explicit = ButcherTableau
360366

361367
def __init__(self, params):
@@ -366,6 +372,7 @@ def __init__(self, params):
366372
params: parameters for the sweeper
367373
"""
368374
super().__init__(params)
375+
type(self).weights_explicit = self.weights if self.weights_explicit is None else self.weights_explicit
369376
self.coll_explicit = self.get_Butcher_tableau_explicit()
370377
self.QE = self.coll_explicit.Qmat
371378

@@ -388,7 +395,7 @@ def predict(self):
388395

389396
@classmethod
390397
def get_Butcher_tableau_explicit(cls):
391-
return cls.ButcherTableauClass_explicit(cls.weights, cls.nodes, cls.matrix_explicit)
398+
return cls.ButcherTableauClass_explicit(cls.weights_explicit, cls.nodes, cls.matrix_explicit)
392399

393400
def integrate(self):
394401
"""
@@ -448,15 +455,41 @@ def update_nodes(self):
448455
else:
449456
lvl.u[m + 1][:] = rhs[:]
450457

451-
# update function values (we don't usually need to evaluate the RHS at the solution of the step)
452-
if m < M - self.coll.num_solution_stages or self.params.eval_rhs_at_right_boundary:
453-
lvl.f[m + 1] = prob.eval_f(lvl.u[m + 1], lvl.time + lvl.dt * self.coll.nodes[m + 1])
458+
# update function values
459+
lvl.f[m + 1] = prob.eval_f(lvl.u[m + 1], lvl.time + lvl.dt * self.coll.nodes[m + 1])
454460

455461
# indicate presence of new values at this level
456462
lvl.status.updated = True
457463

458464
return None
459465

466+
def compute_end_point(self):
467+
"""
468+
In this Runge-Kutta implementation, the solution to the step is always stored in the last node
469+
"""
470+
if self.coll.globally_stiffly_accurate and self.coll_explicit.globally_stiffly_accurate:
471+
self.level.uend = self.level.u[-1]
472+
if type(self.coll) == ButcherTableauEmbedded:
473+
self.u_secondary = self.level.u[0].copy()
474+
for w2, w2E, k in zip(self.coll.weights[1], self.coll_explicit.weights[1], self.level.f[1:]):
475+
self.u_secondary += self.level.dt * (w2 * k.impl + w2E * k.expl)
476+
else:
477+
self.level.uend = self.level.u[0].copy()
478+
if type(self.coll) == ButcherTableau:
479+
for w, wE, k in zip(self.coll.weights, self.coll_explicit.weights, self.level.f[1:]):
480+
self.level.uend += self.level.dt * (w * k.impl + wE * k.expl)
481+
elif type(self.coll) == ButcherTableauEmbedded:
482+
self.u_secondary = self.level.u[0].copy()
483+
for w1, w2, w1E, w2E, k in zip(
484+
self.coll.weights[0],
485+
self.coll.weights[1],
486+
self.coll_explicit.weights[0],
487+
self.coll_explicit.weights[1],
488+
self.level.f[1:],
489+
):
490+
self.level.uend += self.level.dt * (w1 * k.impl + w1E * k.expl)
491+
self.u_secondary += self.level.dt * (w2 * k.impl + w2E * k.expl)
492+
460493

461494
class ForwardEuler(RungeKutta):
462495
"""
@@ -480,6 +513,14 @@ class BackwardEuler(RungeKutta):
480513
nodes, weights, matrix = generator.genCoeffs()
481514

482515

516+
class IMEXEuler(RungeKuttaIMEX):
517+
nodes = BackwardEuler.nodes
518+
weights = BackwardEuler.weights
519+
520+
matrix = BackwardEuler.matrix
521+
matrix_explicit = ForwardEuler.matrix
522+
523+
483524
class CrankNicolson(RungeKutta):
484525
"""
485526
Implicit Runge-Kutta method of second order, A-stable.
@@ -521,8 +562,13 @@ class Heun_Euler(RungeKutta):
521562
Second order explicit embedded Runge-Kutta method.
522563
"""
523564

565+
ButcherTableauClass = ButcherTableauEmbedded
566+
524567
generator = RK_SCHEMES["HEUN"]()
525-
nodes, weights, matrix = generator.genCoeffs()
568+
nodes, _weights, matrix = generator.genCoeffs()
569+
weights = np.zeros((2, len(_weights)))
570+
weights[0] = _weights
571+
weights[1] = matrix[-1]
526572

527573
@classmethod
528574
def get_update_order(cls):
@@ -697,3 +743,41 @@ class ARK548L2SA(RungeKuttaIMEX):
697743
@classmethod
698744
def get_update_order(cls):
699745
return 5
746+
747+
748+
class ARK324L2SAERK(RungeKutta):
749+
generator = RK_SCHEMES["ARK324L2SAERK"]()
750+
nodes, weights, matrix = generator.genCoeffs(embedded=True)
751+
ButcherTableauClass = ButcherTableauEmbedded
752+
753+
@classmethod
754+
def get_update_order(cls):
755+
return 3
756+
757+
758+
class ARK324L2SAESDIRK(ARK324L2SAERK):
759+
generator = RK_SCHEMES["ARK324L2SAESDIRK"]()
760+
matrix = generator.Q
761+
762+
763+
class ARK32(RungeKuttaIMEX):
764+
ButcherTableauClass = ButcherTableauEmbedded
765+
ButcherTableauClass_explicit = ButcherTableauEmbedded
766+
767+
nodes = ARK324L2SAESDIRK.nodes
768+
weights = ARK324L2SAESDIRK.weights
769+
770+
matrix = ARK324L2SAESDIRK.matrix
771+
matrix_explicit = ARK324L2SAERK.matrix
772+
773+
@classmethod
774+
def get_update_order(cls):
775+
return 3
776+
777+
778+
class ARK2(RungeKuttaIMEX):
779+
generator_IMP = RK_SCHEMES["ARK222EDIRK"]()
780+
generator_EXP = RK_SCHEMES["ARK222ERK"]()
781+
782+
nodes, weights, matrix = generator_IMP.genCoeffs()
783+
_, weights_explicit, matrix_explicit = generator_EXP.genCoeffs()

pySDC/tests/test_sweepers/test_Runge_Kutta_sweeper.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,15 @@
1818
'ARK548L2SAERK',
1919
'ARK548L2SAESDIRK2',
2020
'ARK548L2SAERK2',
21+
'ARK324L2SAERK',
22+
'ARK324L2SAESDIRK',
2123
]
2224
IMEX_SWEEPERS = [
2325
'ARK54',
2426
'ARK548L2SA',
27+
'IMEXEuler',
28+
'ARK32',
29+
'ARK2',
2530
]
2631

2732

@@ -61,6 +66,7 @@ def single_run(sweeper_name, dt, lambdas, use_RK_sweeper=True, Tend=None, useGPU
6166
from pySDC.implementations.hooks.log_embedded_error_estimate import LogEmbeddedErrorEstimate
6267
from pySDC.implementations.convergence_controller_classes.estimate_embedded_error import EstimateEmbeddedError
6368
from pySDC.implementations.sweeper_classes.generic_implicit import generic_implicit
69+
from pySDC.implementations.sweeper_classes.Runge_Kutta import ButcherTableauEmbedded
6470

6571
level_params = {'dt': dt}
6672

@@ -93,21 +99,28 @@ def single_run(sweeper_name, dt, lambdas, use_RK_sweeper=True, Tend=None, useGPU
9399
'problem_class': problem_class,
94100
'sweeper_params': sweeper_params,
95101
'problem_params': problem_params,
96-
'convergence_controllers': {EstimateEmbeddedError: {}},
102+
'convergence_controllers': {},
97103
}
98104

99105
controller_params = {
100106
'logger_level': 40,
101107
'hook_class': [LogWork, LogGlobalErrorPostRun, LogSolution, LogEmbeddedErrorEstimate],
102108
}
103109

110+
if (
111+
hasattr(description['sweeper_class'], 'ButcherTableauClass')
112+
and description['sweeper_class'].ButcherTableauClass == ButcherTableauEmbedded
113+
):
114+
description['convergence_controllers'][EstimateEmbeddedError] = {}
115+
104116
controller = controller_nonMPI(num_procs=1, controller_params=controller_params, description=description)
105117

106118
if not use_RK_sweeper:
107119
rk_sweeper = get_sweeper(sweeper_name)
108120
sweeper = controller.MS[0].levels[0].sweep
109121
sweeper.QI = rk_sweeper.get_Q_matrix()
110122
sweeper.coll = rk_sweeper.get_Butcher_tableau()
123+
type(sweeper).compute_end_point = rk_sweeper.compute_end_point
111124

112125
prob = controller.MS[0].levels[0].prob
113126
ic = prob.u_exact(0)
@@ -149,6 +162,11 @@ def test_order(sweeper_name, useGPU=False):
149162
'ARK548L2SAESDIRK2': 6,
150163
'ARK54': 6,
151164
'ARK548L2SA': 6,
165+
'IMEXEuler': 2,
166+
'ARK2': 3,
167+
'ARK32': 4,
168+
'ARK324L2SAERK': 4,
169+
'ARK324L2SAESDIRK': 4,
152170
}
153171

154172
dt_max = {
@@ -160,6 +178,7 @@ def test_order(sweeper_name, useGPU=False):
160178
'ARK548L2SAERK2': 1e0,
161179
'ARK54': 5e-2,
162180
'ARK548L2SA': 5e-2,
181+
'IMEXEuler': 1e-2,
163182
}
164183

165184
lambdas = [[-1.0e-1 + 0j]]
@@ -233,6 +252,8 @@ def test_stability(sweeper_name, useGPU=False):
233252
'ARK548L2SAERK': False,
234253
'ARK548L2SAESDIRK2': True,
235254
'ARK548L2SAERK2': False,
255+
'ARK324L2SAERK': False,
256+
'ARK324L2SAESDIRK': True,
236257
}
237258

238259
re = -np.logspace(-3, 2, 50)
@@ -271,7 +292,7 @@ def test_rhs_evals(sweeper_name, useGPU=False):
271292
stats, _, controller = single_run(sweeper_name, 1.0, lambdas, Tend=10.0, useGPU=useGPU)
272293

273294
sweep = controller.MS[0].levels[0].sweep
274-
num_stages = sweep.coll.num_nodes - sweep.coll.num_solution_stages
295+
num_stages = sweep.coll.num_nodes
275296

276297
rhs_evaluations = [me[1] for me in get_sorted(stats, type='work_rhs')]
277298

@@ -357,5 +378,7 @@ def test_RK_sweepers_with_GPU(test_name, sweeper_name):
357378

358379
if __name__ == '__main__':
359380
# test_rhs_evals('ARK54')
360-
test_order('CrankNicolson')
381+
test_order('ARK2')
382+
# test_order('ARK54')
383+
# test_sweeper_equivalence('Cash_Karp')
361384
# test_order('ARK54')

0 commit comments

Comments
 (0)