Skip to content

Commit 6191177

Browse files
author
Thomas Baumann
committed
Error estimating convergence controllers will now add hooks to record
their estimates
1 parent c74a9fe commit 6191177

File tree

12 files changed

+50
-144
lines changed

12 files changed

+50
-144
lines changed

pySDC/core/Controller.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,12 +42,11 @@ def __init__(self, controller_params, description):
4242

4343
# check if we have a hook on this list. If not, use default class.
4444
self.__hooks = []
45-
self.hook_classes = [default_hooks]
45+
hook_classes = [default_hooks]
4646
user_hooks = controller_params.get('hook_class', [])
47-
self.hook_classes += user_hooks if type(user_hooks) == list else [user_hooks]
48-
for hook in self.hook_classes:
49-
self.__hooks += [hook()]
50-
controller_params['hook_class'] = controller_params.get('hook_class', self.hook_classes)
47+
hook_classes += user_hooks if type(user_hooks) == list else [user_hooks]
48+
[self.add_hook(hook) for hook in hook_classes]
49+
controller_params['hook_class'] = controller_params.get('hook_class', hook_classes)
5150

5251
for hook in self.hooks:
5352
hook.pre_setup(step=None, level_number=None)
@@ -107,6 +106,20 @@ def __setup_custom_logger(level=None, log_to_file=None, fname=None):
107106
else:
108107
pass
109108

109+
def add_hook(self, hook):
110+
"""
111+
Add a hook to the controller which will be called in addition to all other hooks whenever something happens.
112+
The hook is only added if a hook of the same class is not already present.
113+
114+
Args:
115+
hook (pySDC.Hook): A hook class that is derived from the core hook class
116+
117+
Returns:
118+
None
119+
"""
120+
if hook not in [type(me) for me in self.hooks]:
121+
self.__hooks += [hook()]
122+
110123
def welcome_message(self):
111124
out = (
112125
"Welcome to the one and only, really very astonishing and 87.3% bug free"

pySDC/core/Hooks.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import logging
2-
import time
32
from collections import namedtuple
43

54

@@ -9,20 +8,6 @@ class hooks(object):
98
Hook class to contain the functions called during the controller runs (e.g. for calling user-routines)
109
1110
Attributes:
12-
__t0_setup (float): private variable to get starting time of setup
13-
__t0_run (float): private variable to get starting time of the run
14-
__t0_predict (float): private variable to get starting time of the predictor
15-
__t0_step (float): private variable to get starting time of the step
16-
__t0_iteration (float): private variable to get starting time of the iteration
17-
__t0_sweep (float): private variable to get starting time of the sweep
18-
__t0_comm (list): private variable to get starting time of the communication
19-
__t1_run (float): private variable to get end time of the run
20-
__t1_predict (float): private variable to get end time of the predictor
21-
__t1_step (float): private variable to get end time of the step
22-
__t1_iteration (float): private variable to get end time of the iteration
23-
__t1_sweep (float): private variable to get end time of the sweep
24-
__t1_setup (float): private variable to get end time of setup
25-
__t1_comm (list): private variable to hold timing of the communication (!)
2611
__num_restarts (int): number of restarts of the current step
2712
logger: logger instance for output
2813
__stats (dict): dictionary for gathering the statistics of a run

pySDC/implementations/controller_classes/controller_MPI.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -807,7 +807,7 @@ def it_up(self, comm, num_procs):
807807
self.S.levels[l - 1].sweep.update_nodes()
808808
self.S.levels[l - 1].sweep.compute_residual()
809809
for hook in self.hooks:
810-
hooks.post_sweep(step=self.S, level_number=l - 1)
810+
hook.post_sweep(step=self.S, level_number=l - 1)
811811

812812
# update stage
813813
self.S.status.stage = 'IT_FINE'

pySDC/implementations/convergence_controller_classes/estimate_embedded_error.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from pySDC.core.ConvergenceController import ConvergenceController, Pars
44
from pySDC.implementations.convergence_controller_classes.store_uold import StoreUOld
5+
from pySDC.implementations.hooks.log_embedded_error_estimate import log_embedded_error_estimate
56

67
from pySDC.implementations.sweeper_classes.Runge_Kutta import RungeKutta
78

@@ -16,7 +17,7 @@ class EstimateEmbeddedError(ConvergenceController):
1617

1718
def __init__(self, controller, params, description, **kwargs):
1819
"""
19-
Initalization routine. Add the buffers for communication.
20+
Initialisation routine. Add the buffers for communication.
2021
2122
Args:
2223
controller (pySDC.Controller): The controller
@@ -25,6 +26,7 @@ def __init__(self, controller, params, description, **kwargs):
2526
"""
2627
super(EstimateEmbeddedError, self).__init__(controller, params, description, **kwargs)
2728
self.buffers = Pars({'e_em_last': 0.0})
29+
controller.add_hook(log_embedded_error_estimate)
2830

2931
@classmethod
3032
def get_implementation(cls, flavor):
@@ -123,7 +125,7 @@ def reset_status_variables(self, controller, **kwargs):
123125
class EstimateEmbeddedErrorNonMPI(EstimateEmbeddedError):
124126
def reset_buffers_nonMPI(self, controller, **kwargs):
125127
"""
126-
Reset buffers for immitated communication.
128+
Reset buffers for imitated communication.
127129
128130
Args:
129131
controller (pySDC.controller): The controller

pySDC/implementations/convergence_controller_classes/estimate_extrapolation_error.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from pySDC.core.ConvergenceController import ConvergenceController, Status
55
from pySDC.core.Errors import DataError
66
from pySDC.implementations.datatype_classes.mesh import mesh, imex_mesh
7+
from pySDC.implementations.hooks.log_extrapolated_error_estimate import log_extrapolated_error_estimate
78

89

910
class EstimateExtrapolationErrorBase(ConvergenceController):
@@ -27,6 +28,7 @@ def __init__(self, controller, params, description, **kwargs):
2728
self.prev = Status(["t", "u", "f", "dt"]) # store solutions etc. of previous steps here
2829
self.coeff = Status(["u", "f", "prefactor"]) # store coefficients for extrapolation here
2930
super(EstimateExtrapolationErrorBase, self).__init__(controller, params, description)
31+
controller.add_hook(log_extrapolated_error_estimate)
3032

3133
def setup(self, controller, params, description, **kwargs):
3234
"""

pySDC/projects/PinTSimE/switch_estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def get_new_step_size(self, controller, S):
9595
dt_search = self.t_switch - L.time
9696
L.prob.params.set_switch[self.count_switches] = self.switch_detected
9797
L.prob.params.t_switch[self.count_switches] = self.t_switch
98-
controller.hooks.add_to_stats(
98+
controller.hooks[0].add_to_stats(
9999
process=S.status.slot,
100100
time=L.time,
101101
level=L.level_index,

pySDC/projects/Resilience/accuracy_check.py

Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -34,24 +34,6 @@ def post_step(self, step, level_number):
3434

3535
L.sweep.compute_end_point()
3636

37-
self.add_to_stats(
38-
process=step.status.slot,
39-
time=L.time + L.dt,
40-
level=L.level_index,
41-
iter=0,
42-
sweep=L.status.sweep,
43-
type='e_embedded',
44-
value=L.status.error_embedded_estimate,
45-
)
46-
self.add_to_stats(
47-
process=step.status.slot,
48-
time=L.time + L.dt,
49-
level=L.level_index,
50-
iter=0,
51-
sweep=L.status.sweep,
52-
type='e_extrapolated',
53-
value=L.status.get('error_extrapolation_estimate'),
54-
)
5537
self.add_to_stats(
5638
process=step.status.slot,
5739
time=L.time,
@@ -105,8 +87,8 @@ def get_results_from_stats(stats, var, val, hook_class=log_errors):
10587
}
10688

10789
if hook_class == log_errors:
108-
e_extrapolated = np.array(get_sorted(stats, type='e_extrapolated'))[:, 1]
109-
e_embedded = np.array(get_sorted(stats, type='e_embedded'))[:, 1]
90+
e_extrapolated = np.array(get_sorted(stats, type='error_extrapolation_estimate'))[:, 1]
91+
e_embedded = np.array(get_sorted(stats, type='error_embedded_estimate'))[:, 1]
11092
e_loc = np.array(get_sorted(stats, type='e_loc'))[:, 1]
11193

11294
if len(e_extrapolated[e_extrapolated != [None]]) > 0:

pySDC/projects/Resilience/advection.py

Lines changed: 3 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from pySDC.core.Hooks import hooks
77
from pySDC.helpers.stats_helper import get_sorted
88
import numpy as np
9-
from pySDC.projects.Resilience.hook import log_error_estimates
9+
from pySDC.projects.Resilience.hook import log_data
1010

1111

1212
def plot_embedded(stats, ax):
@@ -21,17 +21,8 @@ def plot_embedded(stats, ax):
2121
ax.legend(frameon=False)
2222

2323

24-
class log_data(hooks):
25-
def pre_run(self, step, level_number):
26-
"""
27-
Record los conditiones initiales
28-
"""
29-
super(log_data, self).pre_run(step, level_number)
30-
L = step.levels[level_number]
31-
self.add_to_stats(process=0, time=0, level=0, iter=0, sweep=0, type='u0', value=L.u[0])
32-
24+
class log_every_iteration(hooks):
3325
def post_iteration(self, step, level_number):
34-
super(log_data, self).post_iteration(step, level_number)
3526
if step.status.iter == step.params.maxiter - 1:
3627
L = step.levels[level_number]
3728
L.sweep.compute_end_point()
@@ -45,58 +36,12 @@ def post_iteration(self, step, level_number):
4536
value=L.uold[-1],
4637
)
4738

48-
def post_step(self, step, level_number):
49-
50-
super(log_data, self).post_step(step, level_number)
51-
52-
# some abbreviations
53-
L = step.levels[level_number]
54-
55-
L.sweep.compute_end_point()
56-
57-
self.add_to_stats(
58-
process=step.status.slot,
59-
time=L.time + L.dt,
60-
level=L.level_index,
61-
iter=0,
62-
sweep=L.status.sweep,
63-
type='u',
64-
value=L.uend,
65-
)
66-
self.add_to_stats(
67-
process=step.status.slot,
68-
time=L.time,
69-
level=L.level_index,
70-
iter=0,
71-
sweep=L.status.sweep,
72-
type='dt',
73-
value=L.dt,
74-
)
75-
self.add_to_stats(
76-
process=step.status.slot,
77-
time=L.time + L.dt,
78-
level=L.level_index,
79-
iter=0,
80-
sweep=L.status.sweep,
81-
type='e_embedded',
82-
value=L.status.get('error_embedded_estimate'),
83-
)
84-
self.add_to_stats(
85-
process=step.status.slot,
86-
time=L.time + L.dt,
87-
level=L.level_index,
88-
iter=0,
89-
sweep=L.status.sweep,
90-
type='e_extrapolated',
91-
value=L.status.get('error_extrapolation_estimate'),
92-
)
93-
9439

9540
def run_advection(
9641
custom_description=None,
9742
num_procs=1,
9843
Tend=2e-1,
99-
hook_class=log_error_estimates,
44+
hook_class=log_data,
10045
fault_stuff=None,
10146
custom_controller_params=None,
10247
custom_problem_params=None,

pySDC/projects/Resilience/heat.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,15 @@
55
from pySDC.implementations.controller_classes.controller_nonMPI import controller_nonMPI
66
from pySDC.core.Hooks import hooks
77
from pySDC.helpers.stats_helper import get_sorted
8-
from pySDC.projects.Resilience.hook import log_error_estimates
8+
from pySDC.projects.Resilience.hook import log_data
99
import numpy as np
1010

1111

1212
def run_heat(
1313
custom_description=None,
1414
num_procs=1,
1515
Tend=2e-1,
16-
hook_class=log_error_estimates,
16+
hook_class=log_data,
1717
fault_stuff=None,
1818
custom_controller_params=None,
1919
custom_problem_params=None,

pySDC/projects/Resilience/hook.py

Lines changed: 7 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,13 @@
11
from pySDC.core.Hooks import hooks
2+
from pySDC.implementations.hooks.log_solution import log_solution
3+
from pySDC.implementations.hooks.log_embedded_error_estimate import log_embedded_error_estimate
4+
from pySDC.implementations.hooks.log_extrapolated_error_estimate import log_extrapolated_error_estimate
25

36

4-
class log_error_estimates(hooks):
7+
hook_collection = [log_solution, log_embedded_error_estimate, log_extrapolated_error_estimate]
8+
9+
10+
class log_data(hooks):
511
"""
612
Record data required for analysis of problems in the resilience project
713
"""
@@ -10,30 +16,18 @@ def pre_run(self, step, level_number):
1016
"""
1117
Record los conditiones initiales
1218
"""
13-
super(log_error_estimates, self).pre_run(step, level_number)
1419
L = step.levels[level_number]
1520
self.add_to_stats(process=0, time=0, level=0, iter=0, sweep=0, type='u0', value=L.u[0])
1621

1722
def post_step(self, step, level_number):
1823
"""
1924
Record final solutions as well as step size and error estimates
2025
"""
21-
super(log_error_estimates, self).post_step(step, level_number)
22-
2326
# some abbreviations
2427
L = step.levels[level_number]
2528

2629
L.sweep.compute_end_point()
2730

28-
self.add_to_stats(
29-
process=step.status.slot,
30-
time=L.time + L.dt,
31-
level=L.level_index,
32-
iter=0,
33-
sweep=L.status.sweep,
34-
type='u',
35-
value=L.uend,
36-
)
3731
self.add_to_stats(
3832
process=step.status.slot,
3933
time=L.time,
@@ -43,24 +37,6 @@ def post_step(self, step, level_number):
4337
type='dt',
4438
value=L.dt,
4539
)
46-
self.add_to_stats(
47-
process=step.status.slot,
48-
time=L.time + L.dt,
49-
level=L.level_index,
50-
iter=0,
51-
sweep=L.status.sweep,
52-
type='e_embedded',
53-
value=L.status.__dict__.get('error_embedded_estimate', None),
54-
)
55-
self.add_to_stats(
56-
process=step.status.slot,
57-
time=L.time + L.dt,
58-
level=L.level_index,
59-
iter=0,
60-
sweep=L.status.sweep,
61-
type='e_extrapolated',
62-
value=L.status.__dict__.get('error_extrapolation_estimate', None),
63-
)
6440
self.add_to_stats(
6541
process=step.status.slot,
6642
time=L.time,

0 commit comments

Comments
 (0)