44from gusto .core .labels import explicit
55
66from pySDC .implementations .controller_classes .controller_nonMPI import controller_nonMPI
7+ from pySDC .implementations .controller_classes .controller_MPI import controller_MPI
78from pySDC .implementations .problem_classes .GenericGusto import GenericGusto , GenericGustoImex
89from pySDC .core .hooks import Hooks
910from pySDC .helpers .stats_helper import get_sorted
1011
12+ import logging
13+ import numpy as np
14+
1115
1216class LogTime (Hooks ):
1317 """
@@ -34,6 +38,10 @@ class pySDC_integrator(TimeDiscretisation):
3438 It will construct a pySDC controller which can be used by itself and will be used within the time step when called
3539 from Gusto. Access the controller via `pySDC_integrator.controller`. This class also has `pySDC_integrator.stats`,
3640 which gathers all of the pySDC stats recorded in the hooks during every time step when used within Gusto.
41+
42+ This class supports subcycling with multi-step SDC. You can use pseudo-parallelism by simply giving `n_steps` > 1 or
43+ do proper parallelism by giving a `controller_communicator` of kind `pySDC.FiredrakeEnsembleCommunicator` with the
44+ appropriate size. You also have to toggle between pseudo and proper parallelism with `useMPIController`.
3745 """
3846
3947 def __init__ (
@@ -44,8 +52,10 @@ def __init__(
4452 field_name = None ,
4553 solver_parameters = None ,
4654 options = None ,
47- t0 = 0 ,
4855 imex = False ,
56+ useMPIController = False ,
57+ n_steps = 1 ,
58+ controller_communicator = None ,
4959 ):
5060 """
5161 Initialization
@@ -63,6 +73,10 @@ def __init__(
6373 options to either be passed to the spatial discretisation, or
6474 to control the "wrapper" methods, such as Embedded DG or a
6575 recovery method. Defaults to None.
76+ imex (bool): Whether to use IMEX splitting
77+ useMPIController (bool): Whether to use the pseudo-parallel or proper parallel pySDC controller
78+ n_steps (int): Number of steps done in parallel when using pseudo-parallel pySDC controller
79+ controller_communicator (pySDC.FiredrakeEnsembleCommunicator, optional): Communicator for the proper parallel controller
6680 """
6781
6882 self ._residual = None
@@ -79,6 +93,23 @@ def __init__(
7993 self .timestepper = None
8094 self .dt_next = None
8195 self .imex = imex
96+ self .useMPIController = useMPIController
97+ self .controller_communicator = controller_communicator
98+
99+ if useMPIController :
100+ assert (
101+ type (self .controller_communicator ).__name__ == 'FiredrakeEnsembleCommunicator'
102+ ), f'Need to give a FiredrakeEnsembleCommunicator here, not { type (self .controller_communicator )} '
103+ if n_steps > 1 :
104+ logging .getLogger (type (self ).__name__ ).warning (
105+ f'Warning: You selected { n_steps = } , which will be ignored when using the MPI controller!'
106+ )
107+ assert (
108+ controller_communicator is not None
109+ ), 'You need to supply a communicator when using the MPI controller!'
110+ self .n_steps = controller_communicator .size
111+ else :
112+ self .n_steps = n_steps
82113
83114 def setup (self , equation , apply_bcs = True , * active_labels ):
84115 super ().setup (equation , apply_bcs , * active_labels )
@@ -96,7 +127,7 @@ def setup(self, equation, apply_bcs=True, *active_labels):
96127 'residual' : self ._residual ,
97128 ** self .description ['problem_params' ],
98129 }
99- self .description ['level_params' ]['dt' ] = float (self .domain .dt )
130+ self .description ['level_params' ]['dt' ] = float (self .domain .dt ) / self . n_steps
100131
101132 # add utility hook required for step size adaptivity
102133 hook_class = self .controller_params .get ('hook_class' , [])
@@ -106,7 +137,17 @@ def setup(self, equation, apply_bcs=True, *active_labels):
106137 self .controller_params ['hook_class' ] = hook_class
107138
108139 # prepare controller and variables
109- self .controller = controller_nonMPI (1 , description = self .description , controller_params = self .controller_params )
140+ if self .useMPIController :
141+ self .controller = controller_MPI (
142+ comm = self .controller_communicator ,
143+ description = self .description ,
144+ controller_params = self .controller_params ,
145+ )
146+ else :
147+ self .controller = controller_nonMPI (
148+ self .n_steps , description = self .description , controller_params = self .controller_params
149+ )
150+
110151 self .prob = self .level .prob
111152 self .sweeper = self .level .sweep
112153 self .x0_pySDC = self .prob .dtype_u (self .prob .init )
@@ -125,14 +166,26 @@ def residual(self):
125166 def residual (self , value ):
126167 """Make sure the pySDC problem residual and this residual are the same"""
127168 if hasattr (self , 'prob' ):
128- self .prob .residual = value
169+ if self .useMPIController :
170+ self .controller .S .levels [0 ].prob .residual = value
171+ else :
172+ for S in self .controller .MS :
173+ S .levels [0 ].prob .residual = value
129174 else :
130175 self ._residual = value
131176
177+ @property
178+ def step (self ):
179+ """Get the first step on the controller"""
180+ if self .useMPIController :
181+ return self .controller .S
182+ else :
183+ return self .controller .MS [0 ]
184+
132185 @property
133186 def level (self ):
134187 """Get the finest pySDC level"""
135- return self .controller . MS [ 0 ] .levels [0 ]
188+ return self .step .levels [0 ]
136189
137190 @wrapper_apply
138191 def apply (self , x_out , x_in ):
@@ -144,29 +197,31 @@ def apply(self, x_out, x_in):
144197 x_in (:class:`Function`): the input field.
145198 """
146199 self .x0_pySDC .functionspace .assign (x_in )
147- assert self .level .params .dt == float (self .dt ), 'Step sizes have diverged between pySDC and Gusto'
200+ assert np .isclose (
201+ self .level .params .dt * self .n_steps , float (self .dt )
202+ ), 'Step sizes have diverged between pySDC and Gusto'
148203
149204 if self .dt_next is not None :
150205 assert (
151206 self .timestepper is not None
152207 ), 'You need to set self.timestepper to the timestepper in order to facilitate adaptive step size selection here!'
153- self .timestepper .dt = fd .Constant (self .dt_next )
208+ self .timestepper .dt = fd .Constant (self .dt_next * self . n_steps )
154209 self .t = self .timestepper .t
155210
156211 uend , _stats = self .controller .run (u0 = self .x0_pySDC , t0 = float (self .t ), Tend = float (self .t + self .dt ))
157212
158213 # update time variables
159- if self .level .params .dt != float (self .dt ):
214+ if not np . isclose ( self .level .params .dt * self . n_steps , float (self .dt ) ):
160215 self .dt_next = self .level .params .dt
161216
162- self .t = get_sorted (_stats , type = '_time' , recomputed = False )[- 1 ][1 ]
217+ self .t = get_sorted (_stats , type = '_time' , recomputed = False , comm = self . controller_communicator )[- 1 ][1 ]
163218
164219 # update time of the Gusto stepper.
165220 # After this step, the Gusto stepper updates its time again to arrive at the correct time
166221 if self .timestepper is not None :
167222 self .timestepper .t = fd .Constant (self .t - self .dt )
168223
169- self .dt = self .level .params .dt
224+ self .dt = fd . Constant ( self .level .params .dt * self . n_steps )
170225
171226 # update stats and output
172227 self .stats = {** self .stats , ** _stats }
0 commit comments