Skip to content

Commit 17f4625

Browse files
author
Thomas Baumann
committed
Added a hook for logging the error at the end of the run.
1 parent 6c2d827 commit 17f4625

File tree

4 files changed

+79
-7
lines changed

4 files changed

+79
-7
lines changed
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
import numpy as np
2+
from pySDC.core.Hooks import hooks
3+
4+
5+
class LogGlobalErrorPostRun(hooks):
6+
"""
7+
Compute the global error once after the run is finished.
8+
"""
9+
10+
def __init__(self):
11+
"""
12+
Add an attribute for when the last solution was added.
13+
"""
14+
super().__init__()
15+
self.__t_last_solution = 0
16+
17+
def post_step(self, step, level_number):
18+
"""
19+
Store the time at which the solution is stored.
20+
This is required because between the `post_step` hook where the solution is stored and the `post_run` hook
21+
where the error is stored, the step size can change.
22+
23+
Args:
24+
step (pySDC.Step.step): The current step
25+
level_number (int): The index of the level
26+
27+
Returns:
28+
None
29+
"""
30+
super().post_step(step, level_number)
31+
self.__t_last_solution = step.levels[0].time + step.levels[0].dt
32+
33+
def post_run(self, step, level_number):
34+
"""
35+
Log the global error.
36+
37+
Args:
38+
step (pySDC.Step.step): The current step
39+
level_number (int): The index of the level
40+
41+
Returns:
42+
None
43+
"""
44+
super().post_run(step, level_number)
45+
46+
if level_number == 0:
47+
L = step.levels[level_number]
48+
49+
e_glob = np.linalg.norm(L.prob.u_exact(t=self.__t_last_solution) - L.uend, np.inf)
50+
51+
if step.status.last:
52+
self.logger.info(f'Finished with a global error of e={e_glob:.2e}')
53+
54+
self.add_to_stats(
55+
process=step.status.slot,
56+
time=L.time + L.dt,
57+
level=L.level_index,
58+
iter=step.status.iter,
59+
sweep=L.status.sweep,
60+
type='e_global',
61+
value=e_glob,
62+
)

pySDC/implementations/problem_classes/Lorenz.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ class LorenzAttractor(ptype):
1313
1414
Since the problem is non-linear, we need to use a Newton solver.
1515
16-
Problem and initial conditions does not originate from, but was taken from doi.org/10.2140/camcos.2015.10.1
16+
Problem and initial conditions do not originate from, but were taken from doi.org/10.2140/camcos.2015.10.1
1717
"""
1818

1919
def __init__(self, problem_params):
@@ -93,7 +93,7 @@ def solve_system(self, rhs, dt, u0, t):
9393
# start Newton iterations
9494
u = self.dtype_u(u0)
9595
res = np.inf
96-
for n in range(0, self.params.newton_maxiter):
96+
for _n in range(0, self.params.newton_maxiter):
9797

9898
# assemble G such that G(u) = 0 at the solution to the step
9999
G = np.array(

pySDC/projects/Resilience/Lorentz.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def run_Lorenz(
6666
# initialize controller parameters
6767
controller_params = dict()
6868
controller_params['logger_level'] = 30
69-
controller_params['hook_class'] = hook_collection + [hook_class]
69+
controller_params['hook_class'] = hook_collection + hook_class if type(hook_class) == list else [hook_class]
7070
controller_params['mssdc_jac'] = False
7171

7272
if custom_controller_params is not None:
@@ -148,6 +148,7 @@ def plot_solution(stats):
148148
def check_solution(stats, controller, thresh=5e-4):
149149
"""
150150
Check if the global error solution wrt. a scipy reference solution is tolerable.
151+
This is also a check for the global error hook.
151152
152153
Args:
153154
stats (dict): The stats object of the run
@@ -159,8 +160,12 @@ def check_solution(stats, controller, thresh=5e-4):
159160
"""
160161
u = get_sorted(stats, type='u')
161162
u_exact = controller.MS[0].levels[0].prob.u_exact(t=u[-1][0])
162-
error = np.linalg.norm(u[-1][1] - u_exact)
163+
error = np.linalg.norm(u[-1][1] - u_exact, np.inf)
164+
error_hook = get_sorted(stats, type='e_global')[-1][1]
163165

166+
dt = get_sorted(stats, type='dt')
167+
168+
assert error == error_hook, f'Expected errors to match, got {error:.2e} and {error_hook:.2e}!'
164169
assert error < thresh, f"Error too large, got e={error:.2e}"
165170

166171

@@ -174,11 +179,16 @@ def main(plotting=True):
174179
Returns:
175180
None
176181
"""
182+
from pySDC.implementations.hooks.log_errors import LogGlobalErrorPostRun
183+
177184
custom_description = {}
178185
custom_description['convergence_controllers'] = {Adaptivity: {'e_tol': 1e-5}}
179-
custom_controller_params = {'logger_level': 30}
186+
custom_controller_params = {'logger_level': 15}
180187
stats, controller, _ = run_Lorenz(
181-
custom_description=custom_description, custom_controller_params=custom_controller_params, Tend=10
188+
custom_description=custom_description,
189+
custom_controller_params=custom_controller_params,
190+
Tend=10,
191+
hook_class=[log_data, LogGlobalErrorPostRun],
182192
)
183193
check_solution(stats, controller, 5e-4)
184194
if plotting:

pySDC/tests/test_projects/test_resilience/test_Lorenz.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,4 @@
55
def test_main():
66
from pySDC.projects.Resilience.Lorentz import main
77

8-
main(False)
8+
main(plotting=False)

0 commit comments

Comments
 (0)