44from gusto .core .labels import explicit
55
66from pySDC .implementations .controller_classes .controller_nonMPI import controller_nonMPI
7+ from pySDC .implementations .controller_classes .controller_MPI import controller_MPI
78from pySDC .implementations .problem_classes .GenericGusto import GenericGusto , GenericGustoImex
89from pySDC .core .hooks import Hooks
910from pySDC .helpers .stats_helper import get_sorted
1011
12+ import logging
13+ import numpy as np
14+
1115
1216class LogTime (Hooks ):
1317 """
@@ -34,25 +38,29 @@ class pySDC_integrator(TimeDiscretisation):
3438 It will construct a pySDC controller which can be used by itself and will be used within the time step when called
3539 from Gusto. Access the controller via `pySDC_integrator.controller`. This class also has `pySDC_integrator.stats`,
3640 which gathers all of the pySDC stats recorded in the hooks during every time step when used within Gusto.
41+
42+ This class supports subcycling with multi-step SDC. You can use pseudo-parallelism by simply giving `n_steps` > 1 or
43+ do proper parallelism by giving a `controller_communicator` of kind `pySDC.FiredrakeEnsembleCommunicator` with the
44+ appropriate size. You also have to toggle between pseudo and proper parallelism with `useMPIController`.
3745 """
3846
3947 def __init__ (
4048 self ,
41- equation ,
4249 description ,
4350 controller_params ,
4451 domain ,
4552 field_name = None ,
4653 solver_parameters = None ,
4754 options = None ,
48- t0 = 0 ,
4955 imex = False ,
56+ useMPIController = False ,
57+ n_steps = 1 ,
58+ controller_communicator = None ,
5059 ):
5160 """
5261 Initialization
5362
5463 Args:
55- equation (:class:`PrognosticEquation`): the prognostic equation.
5664 description (dict): pySDC description
5765 controller_params (dict): pySDC controller params
5866 domain (:class:`Domain`): the model's domain object, containing the
@@ -65,6 +73,10 @@ def __init__(
6573 options to either be passed to the spatial discretisation, or
6674 to control the "wrapper" methods, such as Embedded DG or a
6775 recovery method. Defaults to None.
76+ imex (bool): Whether to use IMEX splitting
77+ useMPIController (bool): Whether to use the pseudo-parallel or proper parallel pySDC controller
78+ n_steps (int): Number of steps done in parallel when using pseudo-parallel pySDC controller
79+ controller_communicator (pySDC.FiredrakeEnsembleCommunicator, optional): Communicator for the proper parallel controller
6880 """
6981
7082 self ._residual = None
@@ -81,6 +93,23 @@ def __init__(
8193 self .timestepper = None
8294 self .dt_next = None
8395 self .imex = imex
96+ self .useMPIController = useMPIController
97+ self .controller_communicator = controller_communicator
98+
99+ if useMPIController :
100+ assert (
101+ type (self .controller_communicator ).__name__ == 'FiredrakeEnsembleCommunicator'
102+ ), f'Need to give a FiredrakeEnsembleCommunicator here, not { type (self .controller_communicator )} '
103+ if n_steps > 1 :
104+ logging .getLogger (type (self ).__name__ ).warning (
105+ f'Warning: You selected { n_steps = } , which will be ignored when using the MPI controller!'
106+ )
107+ assert (
108+ controller_communicator is not None
109+ ), 'You need to supply a communicator when using the MPI controller!'
110+ self .n_steps = controller_communicator .size
111+ else :
112+ self .n_steps = n_steps
84113
85114 def setup (self , equation , apply_bcs = True , * active_labels ):
86115 super ().setup (equation , apply_bcs , * active_labels )
@@ -96,8 +125,9 @@ def setup(self, equation, apply_bcs=True, *active_labels):
96125 'equation' : equation ,
97126 'solver_parameters' : self .solver_parameters ,
98127 'residual' : self ._residual ,
128+ ** self .description ['problem_params' ],
99129 }
100- self .description ['level_params' ]['dt' ] = float (self .domain .dt )
130+ self .description ['level_params' ]['dt' ] = float (self .domain .dt ) / self . n_steps
101131
102132 # add utility hook required for step size adaptivity
103133 hook_class = self .controller_params .get ('hook_class' , [])
@@ -107,7 +137,17 @@ def setup(self, equation, apply_bcs=True, *active_labels):
107137 self .controller_params ['hook_class' ] = hook_class
108138
109139 # prepare controller and variables
110- self .controller = controller_nonMPI (1 , description = self .description , controller_params = self .controller_params )
140+ if self .useMPIController :
141+ self .controller = controller_MPI (
142+ comm = self .controller_communicator ,
143+ description = self .description ,
144+ controller_params = self .controller_params ,
145+ )
146+ else :
147+ self .controller = controller_nonMPI (
148+ self .n_steps , description = self .description , controller_params = self .controller_params
149+ )
150+
111151 self .prob = self .level .prob
112152 self .sweeper = self .level .sweep
113153 self .x0_pySDC = self .prob .dtype_u (self .prob .init )
@@ -126,14 +166,26 @@ def residual(self):
126166 def residual (self , value ):
127167 """Make sure the pySDC problem residual and this residual are the same"""
128168 if hasattr (self , 'prob' ):
129- self .prob .residual = value
169+ if self .useMPIController :
170+ self .controller .S .levels [0 ].prob .residual = value
171+ else :
172+ for S in self .controller .MS :
173+ S .levels [0 ].prob .residual = value
130174 else :
131175 self ._residual = value
132176
177+ @property
178+ def step (self ):
179+ """Get the first step on the controller"""
180+ if self .useMPIController :
181+ return self .controller .S
182+ else :
183+ return self .controller .MS [0 ]
184+
133185 @property
134186 def level (self ):
135187 """Get the finest pySDC level"""
136- return self .controller . MS [ 0 ] .levels [0 ]
188+ return self .step .levels [0 ]
137189
138190 @wrapper_apply
139191 def apply (self , x_out , x_in ):
@@ -145,29 +197,31 @@ def apply(self, x_out, x_in):
145197 x_in (:class:`Function`): the input field.
146198 """
147199 self .x0_pySDC .functionspace .assign (x_in )
148- assert self .level .params .dt == float (self .dt ), 'Step sizes have diverged between pySDC and Gusto'
200+ assert np .isclose (
201+ self .level .params .dt * self .n_steps , float (self .dt )
202+ ), 'Step sizes have diverged between pySDC and Gusto'
149203
150204 if self .dt_next is not None :
151205 assert (
152206 self .timestepper is not None
153207 ), 'You need to set self.timestepper to the timestepper in order to facilitate adaptive step size selection here!'
154- self .timestepper .dt = fd .Constant (self .dt_next )
208+ self .timestepper .dt = fd .Constant (self .dt_next * self . n_steps )
155209 self .t = self .timestepper .t
156210
157211 uend , _stats = self .controller .run (u0 = self .x0_pySDC , t0 = float (self .t ), Tend = float (self .t + self .dt ))
158212
159213 # update time variables
160- if self .level .params .dt != float (self .dt ):
214+ if not np . isclose ( self .level .params .dt * self . n_steps , float (self .dt ) ):
161215 self .dt_next = self .level .params .dt
162216
163- self .t = get_sorted (_stats , type = '_time' , recomputed = False )[- 1 ][1 ]
217+ self .t = get_sorted (_stats , type = '_time' , recomputed = False , comm = self . controller_communicator )[- 1 ][1 ]
164218
165219 # update time of the Gusto stepper.
166220 # After this step, the Gusto stepper updates its time again to arrive at the correct time
167221 if self .timestepper is not None :
168222 self .timestepper .t = fd .Constant (self .t - self .dt )
169223
170- self .dt = self .level .params .dt
224+ self .dt = fd . Constant ( self .level .params .dt * self . n_steps )
171225
172226 # update stats and output
173227 self .stats = {** self .stats , ** _stats }
0 commit comments