Skip to content

Commit e372a43

Browse files
Tweaks for FrozenClass (#437)
* Allow to add variables to frozen classes via class attribute * Bugfix * Added default value to uninitialized variables * Small cleanup
1 parent 5491e18 commit e372a43

File tree

9 files changed

+85
-200
lines changed

9 files changed

+85
-200
lines changed

pySDC/core/ConvergenceController.py

Lines changed: 32 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def __init__(self, controller, params, description, **kwargs):
4242
params (dict): The params passed for this specific convergence controller
4343
description (dict): The description object used to instantiate the controller
4444
"""
45+
self.controller = controller
4546
self.params = Pars(self.setup(controller, params, description))
4647
params_ok, msg = self.check_parameters(controller, params, description)
4748
assert params_ok, f'{type(self).__name__} -- {msg}'
@@ -425,94 +426,43 @@ def Recv(self, comm, source, buffer, **kwargs):
425426

426427
return data
427428

428-
def reset_variable(self, controller, name, MPI=False, place=None, where=None, init=None):
429-
"""
430-
Utility function for resetting variables. This function will call the `add_variable` function with all the same
431-
arguments, but with `allow_overwrite = True`.
432-
433-
Args:
434-
controller (pySDC.Controller): The controller
435-
name (str): The name of the variable
436-
MPI (bool): Whether to use MPI controller
437-
place (object): The object you want to reset the variable of
438-
where (list): List of strings containing a path to where you want to reset the variable
439-
init: Initial value of the variable
440-
441-
Returns:
442-
None
443-
"""
444-
self.add_variable(controller, name, MPI, place, where, init, allow_overwrite=True)
429+
def add_status_variable_to_step(self, key, value=None):
430+
if type(self.controller).__name__ == 'controller_MPI':
431+
steps = [self.controller.S]
432+
else:
433+
steps = self.controller.MS
445434

446-
def add_variable(self, controller, name, MPI=False, place=None, where=None, init=None, allow_overwrite=False):
447-
"""
448-
Add a variable to a frozen class.
435+
steps[0].status.add_attr(key)
449436

450-
This function goes through the path to the destination of the variable recursively and adds it to all instances
451-
that are possible in the path. For example, giving `where = ["MS", "levels", "status"]` will result in adding a
452-
variable to the status object of all levels of all steps of the controller.
437+
if value is not None:
438+
self.set_step_status_variable(key, value)
453439

454-
Part of the functionality of the frozen class is to separate initialization and setting of variables. By
455-
enforcing this, you can make sure not to overwrite already existing variables. Since this function is called
456-
outside of the `__init__` function of the status objects, this can otherwise lead to bugs that are hard to find.
457-
For this reason, you need to specifically set `allow_overwrite = True` if you want to forgo the check if the
458-
variable already exists. This can be useful when resetting variables between steps, but make sure to set it to
459-
`allow_overwrite = False` the first time you add a variable.
440+
def set_step_status_variable(self, key, value):
441+
if type(self.controller).__name__ == 'controller_MPI':
442+
steps = [self.controller.S]
443+
else:
444+
steps = self.controller.MS
460445

461-
Args:
462-
controller (pySDC.Controller): The controller
463-
name (str): The name of the variable
464-
MPI (bool): Whether to use MPI controller
465-
place (object): The object you want to add the variable to
466-
where (list): List of strings containing a path to where you want to add the variable
467-
init: Initial value of the variable
468-
allow_overwrite (bool): Allow overwriting the variables if they already exist or raise an exception
446+
for S in steps:
447+
S.status.__dict__[key] = value
469448

470-
Returns:
471-
None
472-
"""
473-
where = ["S" if MPI else "MS", "levels", "status"] if where is None else where
474-
place = controller if place is None else place
449+
def add_status_variable_to_level(self, key, value=None):
450+
if type(self.controller).__name__ == 'controller_MPI':
451+
steps = [self.controller.S]
452+
else:
453+
steps = self.controller.MS
475454

476-
# check if we have arrived at the end of the path to the variable
477-
if len(where) == 0:
478-
variable_exitsts = name in place.__dict__.keys()
479-
# check if the variable already exists and raise an error in case we are about to introduce a bug
480-
if not allow_overwrite and variable_exitsts:
481-
raise ValueError(f"Key \"{name}\" already exists in {place}! Please rename the variable in {self}")
482-
# if we allow overwriting, but the variable does not exist already, we are violating the intended purpose
483-
# of this function, so we also raise an error if someone should be so mad as to attempt this
484-
elif allow_overwrite and not variable_exitsts:
485-
raise ValueError(f"Key \"{name}\" is supposed to be overwritten in {place}, but it does not exist!")
455+
steps[0].levels[0].status.add_attr(key)
486456

487-
# actually add or overwrite the variable
488-
place.__dict__[name] = init
457+
if value is not None:
458+
self.set_level_status_variable(key, value)
489459

490-
# follow the path to the final destination recursively
460+
def set_level_status_variable(self, key, value):
461+
if type(self.controller).__name__ == 'controller_MPI':
462+
steps = [self.controller.S]
491463
else:
492-
# get all possible new places to continue the path
493-
new_places = place.__dict__[where[0]]
494-
495-
# continue all possible paths
496-
if type(new_places) == list:
497-
# loop through all possibilities
498-
for new_place in new_places:
499-
self.add_variable(
500-
controller,
501-
name,
502-
MPI=MPI,
503-
place=new_place,
504-
where=where[1:],
505-
init=init,
506-
allow_overwrite=allow_overwrite,
507-
)
508-
else:
509-
# go to the only possible possibility
510-
self.add_variable(
511-
controller,
512-
name,
513-
MPI=MPI,
514-
place=new_places,
515-
where=where[1:],
516-
init=init,
517-
allow_overwrite=allow_overwrite,
518-
)
464+
steps = self.controller.MS
465+
466+
for S in steps:
467+
for L in S.levels:
468+
L.status.__dict__[key] = value

pySDC/core/Level.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,7 @@ def __init__(self, params):
2121
class _Status(FrozenClass):
2222
"""
2323
This class carries the status of the level. All variables that the core SDC / PFASST functionality depend on are
24-
initialized here, while the convergence controllers are allowed to add more variables in a controlled fashion
25-
later on using the `add_variable` function.
24+
initialized here.
2625
"""
2726

2827
def __init__(self):

pySDC/core/Step.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,7 @@ def __init__(self, params):
2020
class _Status(FrozenClass):
2121
"""
2222
This class carries the status of the step. All variables that the core SDC / PFASST functionality depend on are
23-
initialized here, while the convergence controllers are allowed to add more variables in a controlled fashion
24-
later on using the `add_variable` function.
23+
initialized here.
2524
"""
2625

2726
def __init__(self):

pySDC/helpers/pysdc_helper.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ class FrozenClass(object):
66
__isfrozen: Flag to freeze a class
77
"""
88

9+
attrs = []
10+
911
__isfrozen = False
1012

1113
def __setattr__(self, key, value):
@@ -18,10 +20,33 @@ def __setattr__(self, key, value):
1820
"""
1921

2022
# check if attribute exists and if class is frozen
21-
if self.__isfrozen and not hasattr(self, key):
22-
raise TypeError("%r is a frozen class" % self)
23+
if self.__isfrozen and not (key in self.attrs or hasattr(self, key)):
24+
raise TypeError(f'{type(self).__name__!r} is a frozen class, cannot add attribute {key!r}')
25+
2326
object.__setattr__(self, key, value)
2427

28+
def __getattr__(self, key):
29+
"""
30+
This is needed in case the variables have not been initialized after adding.
31+
"""
32+
if key in self.attrs:
33+
return None
34+
else:
35+
super().__getattr__(key)
36+
37+
@classmethod
38+
def add_attr(cls, key, raise_error_if_exists=False):
39+
"""
40+
Add a key to the allowed attributes of this class.
41+
42+
Args:
43+
key (str): The key to add
44+
raise_error_if_exists (bool): Raise an error if the attribute already exists in the class
45+
"""
46+
if key in cls.attrs and raise_error_if_exists:
47+
raise TypeError(f'Attribute {key!r} already exists in {cls.__name__}!')
48+
cls.attrs += [key]
49+
2550
def _freeze(self):
2651
"""
2752
Function to freeze the class
@@ -40,3 +65,9 @@ def get(self, key, default=None):
4065
__dict__.get(key, default)
4166
"""
4267
return self.__dict__.get(key, default)
68+
69+
def __dir__(self):
70+
"""
71+
My hope is that some editors can use this for dynamic autocompletion.
72+
"""
73+
return super().__dir__() + self.attrs

pySDC/implementations/convergence_controller_classes/basic_restarting.py

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -76,36 +76,26 @@ def setup(self, controller, params, description, **kwargs):
7676

7777
return {**defaults, **super().setup(controller, params, description, **kwargs)}
7878

79-
def setup_status_variables(self, controller, **kwargs):
79+
def setup_status_variables(self, *args, **kwargs):
8080
"""
8181
Add status variables for whether to restart now and how many times the step has been restarted in a row to the
8282
Steps
8383
84-
Args:
85-
controller (pySDC.Controller): The controller
86-
reset (bool): Whether the function is called for the first time or to reset
87-
8884
Returns:
8985
None
9086
"""
91-
where = ["S" if 'comm' in kwargs.keys() else "MS", "status"]
92-
self.add_variable(controller, name='restart', where=where, init=False)
93-
self.add_variable(controller, name='restarts_in_a_row', where=where, init=0)
87+
self.add_status_variable_to_step('restart', False)
88+
self.add_status_variable_to_step('restarts_in_a_row', 0)
9489

95-
def reset_status_variables(self, controller, reset=False, **kwargs):
90+
def reset_status_variables(self, *args, **kwargs):
9691
"""
9792
Add status variables for whether to restart now and how many times the step has been restarted in a row to the
9893
Steps
9994
100-
Args:
101-
controller (pySDC.Controller): The controller
102-
reset (bool): Whether the function is called for the first time or to reset
103-
10495
Returns:
10596
None
10697
"""
107-
where = ["S" if 'comm' in kwargs.keys() else "MS", "status"]
108-
self.reset_variable(controller, name='restart', where=where, init=False)
98+
self.set_step_status_variable('restart', False)
10999

110100
def dependencies(self, controller, description, **kwargs):
111101
"""

pySDC/implementations/convergence_controller_classes/estimate_contraction_factor.py

Lines changed: 5 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -39,41 +39,17 @@ def dependencies(self, controller, description, **kwargs):
3939
description=description,
4040
)
4141

42-
def setup_status_variables(self, controller, **kwargs):
42+
def setup_status_variables(self, *args, **kwargs):
4343
"""
4444
Add the embedded error, contraction factor and iterations to convergence variable to the status of the levels.
4545
46-
Args:
47-
controller (pySDC.Controller): The controller
48-
49-
Returns:
50-
None
51-
"""
52-
if 'comm' in kwargs.keys():
53-
steps = [controller.S]
54-
else:
55-
if 'active_slots' in kwargs.keys():
56-
steps = [controller.MS[i] for i in kwargs['active_slots']]
57-
else:
58-
steps = controller.MS
59-
where = ["levels", "status"]
60-
for S in steps:
61-
self.add_variable(S, name='error_embedded_estimate_last_iter', where=where, init=None)
62-
self.add_variable(S, name='contraction_factor', where=where, init=None)
63-
if self.params.e_tol is not None:
64-
self.add_variable(S, name='iter_to_convergence', where=where, init=None)
65-
66-
def reset_status_variables(self, controller, **kwargs):
67-
"""
68-
Reinitialize new status variables for the levels.
69-
70-
Args:
71-
controller (pySDC.controller): The controller
72-
7346
Returns:
7447
None
7548
"""
76-
self.setup_status_variables(controller, **kwargs)
49+
self.add_status_variable_to_level('error_embedded_estimate_last_iter')
50+
self.add_status_variable_to_level('contraction_factor')
51+
if self.params.e_tol is not None:
52+
self.add_status_variable_to_level('iter_to_convergence')
7753

7854
def post_iteration_processing(self, controller, S, **kwargs):
7955
"""

pySDC/implementations/convergence_controller_classes/estimate_embedded_error.py

Lines changed: 3 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -114,19 +114,7 @@ def setup_status_variables(self, controller, **kwargs):
114114
Args:
115115
controller (pySDC.Controller): The controller
116116
"""
117-
if 'comm' in kwargs.keys():
118-
steps = [controller.S]
119-
else:
120-
if 'active_slots' in kwargs.keys():
121-
steps = [controller.MS[i] for i in kwargs['active_slots']]
122-
else:
123-
steps = controller.MS
124-
where = ["levels", "status"]
125-
for S in steps:
126-
self.add_variable(S, name='error_embedded_estimate', where=where, init=None)
127-
128-
def reset_status_variables(self, controller, **kwargs):
129-
self.setup_status_variables(controller, **kwargs)
117+
self.add_status_variable_to_level('error_embedded_estimate')
130118

131119
def post_iteration_processing(self, controller, S, **kwargs):
132120
"""
@@ -350,7 +338,7 @@ def post_iteration_processing(self, controller, step, **kwargs):
350338
max([np.finfo(float).eps, abs(self.status.u[-1] - self.status.u[-2])]),
351339
)
352340

353-
def setup_status_variables(self, controller, **kwargs):
341+
def setup_status_variables(self, *args, **kwargs):
354342
"""
355343
Add the embedded error variable to the levels and add a status variable for previous steps.
356344
@@ -361,16 +349,4 @@ def setup_status_variables(self, controller, **kwargs):
361349
self.status.u = [] # the solutions of converged collocation problems
362350
self.status.iter = [] # the iteration in which the solution converged
363351

364-
if 'comm' in kwargs.keys():
365-
steps = [controller.S]
366-
else:
367-
if 'active_slots' in kwargs.keys():
368-
steps = [controller.MS[i] for i in kwargs['active_slots']]
369-
else:
370-
steps = controller.MS
371-
where = ["levels", "status"]
372-
for S in steps:
373-
self.add_variable(S, name='error_embedded_estimate_collocation', where=where, init=None)
374-
375-
def reset_status_variables(self, controller, **kwargs):
376-
self.setup_status_variables(controller, **kwargs)
352+
self.add_status_variable_to_level('error_embedded_estimate_collocation')

pySDC/implementations/convergence_controller_classes/estimate_extrapolation_error.py

Lines changed: 1 addition & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -84,29 +84,7 @@ def setup_status_variables(self, controller, **kwargs):
8484
self.coeff.u = [None] * self.params.n
8585
self.coeff.f = [0.0] * self.params.n
8686

87-
self.reset_status_variables(controller, **kwargs)
88-
return None
89-
90-
def reset_status_variables(self, controller, **kwargs):
91-
"""
92-
Add variable for extrapolated error
93-
94-
Args:
95-
controller (pySDC.Controller): The controller
96-
97-
Returns:
98-
None
99-
"""
100-
if 'comm' in kwargs.keys():
101-
steps = [controller.S]
102-
else:
103-
if 'active_slots' in kwargs.keys():
104-
steps = [controller.MS[i] for i in kwargs['active_slots']]
105-
else:
106-
steps = controller.MS
107-
where = ["levels", "status"]
108-
for S in steps:
109-
self.add_variable(S, name='error_extrapolation_estimate', where=where, init=None)
87+
self.add_status_variable_to_level('error_extrapolation_estimate')
11088

11189
def check_parameters(self, controller, params, description, **kwargs):
11290
"""

0 commit comments

Comments
 (0)