Skip to content

Commit c74a9fe

Browse files
author
Thomas Baumann
committed
Allowing multiple hook classes now
1 parent 1b0d898 commit c74a9fe

File tree

6 files changed

+478
-243
lines changed

6 files changed

+478
-243
lines changed

pySDC/core/Controller.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33
import sys
44
import numpy as np
55

6-
from pySDC.core import Hooks as hookclass
76
from pySDC.core.BaseTransfer import base_transfer
87
from pySDC.helpers.pysdc_helper import FrozenClass
98
from pySDC.implementations.convergence_controller_classes.check_convergence import CheckConvergence
9+
from pySDC.implementations.hooks.default_hook import default_hooks
1010

1111

1212
# short helper class to add params as attributes
@@ -41,10 +41,16 @@ def __init__(self, controller_params, description):
4141
"""
4242

4343
# check if we have a hook on this list. If not, use default class.
44-
controller_params['hook_class'] = controller_params.get('hook_class', hookclass.hooks)
45-
self.__hooks = controller_params['hook_class']()
44+
self.__hooks = []
45+
self.hook_classes = [default_hooks]
46+
user_hooks = controller_params.get('hook_class', [])
47+
self.hook_classes += user_hooks if type(user_hooks) == list else [user_hooks]
48+
for hook in self.hook_classes:
49+
self.__hooks += [hook()]
50+
controller_params['hook_class'] = controller_params.get('hook_class', self.hook_classes)
4651

47-
self.hooks.pre_setup(step=None, level_number=None)
52+
for hook in self.hooks:
53+
hook.pre_setup(step=None, level_number=None)
4854

4955
self.params = _Pars(controller_params)
5056

@@ -308,3 +314,15 @@ def get_convergence_controllers_as_table(self, description):
308314
out += f'\n{user_added}|{i:3} | {C.params.control_order:5} | {type(C).__name__}'
309315

310316
return out
317+
318+
def return_stats(self):
319+
"""
320+
Return the merged stats from all hooks
321+
322+
Returns:
323+
dict: Merged stats from all hooks
324+
"""
325+
stats = {}
326+
for hook in self.hooks:
327+
stats = {**stats, **hook.return_stats()}
328+
return stats

pySDC/core/Hooks.py

Lines changed: 1 addition & 177 deletions
Original file line numberDiff line numberDiff line change
@@ -33,20 +33,6 @@ def __init__(self):
3333
"""
3434
Initialization routine
3535
"""
36-
self.__t0_setup = None
37-
self.__t0_run = None
38-
self.__t0_predict = None
39-
self.__t0_step = None
40-
self.__t0_iteration = None
41-
self.__t0_sweep = None
42-
self.__t0_comm = []
43-
self.__t1_run = None
44-
self.__t1_predict = None
45-
self.__t1_step = None
46-
self.__t1_iteration = None
47-
self.__t1_sweep = None
48-
self.__t1_setup = None
49-
self.__t1_comm = []
5036
self.__num_restarts = 0
5137

5238
self.logger = logging.getLogger('hooks')
@@ -130,7 +116,6 @@ def pre_setup(self, step, level_number):
130116
level_number (int): the current level number
131117
"""
132118
self.__num_restarts = step.status.get('restarts_in_a_row') if step is not None else 0
133-
self.__t0_setup = time.perf_counter()
134119

135120
def pre_run(self, step, level_number):
136121
"""
@@ -141,7 +126,6 @@ def pre_run(self, step, level_number):
141126
level_number (int): the current level number
142127
"""
143128
self.__num_restarts = step.status.get('restarts_in_a_row') if step is not None else 0
144-
self.__t0_run = time.perf_counter()
145129

146130
def pre_predict(self, step, level_number):
147131
"""
@@ -151,7 +135,7 @@ def pre_predict(self, step, level_number):
151135
step (pySDC.Step.step): the current step
152136
level_number (int): the current level number
153137
"""
154-
self.__t0_predict = time.perf_counter()
138+
pass
155139

156140
def pre_step(self, step, level_number):
157141
"""
@@ -162,7 +146,6 @@ def pre_step(self, step, level_number):
162146
level_number (int): the current level number
163147
"""
164148
self.__num_restarts = step.status.get('restarts_in_a_row') if step is not None else 0
165-
self.__t0_step = time.perf_counter()
166149

167150
def pre_iteration(self, step, level_number):
168151
"""
@@ -173,7 +156,6 @@ def pre_iteration(self, step, level_number):
173156
level_number (int): the current level number
174157
"""
175158
self.__num_restarts = step.status.get('restarts_in_a_row') if step is not None else 0
176-
self.__t0_iteration = time.perf_counter()
177159

178160
def pre_sweep(self, step, level_number):
179161
"""
@@ -184,7 +166,6 @@ def pre_sweep(self, step, level_number):
184166
level_number (int): the current level number
185167
"""
186168
self.__num_restarts = step.status.get('restarts_in_a_row') if step is not None else 0
187-
self.__t0_sweep = time.perf_counter()
188169

189170
def pre_comm(self, step, level_number):
190171
"""
@@ -195,16 +176,6 @@ def pre_comm(self, step, level_number):
195176
level_number (int): the current level number
196177
"""
197178
self.__num_restarts = step.status.get('restarts_in_a_row') if step is not None else 0
198-
if len(self.__t0_comm) >= level_number + 1:
199-
self.__t0_comm[level_number] = time.perf_counter()
200-
else:
201-
while len(self.__t0_comm) < level_number:
202-
self.__t0_comm.append(None)
203-
self.__t0_comm.append(time.perf_counter())
204-
while len(self.__t1_comm) <= level_number:
205-
self.__t1_comm.append(0.0)
206-
assert len(self.__t0_comm) == level_number + 1
207-
assert len(self.__t1_comm) == level_number + 1
208179

209180
def post_comm(self, step, level_number, add_to_stats=False):
210181
"""
@@ -216,22 +187,6 @@ def post_comm(self, step, level_number, add_to_stats=False):
216187
add_to_stats (bool): set if result should go to stats object
217188
"""
218189
self.__num_restarts = step.status.get('restarts_in_a_row') if step is not None else 0
219-
assert len(self.__t1_comm) >= level_number + 1
220-
self.__t1_comm[level_number] += time.perf_counter() - self.__t0_comm[level_number]
221-
222-
if add_to_stats:
223-
L = step.levels[level_number]
224-
225-
self.add_to_stats(
226-
process=step.status.slot,
227-
time=L.time,
228-
level=L.level_index,
229-
iter=step.status.iter,
230-
sweep=L.status.sweep,
231-
type='timing_comm',
232-
value=self.__t1_comm[level_number],
233-
)
234-
self.__t1_comm[level_number] = 0.0
235190

236191
def post_sweep(self, step, level_number):
237192
"""
@@ -242,39 +197,6 @@ def post_sweep(self, step, level_number):
242197
level_number (int): the current level number
243198
"""
244199
self.__num_restarts = step.status.get('restarts_in_a_row') if step is not None else 0
245-
self.__t1_sweep = time.perf_counter()
246-
247-
L = step.levels[level_number]
248-
249-
self.logger.info(
250-
'Process %2i on time %8.6f at stage %15s: Level: %s -- Iteration: %2i -- Sweep: %2i -- ' 'residual: %12.8e',
251-
step.status.slot,
252-
L.time,
253-
step.status.stage,
254-
L.level_index,
255-
step.status.iter,
256-
L.status.sweep,
257-
L.status.residual,
258-
)
259-
260-
self.add_to_stats(
261-
process=step.status.slot,
262-
time=L.time,
263-
level=L.level_index,
264-
iter=step.status.iter,
265-
sweep=L.status.sweep,
266-
type='residual_post_sweep',
267-
value=L.status.residual,
268-
)
269-
self.add_to_stats(
270-
process=step.status.slot,
271-
time=L.time,
272-
level=L.level_index,
273-
iter=step.status.iter,
274-
sweep=L.status.sweep,
275-
type='timing_sweep',
276-
value=self.__t1_sweep - self.__t0_sweep,
277-
)
278200

279201
def post_iteration(self, step, level_number):
280202
"""
@@ -286,29 +208,6 @@ def post_iteration(self, step, level_number):
286208
"""
287209
self.__num_restarts = step.status.get('restarts_in_a_row') if step is not None else 0
288210

289-
self.__t1_iteration = time.perf_counter()
290-
291-
L = step.levels[level_number]
292-
293-
self.add_to_stats(
294-
process=step.status.slot,
295-
time=L.time,
296-
level=-1,
297-
iter=step.status.iter,
298-
sweep=L.status.sweep,
299-
type='residual_post_iteration',
300-
value=L.status.residual,
301-
)
302-
self.add_to_stats(
303-
process=step.status.slot,
304-
time=L.time,
305-
level=L.level_index,
306-
iter=step.status.iter,
307-
sweep=L.status.sweep,
308-
type='timing_iteration',
309-
value=self.__t1_iteration - self.__t0_iteration,
310-
)
311-
312211
def post_step(self, step, level_number):
313212
"""
314213
Default routine called after each step or block
@@ -319,44 +218,6 @@ def post_step(self, step, level_number):
319218
"""
320219
self.__num_restarts = step.status.get('restarts_in_a_row') if step is not None else 0
321220

322-
self.__t1_step = time.perf_counter()
323-
324-
L = step.levels[level_number]
325-
326-
self.add_to_stats(
327-
process=step.status.slot,
328-
time=L.time,
329-
level=L.level_index,
330-
iter=step.status.iter,
331-
sweep=L.status.sweep,
332-
type='timing_step',
333-
value=self.__t1_step - self.__t0_step,
334-
)
335-
self.add_to_stats(
336-
process=step.status.slot,
337-
time=L.time,
338-
level=-1,
339-
iter=step.status.iter,
340-
sweep=L.status.sweep,
341-
type='niter',
342-
value=step.status.iter,
343-
)
344-
self.add_to_stats(
345-
process=step.status.slot,
346-
time=L.time,
347-
level=L.level_index,
348-
iter=-1,
349-
sweep=L.status.sweep,
350-
type='residual_post_step',
351-
value=L.status.residual,
352-
)
353-
354-
# record the recomputed quantities at weird positions to make sure there is only one value for each step
355-
for t in [L.time, L.time + L.dt]:
356-
self.add_to_stats(
357-
process=-1, time=t, level=-1, iter=-1, sweep=-1, type='_recomputed', value=step.status.get('restart')
358-
)
359-
360221
def post_predict(self, step, level_number):
361222
"""
362223
Default routine called after each predictor
@@ -366,19 +227,6 @@ def post_predict(self, step, level_number):
366227
level_number (int): the current level number
367228
"""
368229
self.__num_restarts = step.status.get('restarts_in_a_row') if step is not None else 0
369-
self.__t1_predict = time.perf_counter()
370-
371-
L = step.levels[level_number]
372-
373-
self.add_to_stats(
374-
process=step.status.slot,
375-
time=L.time,
376-
level=L.level_index,
377-
iter=step.status.iter,
378-
sweep=L.status.sweep,
379-
type='timing_predictor',
380-
value=self.__t1_predict - self.__t0_predict,
381-
)
382230

383231
def post_run(self, step, level_number):
384232
"""
@@ -389,19 +237,6 @@ def post_run(self, step, level_number):
389237
level_number (int): the current level number
390238
"""
391239
self.__num_restarts = step.status.get('restarts_in_a_row') if step is not None else 0
392-
self.__t1_run = time.perf_counter()
393-
394-
L = step.levels[level_number]
395-
396-
self.add_to_stats(
397-
process=step.status.slot,
398-
time=L.time,
399-
level=L.level_index,
400-
iter=step.status.iter,
401-
sweep=L.status.sweep,
402-
type='timing_run',
403-
value=self.__t1_run - self.__t0_run,
404-
)
405240

406241
def post_setup(self, step, level_number):
407242
"""
@@ -412,14 +247,3 @@ def post_setup(self, step, level_number):
412247
level_number (int): the current level number
413248
"""
414249
self.__num_restarts = step.status.get('restarts_in_a_row') if step is not None else 0
415-
self.__t1_setup = time.perf_counter()
416-
417-
self.add_to_stats(
418-
process=-1,
419-
time=-1,
420-
level=-1,
421-
iter=-1,
422-
sweep=-1,
423-
type='timing_setup',
424-
value=self.__t1_setup - self.__t0_setup,
425-
)

0 commit comments

Comments
 (0)