Skip to content

Commit 92a9a1e

Browse files
authored
Merge pull request #246 from brownbaerchen/controller_restructure
Multiple Hooks simultaneously
2 parents 1de8d2f + 66cc5bc commit 92a9a1e

21 files changed

+699
-408
lines changed

pySDC/core/Controller.py

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33
import sys
44
import numpy as np
55

6-
from pySDC.core import Hooks as hookclass
76
from pySDC.core.BaseTransfer import base_transfer
87
from pySDC.helpers.pysdc_helper import FrozenClass
98
from pySDC.implementations.convergence_controller_classes.check_convergence import CheckConvergence
9+
from pySDC.implementations.hooks.default_hook import DefaultHooks
1010

1111

1212
# short helper class to add params as attributes
@@ -41,10 +41,15 @@ def __init__(self, controller_params, description):
4141
"""
4242

4343
# check if we have a hook on this list. If not, use default class.
44-
controller_params['hook_class'] = controller_params.get('hook_class', hookclass.hooks)
45-
self.__hooks = controller_params['hook_class']()
44+
self.__hooks = []
45+
hook_classes = [DefaultHooks]
46+
user_hooks = controller_params.get('hook_class', [])
47+
hook_classes += user_hooks if type(user_hooks) == list else [user_hooks]
48+
[self.add_hook(hook) for hook in hook_classes]
49+
controller_params['hook_class'] = controller_params.get('hook_class', hook_classes)
4650

47-
self.hooks.pre_setup(step=None, level_number=None)
51+
for hook in self.hooks:
52+
hook.pre_setup(step=None, level_number=None)
4853

4954
self.params = _Pars(controller_params)
5055

@@ -101,6 +106,20 @@ def __setup_custom_logger(level=None, log_to_file=None, fname=None):
101106
else:
102107
pass
103108

109+
def add_hook(self, hook):
110+
"""
111+
Add a hook to the controller which will be called in addition to all other hooks whenever something happens.
112+
The hook is only added if a hook of the same class is not already present.
113+
114+
Args:
115+
hook (pySDC.Hook): A hook class that is derived from the core hook class
116+
117+
Returns:
118+
None
119+
"""
120+
if hook not in [type(me) for me in self.hooks]:
121+
self.__hooks += [hook()]
122+
104123
def welcome_message(self):
105124
out = (
106125
"Welcome to the one and only, really very astonishing and 87.3% bug free"
@@ -308,3 +327,15 @@ def get_convergence_controllers_as_table(self, description):
308327
out += f'\n{user_added}|{i:3} | {C.params.control_order:5} | {type(C).__name__}'
309328

310329
return out
330+
331+
def return_stats(self):
332+
"""
333+
Return the merged stats from all hooks
334+
335+
Returns:
336+
dict: Merged stats from all hooks
337+
"""
338+
stats = {}
339+
for hook in self.hooks:
340+
stats = {**stats, **hook.return_stats()}
341+
return stats

pySDC/core/Hooks.py

Lines changed: 6 additions & 193 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import logging
2-
import time
32
from collections import namedtuple
43

54

@@ -8,23 +7,13 @@ class hooks(object):
87
"""
98
Hook class to contain the functions called during the controller runs (e.g. for calling user-routines)
109
10+
When deriving a custom hook from this class make sure to always call the parent method using e.g.
11+
`super().post_step(step, level_number)`. Otherwise bugs may arise when using `filer_recomputed` from the stats
12+
helper for post processing.
13+
1114
Attributes:
12-
__t0_setup (float): private variable to get starting time of setup
13-
__t0_run (float): private variable to get starting time of the run
14-
__t0_predict (float): private variable to get starting time of the predictor
15-
__t0_step (float): private variable to get starting time of the step
16-
__t0_iteration (float): private variable to get starting time of the iteration
17-
__t0_sweep (float): private variable to get starting time of the sweep
18-
__t0_comm (list): private variable to get starting time of the communication
19-
__t1_run (float): private variable to get end time of the run
20-
__t1_predict (float): private variable to get end time of the predictor
21-
__t1_step (float): private variable to get end time of the step
22-
__t1_iteration (float): private variable to get end time of the iteration
23-
__t1_sweep (float): private variable to get end time of the sweep
24-
__t1_setup (float): private variable to get end time of setup
25-
__t1_comm (list): private variable to hold timing of the communication (!)
26-
__num_restarts (int): number of restarts of the current step
2715
logger: logger instance for output
16+
__num_restarts (int): number of restarts of the current step
2817
__stats (dict): dictionary for gathering the statistics of a run
2918
__entry (namedtuple): statistics entry containing all information to identify the value
3019
"""
@@ -33,20 +22,6 @@ def __init__(self):
3322
"""
3423
Initialization routine
3524
"""
36-
self.__t0_setup = None
37-
self.__t0_run = None
38-
self.__t0_predict = None
39-
self.__t0_step = None
40-
self.__t0_iteration = None
41-
self.__t0_sweep = None
42-
self.__t0_comm = []
43-
self.__t1_run = None
44-
self.__t1_predict = None
45-
self.__t1_step = None
46-
self.__t1_iteration = None
47-
self.__t1_sweep = None
48-
self.__t1_setup = None
49-
self.__t1_comm = []
5025
self.__num_restarts = 0
5126

5227
self.logger = logging.getLogger('hooks')
@@ -130,7 +105,6 @@ def pre_setup(self, step, level_number):
130105
level_number (int): the current level number
131106
"""
132107
self.__num_restarts = step.status.get('restarts_in_a_row') if step is not None else 0
133-
self.__t0_setup = time.perf_counter()
134108

135109
def pre_run(self, step, level_number):
136110
"""
@@ -141,7 +115,6 @@ def pre_run(self, step, level_number):
141115
level_number (int): the current level number
142116
"""
143117
self.__num_restarts = step.status.get('restarts_in_a_row') if step is not None else 0
144-
self.__t0_run = time.perf_counter()
145118

146119
def pre_predict(self, step, level_number):
147120
"""
@@ -151,7 +124,7 @@ def pre_predict(self, step, level_number):
151124
step (pySDC.Step.step): the current step
152125
level_number (int): the current level number
153126
"""
154-
self.__t0_predict = time.perf_counter()
127+
self.__num_restarts = step.status.get('restarts_in_a_row') if step is not None else 0
155128

156129
def pre_step(self, step, level_number):
157130
"""
@@ -162,7 +135,6 @@ def pre_step(self, step, level_number):
162135
level_number (int): the current level number
163136
"""
164137
self.__num_restarts = step.status.get('restarts_in_a_row') if step is not None else 0
165-
self.__t0_step = time.perf_counter()
166138

167139
def pre_iteration(self, step, level_number):
168140
"""
@@ -173,7 +145,6 @@ def pre_iteration(self, step, level_number):
173145
level_number (int): the current level number
174146
"""
175147
self.__num_restarts = step.status.get('restarts_in_a_row') if step is not None else 0
176-
self.__t0_iteration = time.perf_counter()
177148

178149
def pre_sweep(self, step, level_number):
179150
"""
@@ -184,7 +155,6 @@ def pre_sweep(self, step, level_number):
184155
level_number (int): the current level number
185156
"""
186157
self.__num_restarts = step.status.get('restarts_in_a_row') if step is not None else 0
187-
self.__t0_sweep = time.perf_counter()
188158

189159
def pre_comm(self, step, level_number):
190160
"""
@@ -195,16 +165,6 @@ def pre_comm(self, step, level_number):
195165
level_number (int): the current level number
196166
"""
197167
self.__num_restarts = step.status.get('restarts_in_a_row') if step is not None else 0
198-
if len(self.__t0_comm) >= level_number + 1:
199-
self.__t0_comm[level_number] = time.perf_counter()
200-
else:
201-
while len(self.__t0_comm) < level_number:
202-
self.__t0_comm.append(None)
203-
self.__t0_comm.append(time.perf_counter())
204-
while len(self.__t1_comm) <= level_number:
205-
self.__t1_comm.append(0.0)
206-
assert len(self.__t0_comm) == level_number + 1
207-
assert len(self.__t1_comm) == level_number + 1
208168

209169
def post_comm(self, step, level_number, add_to_stats=False):
210170
"""
@@ -216,22 +176,6 @@ def post_comm(self, step, level_number, add_to_stats=False):
216176
add_to_stats (bool): set if result should go to stats object
217177
"""
218178
self.__num_restarts = step.status.get('restarts_in_a_row') if step is not None else 0
219-
assert len(self.__t1_comm) >= level_number + 1
220-
self.__t1_comm[level_number] += time.perf_counter() - self.__t0_comm[level_number]
221-
222-
if add_to_stats:
223-
L = step.levels[level_number]
224-
225-
self.add_to_stats(
226-
process=step.status.slot,
227-
time=L.time,
228-
level=L.level_index,
229-
iter=step.status.iter,
230-
sweep=L.status.sweep,
231-
type='timing_comm',
232-
value=self.__t1_comm[level_number],
233-
)
234-
self.__t1_comm[level_number] = 0.0
235179

236180
def post_sweep(self, step, level_number):
237181
"""
@@ -242,39 +186,6 @@ def post_sweep(self, step, level_number):
242186
level_number (int): the current level number
243187
"""
244188
self.__num_restarts = step.status.get('restarts_in_a_row') if step is not None else 0
245-
self.__t1_sweep = time.perf_counter()
246-
247-
L = step.levels[level_number]
248-
249-
self.logger.info(
250-
'Process %2i on time %8.6f at stage %15s: Level: %s -- Iteration: %2i -- Sweep: %2i -- ' 'residual: %12.8e',
251-
step.status.slot,
252-
L.time,
253-
step.status.stage,
254-
L.level_index,
255-
step.status.iter,
256-
L.status.sweep,
257-
L.status.residual,
258-
)
259-
260-
self.add_to_stats(
261-
process=step.status.slot,
262-
time=L.time,
263-
level=L.level_index,
264-
iter=step.status.iter,
265-
sweep=L.status.sweep,
266-
type='residual_post_sweep',
267-
value=L.status.residual,
268-
)
269-
self.add_to_stats(
270-
process=step.status.slot,
271-
time=L.time,
272-
level=L.level_index,
273-
iter=step.status.iter,
274-
sweep=L.status.sweep,
275-
type='timing_sweep',
276-
value=self.__t1_sweep - self.__t0_sweep,
277-
)
278189

279190
def post_iteration(self, step, level_number):
280191
"""
@@ -286,29 +197,6 @@ def post_iteration(self, step, level_number):
286197
"""
287198
self.__num_restarts = step.status.get('restarts_in_a_row') if step is not None else 0
288199

289-
self.__t1_iteration = time.perf_counter()
290-
291-
L = step.levels[level_number]
292-
293-
self.add_to_stats(
294-
process=step.status.slot,
295-
time=L.time,
296-
level=-1,
297-
iter=step.status.iter,
298-
sweep=L.status.sweep,
299-
type='residual_post_iteration',
300-
value=L.status.residual,
301-
)
302-
self.add_to_stats(
303-
process=step.status.slot,
304-
time=L.time,
305-
level=L.level_index,
306-
iter=step.status.iter,
307-
sweep=L.status.sweep,
308-
type='timing_iteration',
309-
value=self.__t1_iteration - self.__t0_iteration,
310-
)
311-
312200
def post_step(self, step, level_number):
313201
"""
314202
Default routine called after each step or block
@@ -319,44 +207,6 @@ def post_step(self, step, level_number):
319207
"""
320208
self.__num_restarts = step.status.get('restarts_in_a_row') if step is not None else 0
321209

322-
self.__t1_step = time.perf_counter()
323-
324-
L = step.levels[level_number]
325-
326-
self.add_to_stats(
327-
process=step.status.slot,
328-
time=L.time,
329-
level=L.level_index,
330-
iter=step.status.iter,
331-
sweep=L.status.sweep,
332-
type='timing_step',
333-
value=self.__t1_step - self.__t0_step,
334-
)
335-
self.add_to_stats(
336-
process=step.status.slot,
337-
time=L.time,
338-
level=-1,
339-
iter=step.status.iter,
340-
sweep=L.status.sweep,
341-
type='niter',
342-
value=step.status.iter,
343-
)
344-
self.add_to_stats(
345-
process=step.status.slot,
346-
time=L.time,
347-
level=L.level_index,
348-
iter=-1,
349-
sweep=L.status.sweep,
350-
type='residual_post_step',
351-
value=L.status.residual,
352-
)
353-
354-
# record the recomputed quantities at weird positions to make sure there is only one value for each step
355-
for t in [L.time, L.time + L.dt]:
356-
self.add_to_stats(
357-
process=-1, time=t, level=-1, iter=-1, sweep=-1, type='_recomputed', value=step.status.get('restart')
358-
)
359-
360210
def post_predict(self, step, level_number):
361211
"""
362212
Default routine called after each predictor
@@ -366,19 +216,6 @@ def post_predict(self, step, level_number):
366216
level_number (int): the current level number
367217
"""
368218
self.__num_restarts = step.status.get('restarts_in_a_row') if step is not None else 0
369-
self.__t1_predict = time.perf_counter()
370-
371-
L = step.levels[level_number]
372-
373-
self.add_to_stats(
374-
process=step.status.slot,
375-
time=L.time,
376-
level=L.level_index,
377-
iter=step.status.iter,
378-
sweep=L.status.sweep,
379-
type='timing_predictor',
380-
value=self.__t1_predict - self.__t0_predict,
381-
)
382219

383220
def post_run(self, step, level_number):
384221
"""
@@ -389,19 +226,6 @@ def post_run(self, step, level_number):
389226
level_number (int): the current level number
390227
"""
391228
self.__num_restarts = step.status.get('restarts_in_a_row') if step is not None else 0
392-
self.__t1_run = time.perf_counter()
393-
394-
L = step.levels[level_number]
395-
396-
self.add_to_stats(
397-
process=step.status.slot,
398-
time=L.time,
399-
level=L.level_index,
400-
iter=step.status.iter,
401-
sweep=L.status.sweep,
402-
type='timing_run',
403-
value=self.__t1_run - self.__t0_run,
404-
)
405229

406230
def post_setup(self, step, level_number):
407231
"""
@@ -412,14 +236,3 @@ def post_setup(self, step, level_number):
412236
level_number (int): the current level number
413237
"""
414238
self.__num_restarts = step.status.get('restarts_in_a_row') if step is not None else 0
415-
self.__t1_setup = time.perf_counter()
416-
417-
self.add_to_stats(
418-
process=-1,
419-
time=-1,
420-
level=-1,
421-
iter=-1,
422-
sweep=-1,
423-
type='timing_setup',
424-
value=self.__t1_setup - self.__t0_setup,
425-
)

0 commit comments

Comments
 (0)