Skip to content

Commit 7e7e73f

Browse files
authored
Runge-Kutta sweeper for DAEs (#432)
* Added DIRK43_2 and EDIRK4 methods * Runge-Kutta sweeper for DAEs * Every problem needs a du_exact method when using RK-DAE sweepers * Tests for RK methods * Added label in documentation * Typos + some corrections * First requested changes * Added statement: predict only with du_exact in first step * Make function for implicit system to be solved as static method in FI-SDC and SI-SDC sweepers * Removed deep copy
1 parent d89cac1 commit 7e7e73f

File tree

9 files changed

+700
-56
lines changed

9 files changed

+700
-56
lines changed

pySDC/implementations/sweeper_classes/Runge_Kutta.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -590,6 +590,38 @@ def get_update_order(cls):
590590
return 4
591591

592592

593+
class DIRK43_2(RungeKutta):
594+
"""
595+
L-stable Diagonally Implicit RK method with four stages of order 3.
596+
Taken from [here](https://en.wikipedia.org/wiki/List_of_Runge%E2%80%93Kutta_methods).
597+
"""
598+
599+
nodes = np.array([0.5, 2.0 / 3.0, 0.5, 1.0])
600+
weights = np.array([3.0 / 2.0, -3.0 / 2.0, 0.5, 0.5])
601+
matrix = np.zeros((4, 4))
602+
matrix[0, 0] = 0.5
603+
matrix[1, :2] = [1.0 / 6.0, 0.5]
604+
matrix[2, :3] = [-0.5, 0.5, 0.5]
605+
matrix[3, :] = [3.0 / 2.0, -3.0 / 2.0, 0.5, 0.5]
606+
ButcherTableauClass = ButcherTableau
607+
608+
609+
class EDIRK4(RungeKutta):
610+
"""
611+
Stiffly accurate, fourth-order EDIRK with four stages. Taken from
612+
[here](https://ntrs.nasa.gov/citations/20160005923), second one in eq. (216).
613+
"""
614+
615+
nodes = np.array([0.0, 3.0 / 2.0, 7.0 / 5.0, 1.0])
616+
weights = np.array([13.0, 84.0, -125.0, 70.0]) / 42.0
617+
matrix = np.zeros((4, 4))
618+
matrix[0, 0] = 0
619+
matrix[1, :2] = [3.0 / 4.0, 3.0 / 4.0]
620+
matrix[2, :3] = [447.0 / 675.0, -357.0 / 675.0, 855.0 / 675.0]
621+
matrix[3, :] = [13.0 / 42.0, 84.0 / 42.0, -125.0 / 42.0, 70.0 / 42.0]
622+
ButcherTableauClass = ButcherTableau
623+
624+
593625
class ESDIRK53(RungeKutta):
594626
"""
595627
A-stable embedded RK pair of orders 5 and 3, ESDIRK5(3)6L[2]SA.

pySDC/projects/DAE/misc/ProblemDAE.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,14 +36,19 @@ def __init__(self, nvars, newton_tol):
3636
self.work_counters['newton'] = WorkCounter()
3737
self.work_counters['rhs'] = WorkCounter()
3838

39-
def solve_system(self, impl_sys, u0, t):
39+
def solve_system(self, impl_sys, u_approx, factor, u0, t):
4040
r"""
4141
Solver for nonlinear implicit system (defined in sweeper).
4242
4343
Parameters
4444
----------
4545
impl_sys : callable
4646
Implicit system to be solved.
47+
u_approx : dtype_u
48+
Approximation of solution :math:`u` which is needed to solve
49+
the implicit system.
50+
factor : float
51+
Abbrev. for the node-to-node stepsize.
4752
u0 : dtype_u
4853
Initial guess for solver.
4954
t : float
@@ -57,7 +62,7 @@ def solve_system(self, impl_sys, u0, t):
5762
me = self.dtype_u(self.init)
5863

5964
def implSysFlatten(unknowns, **kwargs):
60-
sys = impl_sys(unknowns.reshape(me.shape).view(type(u0)), **kwargs)
65+
sys = impl_sys(unknowns.reshape(me.shape).view(type(u0)), self, factor, u_approx, t, **kwargs)
6166
return sys.flatten()
6267

6368
opt = root(
@@ -69,3 +74,21 @@ def implSysFlatten(unknowns, **kwargs):
6974
me[:] = opt.x.reshape(me.shape)
7075
self.work_counters['newton'].niter += opt.nfev
7176
return me
77+
78+
def du_exact(self, t):
79+
r"""
80+
Routine for the derivative of the exact solution at time :math:`t \leq 1`.
81+
For this problem, the exact solution is piecewise.
82+
83+
Parameters
84+
----------
85+
t : float
86+
Time of the exact solution.
87+
88+
Returns
89+
-------
90+
me : dtype_u
91+
Derivative of exact solution.
92+
"""
93+
94+
raise NotImplementedError('ERROR: problem has to implement du_exact(self, t)!')

pySDC/projects/DAE/problems/DiscontinuousTestDAE.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,33 @@ def u_exact(self, t, **kwargs):
128128
me.alg[0] = np.sinh(self.t_switch_exact)
129129
return me
130130

131+
def du_exact(self, t, **kwargs):
132+
r"""
133+
Routine for the derivative of the exact solution at time :math:`t \leq 1`.
134+
For this problem, the exact solution is piecewise.
135+
136+
Parameters
137+
----------
138+
t : float
139+
Time of the exact solution.
140+
141+
Returns
142+
-------
143+
me : dtype_u
144+
Derivative of exact solution.
145+
"""
146+
147+
assert t >= 1, 'ERROR: u_exact only available for t>=1'
148+
149+
me = self.dtype_u(self.init)
150+
if t <= self.t_switch_exact:
151+
me.diff[0] = np.sinh(t)
152+
me.alg[0] = np.cosh(t)
153+
else:
154+
me.diff[0] = np.sinh(self.t_switch_exact)
155+
me.alg[0] = np.cosh(self.t_switch_exact)
156+
return me
157+
131158
def get_switching_info(self, u, t):
132159
r"""
133160
Provides information about the state function of the problem. A change in sign of the state function

pySDC/projects/DAE/problems/simple_DAE.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,26 @@ def u_exact(self, t):
210210
me.alg[0] = -np.exp(t) / (2 - t)
211211
return me
212212

213+
def du_exact(self, t):
214+
"""
215+
Routine for the derivative of the exact solution.
216+
217+
Parameters
218+
----------
219+
t : float
220+
The time of the reference solution.
221+
222+
Returns
223+
-------
224+
me : dtype_u
225+
The reference solution as mesh object containing three components.
226+
"""
227+
228+
me = self.dtype_u(self.init)
229+
me.diff[:2] = (np.exp(t), np.exp(t))
230+
me.alg[0] = (np.exp(t) * (t - 3)) / ((2 - t) ** 2)
231+
return me
232+
213233

214234
class problematic_f(ptype_dae):
215235
r"""
@@ -293,3 +313,22 @@ def u_exact(self, t):
293313
me = self.dtype_u(self.init)
294314
me[:] = (np.sin(t), 0)
295315
return me
316+
317+
def du_exact(self, t):
318+
"""
319+
Routine for the derivative of the exact solution.
320+
321+
Parameters
322+
----------
323+
t : float
324+
The time of the reference solution.
325+
326+
Returns
327+
-------
328+
me : dtype_u
329+
The reference solution as mesh object containing two components.
330+
"""
331+
332+
me = self.dtype_u(self.init)
333+
me[:] = (np.cos(t), 0)
334+
return me
Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
from pySDC.projects.DAE.sweepers.fully_implicit_DAE import fully_implicit_DAE
2+
from pySDC.implementations.sweeper_classes.Runge_Kutta import (
3+
RungeKutta,
4+
BackwardEuler,
5+
CrankNicholson,
6+
EDIRK4,
7+
DIRK43_2,
8+
)
9+
10+
11+
class RungeKuttaDAE(RungeKutta):
12+
r"""
13+
Custom sweeper class to implement Runge-Kutta (RK) methods for general differential-algebraic equations (DAEs)
14+
of the form
15+
16+
.. math::
17+
0 = F(u, u', t).
18+
19+
RK methods for general DAEs have the form
20+
21+
.. math::
22+
0 = F(u_0 + \Delta t \sum_{j=1}^M a_{i,j} U_j, U_m),
23+
24+
.. math::
25+
u_M = u_0 + \Delta t \sum_{j=1}^M b_j U_j.
26+
27+
In pySDC, RK methods are implemented in the way that the coefficient matrix :math:`A` in the Butcher
28+
tableau is a lower triangular matrix so that the stages are updated node-by-node. This class therefore only supports
29+
RK methods with lower triangular matrices in the tableau.
30+
31+
Parameters
32+
----------
33+
params : dict
34+
Parameters passed to the sweeper.
35+
36+
Attributes
37+
----------
38+
du_init : dtype_f
39+
Stores the initial condition for each step.
40+
41+
Note
42+
----
43+
When using a RK sweeper to simulate a problem make sure the DAE problem class has a ``du_exact`` method since RK methods need an initial
44+
condition for :math:`u'(t_0)` as well.
45+
46+
In order to implement a new RK method for DAEs a new tableau can be added in ``pySDC.implementations.sweeper_classes.Runge_Kutta.py``.
47+
For example, a new method called ``newRungeKuttaMethod`` with nodes :math:`c=(c_1, c_2, c_3)`, weights :math:`b=(b_1, b_2, b_3)` and
48+
coefficient matrix
49+
50+
..math::
51+
\begin{eqnarray}
52+
A = \begin{pmatrix}
53+
a_{11} & 0 & 0 \\
54+
a_{21} & a_{22} & 0 \\
55+
a_{31} & a_{32} & & 0 \\
56+
\end{pmatrix}
57+
\end{eqnarray}
58+
59+
can be implemented as follows:
60+
61+
>>> class newRungeKuttaMethod(RungeKutta):
62+
>>> nodes = np.array([c1, c2, c3])
63+
>>> weights = np.array([b1, b2, b3])
64+
>>> matrix = np.zeros((3, 3))
65+
>>> matrix[0, 0] = a11
66+
>>> matrix[1, :2] = [a21, a22]
67+
>>> matrix[2, :] = [a31, a32, a33]
68+
>>> ButcherTableauClass = ButcherTableau
69+
70+
The new class ``newRungeKuttaMethodDAE`` can then be used by defining the DAE class inheriting from both, this base class and class containing
71+
the Butcher tableau:
72+
73+
>>> class newRungeKuttaMethodDAE(RungeKuttaDAE, newRungeKuttaMethod):
74+
>>> pass
75+
76+
More details can be found [here](https://github.com/Parallel-in-Time/pySDC/blob/master/pySDC/implementations/sweeper_classes/Runge_Kutta.py).
77+
"""
78+
79+
def __init__(self, params):
80+
super().__init__(params)
81+
self.du_init = None
82+
self.fully_initialized = False
83+
84+
def predict(self):
85+
"""
86+
Predictor to fill values with zeros at nodes before first sweep.
87+
"""
88+
89+
# get current level and problem
90+
lvl = self.level
91+
prob = lvl.prob
92+
93+
if not self.fully_initialized:
94+
self.du_init = prob.du_exact(lvl.time)
95+
self.fully_initialized = True
96+
97+
lvl.f[0] = prob.dtype_f(self.du_init)
98+
for m in range(1, self.coll.num_nodes + 1):
99+
lvl.u[m] = prob.dtype_u(init=prob.init, val=0.0)
100+
lvl.f[m] = prob.dtype_f(init=prob.init, val=0.0)
101+
102+
lvl.status.unlocked = True
103+
lvl.status.updated = True
104+
105+
def update_nodes(self):
106+
r"""
107+
Updates the values of solution ``u`` and their gradient stored in ``f``.
108+
"""
109+
110+
# get current level and problem description
111+
lvl = self.level
112+
prob = lvl.prob
113+
114+
# only if the level has been touched before
115+
assert lvl.status.unlocked
116+
assert lvl.status.sweep <= 1, "RK schemes are direct solvers. Please perform only 1 iteration!"
117+
118+
M = self.coll.num_nodes
119+
for m in range(M):
120+
u_approx = prob.dtype_u(lvl.u[0])
121+
for j in range(1, m + 1):
122+
u_approx += lvl.dt * self.QI[m + 1, j] * lvl.f[j][:]
123+
124+
finit = lvl.f[m].flatten()
125+
lvl.f[m + 1][:] = prob.solve_system(
126+
fully_implicit_DAE.F,
127+
u_approx,
128+
lvl.dt * self.QI[m + 1, m + 1],
129+
finit,
130+
lvl.time + lvl.dt * self.coll.nodes[m + 1],
131+
)
132+
133+
# Update numerical solution - update value only at last node
134+
lvl.u[-1][:] = lvl.u[0]
135+
for j in range(1, M + 1):
136+
lvl.u[-1][:] += lvl.dt * self.coll.Qmat[-1, j] * lvl.f[j][:]
137+
138+
self.du_init = prob.dtype_f(lvl.f[-1])
139+
140+
lvl.status.updated = True
141+
142+
return None
143+
144+
145+
class BackwardEulerDAE(RungeKuttaDAE, BackwardEuler):
146+
pass
147+
148+
149+
class TrapezoidalRuleDAE(RungeKuttaDAE, CrankNicholson):
150+
pass
151+
152+
153+
class EDIRK4DAE(RungeKuttaDAE, EDIRK4):
154+
pass
155+
156+
157+
class DIRK43_2DAE(RungeKuttaDAE, DIRK43_2):
158+
pass

0 commit comments

Comments
 (0)