|
3 | 3 | import sys |
4 | 4 | import numpy as np |
5 | 5 |
|
6 | | -from pySDC.core import Hooks as hookclass |
7 | 6 | from pySDC.core.BaseTransfer import base_transfer |
8 | 7 | from pySDC.helpers.pysdc_helper import FrozenClass |
9 | 8 | from pySDC.implementations.convergence_controller_classes.check_convergence import CheckConvergence |
| 9 | +from pySDC.implementations.hooks.default_hook import DefaultHooks |
10 | 10 |
|
11 | 11 |
|
12 | 12 | # short helper class to add params as attributes |
@@ -41,10 +41,15 @@ def __init__(self, controller_params, description): |
41 | 41 | """ |
42 | 42 |
|
43 | 43 | # check if we have a hook on this list. If not, use default class. |
44 | | - controller_params['hook_class'] = controller_params.get('hook_class', hookclass.hooks) |
45 | | - self.__hooks = controller_params['hook_class']() |
| 44 | + self.__hooks = [] |
| 45 | + hook_classes = [DefaultHooks] |
| 46 | + user_hooks = controller_params.get('hook_class', []) |
| 47 | + hook_classes += user_hooks if type(user_hooks) == list else [user_hooks] |
| 48 | + [self.add_hook(hook) for hook in hook_classes] |
| 49 | + controller_params['hook_class'] = controller_params.get('hook_class', hook_classes) |
46 | 50 |
|
47 | | - self.hooks.pre_setup(step=None, level_number=None) |
| 51 | + for hook in self.hooks: |
| 52 | + hook.pre_setup(step=None, level_number=None) |
48 | 53 |
|
49 | 54 | self.params = _Pars(controller_params) |
50 | 55 |
|
@@ -101,6 +106,20 @@ def __setup_custom_logger(level=None, log_to_file=None, fname=None): |
101 | 106 | else: |
102 | 107 | pass |
103 | 108 |
|
| 109 | + def add_hook(self, hook): |
| 110 | + """ |
| 111 | + Add a hook to the controller which will be called in addition to all other hooks whenever something happens. |
| 112 | + The hook is only added if a hook of the same class is not already present. |
| 113 | +
|
| 114 | + Args: |
| 115 | + hook (pySDC.Hook): A hook class that is derived from the core hook class |
| 116 | +
|
| 117 | + Returns: |
| 118 | + None |
| 119 | + """ |
| 120 | + if hook not in [type(me) for me in self.hooks]: |
| 121 | + self.__hooks += [hook()] |
| 122 | + |
104 | 123 | def welcome_message(self): |
105 | 124 | out = ( |
106 | 125 | "Welcome to the one and only, really very astonishing and 87.3% bug free" |
@@ -308,3 +327,15 @@ def get_convergence_controllers_as_table(self, description): |
308 | 327 | out += f'\n{user_added}|{i:3} | {C.params.control_order:5} | {type(C).__name__}' |
309 | 328 |
|
310 | 329 | return out |
| 330 | + |
| 331 | + def return_stats(self): |
| 332 | + """ |
| 333 | + Return the merged stats from all hooks |
| 334 | +
|
| 335 | + Returns: |
| 336 | + dict: Merged stats from all hooks |
| 337 | + """ |
| 338 | + stats = {} |
| 339 | + for hook in self.hooks: |
| 340 | + stats = {**stats, **hook.return_stats()} |
| 341 | + return stats |
0 commit comments