Skip to content

Commit 525f6a1

Browse files
Sweepers require level during instantiation (#577)
* Put level as argument of sweeper * Fixed more sweepers in projects * Fixed second order SDC tests
1 parent 23ef9d3 commit 525f6a1

26 files changed

+110
-83
lines changed

pySDC/core/level.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -67,14 +67,14 @@ def __init__(self, problem_class, problem_params, sweeper_class, sweeper_params,
6767
level_index (int): custom name for this level
6868
"""
6969

70-
# instantiate sweeper, problem and hooks
71-
self.__sweep = sweeper_class(sweeper_params)
72-
self.__prob = problem_class(**problem_params)
73-
7470
# set level parameters and status
7571
self.params = _Pars(level_params)
7672
self.status = _Status()
7773

74+
# instantiate sweeper, problem and hooks
75+
self.__sweep = sweeper_class(sweeper_params, self)
76+
self.__prob = problem_class(**problem_params)
77+
7878
# set name
7979
self.level_index = level_index
8080

@@ -90,9 +90,6 @@ def __init__(self, problem_class, problem_params, sweeper_class, sweeper_params,
9090

9191
self.tau = [None] * self.sweep.coll.num_nodes
9292

93-
# pass this level to the sweeper for easy access
94-
self.sweep.level = self
95-
9693
self.__tag = None
9794

9895
# freeze class, no further attributes allowed from this point

pySDC/core/sweeper.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,16 +44,15 @@ class Sweeper(object):
4444
coll (pySDC.Collocation.CollBase): collocation object
4545
"""
4646

47-
def __init__(self, params):
47+
def __init__(self, params, level):
4848
"""
4949
Initialization routine for the base sweeper
5050
5151
Args:
5252
params (dict): parameter object
53-
53+
level (pySDC.Level.level): the level that uses this sweeper
5454
"""
5555

56-
# set up logger
5756
self.logger = logging.getLogger('sweeper')
5857

5958
essential_keys = ['num_nodes']
@@ -81,9 +80,7 @@ def __init__(self, params):
8180
)
8281
self.params.do_coll_update = True
8382

84-
# This will be set as soon as the sweeper is instantiated at the level
85-
self.__level = None
86-
83+
self.__level = level
8784
self.parallelizable = False
8885

8986
def setupGenerator(self, qd_type):

pySDC/helpers/setup_helper.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,9 @@ def generate_description(problem_class, **kwargs):
2424

2525
problem_keys = problem_class.__init__.__code__.co_varnames
2626
level_keys = level_params({}).__dict__.keys()
27-
sweeper_keys = description['sweeper_class']({'num_nodes': 1, 'quad_type': 'RADAU-RIGHT'}).params.__dict__.keys()
27+
sweeper_keys = description['sweeper_class'](
28+
{'num_nodes': 1, 'quad_type': 'RADAU-RIGHT'}, None
29+
).params.__dict__.keys()
2830
step_keys = step_params({}).__dict__.keys()
2931

3032
# TODO: add convergence controllers

pySDC/implementations/convergence_controller_classes/adaptive_collocation.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,8 +142,7 @@ def switch_sweeper(self, S):
142142
nodes_old = L.sweep.coll.nodes.copy()
143143

144144
# change sweeper
145-
L.sweep.__init__(update_params_sweeper)
146-
L.sweep.level = L
145+
L.sweep.__init__(update_params_sweeper, L)
147146

148147
# reset level to tell it the new structure of the solution
149148
L.params.__dict__.update(new_params_level)

pySDC/implementations/sweeper_classes/Multistep.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from pySDC.core.sweeper import Sweeper, _Pars
2+
from pySDC.core.level import Level
23

34

45
class Cache(object):
@@ -55,7 +56,7 @@ class MultiStep(Sweeper):
5556
alpha = None
5657
beta = None
5758

58-
def __init__(self, params):
59+
def __init__(self, params, level):
5960
"""
6061
Initialization routine for the base sweeper.
6162
@@ -71,6 +72,7 @@ def __init__(self, params):
7172
7273
Args:
7374
params (dict): parameter object
75+
level (pySDC.Level.level): the level that uses this sweeper
7476
"""
7577
import logging
7678
from pySDC.core.collocation import CollBase
@@ -88,15 +90,36 @@ def __init__(self, params):
8890
# we need a dummy collocation object to instantiate the levels.
8991
self.coll = CollBase(num_nodes=1, quad_type='RADAU-RIGHT')
9092

91-
# This will be set as soon as the sweeper is instantiated at the level
92-
self.__level = None
93+
self.__level = level
9394

9495
self.parallelizable = False
9596

9697
# proprietary variables for the multistep methods
9798
self.steps = len(self.alpha)
9899
self.cache = Cache(self.steps)
99100

101+
@property
102+
def level(self):
103+
"""
104+
Returns the current level
105+
106+
Returns:
107+
pySDC.Level.level: Current level
108+
"""
109+
return self.__level
110+
111+
@level.setter
112+
def level(self, lvl):
113+
"""
114+
Sets a reference to the current level (done in the initialization of the level)
115+
116+
Args:
117+
lvl (pySDC.Level.level): Current level
118+
"""
119+
assert isinstance(lvl, Level), f"You tried to set the sweeper's level with an instance of {type(lvl)}!"
120+
121+
self.__level = lvl
122+
100123
def predict(self):
101124
"""
102125
Add the initial conditions to the cache if needed.

pySDC/implementations/sweeper_classes/Runge_Kutta.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -125,12 +125,13 @@ class RungeKutta(Sweeper):
125125
The entries of the Butcher tableau are stored as class attributes.
126126
"""
127127

128-
def __init__(self, params):
128+
def __init__(self, params, level):
129129
"""
130130
Initialization routine for the custom sweeper
131131
132132
Args:
133133
params: parameters for the sweeper
134+
level (pySDC.Level.level): the level that uses this sweeper
134135
"""
135136
# set up logger
136137
self.logger = logging.getLogger('sweeper')
@@ -156,8 +157,9 @@ def __init__(self, params):
156157

157158
self.params = _Pars(params)
158159

159-
# This will be set as soon as the sweeper is instantiated at the level
160+
# set level using the setter in order to adapt residual tolerance if needed
160161
self.__level = None
162+
self.level = level
161163

162164
self.parallelizable = False
163165
self.QI = self.coll.Qmat
@@ -343,14 +345,15 @@ class RungeKuttaIMEX(RungeKutta):
343345
weights_explicit = None
344346
ButcherTableauClass_explicit = ButcherTableau
345347

346-
def __init__(self, params):
348+
def __init__(self, params, level):
347349
"""
348350
Initialization routine
349351
350352
Args:
351353
params: parameters for the sweeper
354+
level (pySDC.Level.level): the level that uses this sweeper
352355
"""
353-
super().__init__(params)
356+
super().__init__(params, level)
354357
type(self).weights_explicit = self.weights if self.weights_explicit is None else self.weights_explicit
355358
self.coll_explicit = self.get_Butcher_tableau_explicit()
356359
self.QE = self.coll_explicit.Qmat

pySDC/implementations/sweeper_classes/Runge_Kutta_Nystrom.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,14 +104,15 @@ class RungeKuttaNystrom(RungeKutta):
104104
weights_bar = None
105105
matrix_bar = None
106106

107-
def __init__(self, params):
107+
def __init__(self, params, level):
108108
"""
109109
Initialization routine for the custom sweeper
110110
111111
Args:
112112
params: parameters for the sweeper
113+
level (pySDC.Level.level): the level that uses this sweeper
113114
"""
114-
super().__init__(params)
115+
super().__init__(params, level)
115116
self.coll_bar = self.get_Butcher_tableau_bar()
116117
self.Qx = self.coll_bar.Qmat
117118

pySDC/implementations/sweeper_classes/boris_2nd_order.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,22 +16,21 @@ class boris_2nd_order(Sweeper):
1616
Sx: node-to-node Euler half-step for position update
1717
"""
1818

19-
def __init__(self, params):
19+
def __init__(self, params, level):
2020
"""
2121
Initialization routine for the custom sweeper
2222
2323
Args:
2424
params: parameters for the sweeper
25+
level (pySDC.Level.level): the level that uses this sweeper
2526
"""
2627

27-
# call parent's initialization routine
28-
2928
if "QI" not in params:
3029
params["QI"] = "IE"
3130
if "QE" not in params:
3231
params["QE"] = "EE"
3332

34-
super(boris_2nd_order, self).__init__(params)
33+
super(boris_2nd_order, self).__init__(params, level)
3534

3635
# S- and SQ-matrices (derived from Q) and Sx- and ST-matrices for the integrator
3736
[

pySDC/implementations/sweeper_classes/explicit.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,19 +9,19 @@ class explicit(Sweeper):
99
QE: explicit Euler integration matrix
1010
"""
1111

12-
def __init__(self, params):
12+
def __init__(self, params, level):
1313
"""
1414
Initialization routine for the custom sweeper
1515
1616
Args:
1717
params: parameters for the sweeper
18+
level (pySDC.Level.level): the level that uses this sweeper
1819
"""
1920

2021
if 'QE' not in params:
2122
params['QE'] = 'EE'
2223

23-
# call parent's initialization routine
24-
super(explicit, self).__init__(params)
24+
super(explicit, self).__init__(params, level)
2525

2626
# integration matrix
2727
self.QE = self.get_Qdelta_explicit(qd_type=self.params.QE)

pySDC/implementations/sweeper_classes/generic_implicit.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,19 +9,19 @@ class generic_implicit(Sweeper):
99
QI: lower triangular matrix
1010
"""
1111

12-
def __init__(self, params):
12+
def __init__(self, params, level):
1313
"""
1414
Initialization routine for the custom sweeper
1515
1616
Args:
1717
params: parameters for the sweeper
18+
level (pySDC.Level.level): the level that uses this sweeper
1819
"""
1920

2021
if 'QI' not in params:
2122
params['QI'] = 'IE'
2223

23-
# call parent's initialization routine
24-
super().__init__(params)
24+
super().__init__(params, level)
2525

2626
# get QI matrix
2727
self.QI = self.get_Qdelta_implicit(qd_type=self.params.QI)

0 commit comments

Comments
 (0)