Skip to content

Commit f9eb1f5

Browse files
Added multi-step SDC to the Gusto coupling (#521)
* Implemented multi-step SDC in the Gusto coupling * Fixed test
1 parent 8b1ed0c commit f9eb1f5

File tree

5 files changed

+388
-20
lines changed

5 files changed

+388
-20
lines changed

pySDC/helpers/firedrake_ensemble_communicator.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@ def __init__(self, comm, space_size):
2121
ensemble (firedrake.Ensemble): Ensemble communicator
2222
"""
2323
self.ensemble = fd.Ensemble(comm, space_size)
24+
self.comm_wold = comm
25+
26+
def Split(self, *args, **kwargs):
27+
return FiredrakeEnsembleCommunicator(self.comm_wold.Split(*args, **kwargs), space_size=self.space_comm.size)
2428

2529
@property
2630
def space_comm(self):
@@ -53,6 +57,16 @@ def Bcast(self, buf, root=0):
5357
else:
5458
self.ensemble.bcast(buf, root=root)
5559

60+
def Irecv(self, buf, source, tag=MPI.ANY_TAG):
61+
if type(buf) in [np.ndarray, list]:
62+
return self.ensemble.ensemble_comm.Irecv(buf=buf, source=source, tag=tag)
63+
return self.ensemble.irecv(buf, source, tag=tag)[0]
64+
65+
def Isend(self, buf, dest, tag=MPI.ANY_TAG):
66+
if type(buf) in [np.ndarray, list]:
67+
return self.ensemble.ensemble_comm.Isend(buf=buf, dest=dest, tag=tag)
68+
return self.ensemble.isend(buf, dest, tag=tag)[0]
69+
5670

5771
def get_ensemble(comm, space_size):
5872
return fd.Ensemble(comm, space_size)

pySDC/helpers/pySDC_as_gusto_time_discretization.py

Lines changed: 65 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,14 @@
44
from gusto.core.labels import explicit
55

66
from pySDC.implementations.controller_classes.controller_nonMPI import controller_nonMPI
7+
from pySDC.implementations.controller_classes.controller_MPI import controller_MPI
78
from pySDC.implementations.problem_classes.GenericGusto import GenericGusto, GenericGustoImex
89
from pySDC.core.hooks import Hooks
910
from pySDC.helpers.stats_helper import get_sorted
1011

12+
import logging
13+
import numpy as np
14+
1115

1216
class LogTime(Hooks):
1317
"""
@@ -34,6 +38,10 @@ class pySDC_integrator(TimeDiscretisation):
3438
It will construct a pySDC controller which can be used by itself and will be used within the time step when called
3539
from Gusto. Access the controller via `pySDC_integrator.controller`. This class also has `pySDC_integrator.stats`,
3640
which gathers all of the pySDC stats recorded in the hooks during every time step when used within Gusto.
41+
42+
This class supports subcycling with multi-step SDC. You can use pseudo-parallelism by simply giving `n_steps` > 1 or
43+
do proper parallelism by giving a `controller_communicator` of kind `pySDC.FiredrakeEnsembleCommunicator` with the
44+
appropriate size. You also have to toggle between pseudo and proper parallelism with `useMPIController`.
3745
"""
3846

3947
def __init__(
@@ -44,8 +52,10 @@ def __init__(
4452
field_name=None,
4553
solver_parameters=None,
4654
options=None,
47-
t0=0,
4855
imex=False,
56+
useMPIController=False,
57+
n_steps=1,
58+
controller_communicator=None,
4959
):
5060
"""
5161
Initialization
@@ -63,6 +73,10 @@ def __init__(
6373
options to either be passed to the spatial discretisation, or
6474
to control the "wrapper" methods, such as Embedded DG or a
6575
recovery method. Defaults to None.
76+
imex (bool): Whether to use IMEX splitting
77+
useMPIController (bool): Whether to use the pseudo-parallel or proper parallel pySDC controller
78+
n_steps (int): Number of steps done in parallel when using pseudo-parallel pySDC controller
79+
controller_communicator (pySDC.FiredrakeEnsembleCommunicator, optional): Communicator for the proper parallel controller
6680
"""
6781

6882
self._residual = None
@@ -79,6 +93,23 @@ def __init__(
7993
self.timestepper = None
8094
self.dt_next = None
8195
self.imex = imex
96+
self.useMPIController = useMPIController
97+
self.controller_communicator = controller_communicator
98+
99+
if useMPIController:
100+
assert (
101+
type(self.controller_communicator).__name__ == 'FiredrakeEnsembleCommunicator'
102+
), f'Need to give a FiredrakeEnsembleCommunicator here, not {type(self.controller_communicator)}'
103+
if n_steps > 1:
104+
logging.getLogger(type(self).__name__).warning(
105+
f'Warning: You selected {n_steps=}, which will be ignored when using the MPI controller!'
106+
)
107+
assert (
108+
controller_communicator is not None
109+
), 'You need to supply a communicator when using the MPI controller!'
110+
self.n_steps = controller_communicator.size
111+
else:
112+
self.n_steps = n_steps
82113

83114
def setup(self, equation, apply_bcs=True, *active_labels):
84115
super().setup(equation, apply_bcs, *active_labels)
@@ -96,7 +127,7 @@ def setup(self, equation, apply_bcs=True, *active_labels):
96127
'residual': self._residual,
97128
**self.description['problem_params'],
98129
}
99-
self.description['level_params']['dt'] = float(self.domain.dt)
130+
self.description['level_params']['dt'] = float(self.domain.dt) / self.n_steps
100131

101132
# add utility hook required for step size adaptivity
102133
hook_class = self.controller_params.get('hook_class', [])
@@ -106,7 +137,17 @@ def setup(self, equation, apply_bcs=True, *active_labels):
106137
self.controller_params['hook_class'] = hook_class
107138

108139
# prepare controller and variables
109-
self.controller = controller_nonMPI(1, description=self.description, controller_params=self.controller_params)
140+
if self.useMPIController:
141+
self.controller = controller_MPI(
142+
comm=self.controller_communicator,
143+
description=self.description,
144+
controller_params=self.controller_params,
145+
)
146+
else:
147+
self.controller = controller_nonMPI(
148+
self.n_steps, description=self.description, controller_params=self.controller_params
149+
)
150+
110151
self.prob = self.level.prob
111152
self.sweeper = self.level.sweep
112153
self.x0_pySDC = self.prob.dtype_u(self.prob.init)
@@ -125,14 +166,26 @@ def residual(self):
125166
def residual(self, value):
126167
"""Make sure the pySDC problem residual and this residual are the same"""
127168
if hasattr(self, 'prob'):
128-
self.prob.residual = value
169+
if self.useMPIController:
170+
self.controller.S.levels[0].prob.residual = value
171+
else:
172+
for S in self.controller.MS:
173+
S.levels[0].prob.residual = value
129174
else:
130175
self._residual = value
131176

177+
@property
178+
def step(self):
179+
"""Get the first step on the controller"""
180+
if self.useMPIController:
181+
return self.controller.S
182+
else:
183+
return self.controller.MS[0]
184+
132185
@property
133186
def level(self):
134187
"""Get the finest pySDC level"""
135-
return self.controller.MS[0].levels[0]
188+
return self.step.levels[0]
136189

137190
@wrapper_apply
138191
def apply(self, x_out, x_in):
@@ -144,29 +197,31 @@ def apply(self, x_out, x_in):
144197
x_in (:class:`Function`): the input field.
145198
"""
146199
self.x0_pySDC.functionspace.assign(x_in)
147-
assert self.level.params.dt == float(self.dt), 'Step sizes have diverged between pySDC and Gusto'
200+
assert np.isclose(
201+
self.level.params.dt * self.n_steps, float(self.dt)
202+
), 'Step sizes have diverged between pySDC and Gusto'
148203

149204
if self.dt_next is not None:
150205
assert (
151206
self.timestepper is not None
152207
), 'You need to set self.timestepper to the timestepper in order to facilitate adaptive step size selection here!'
153-
self.timestepper.dt = fd.Constant(self.dt_next)
208+
self.timestepper.dt = fd.Constant(self.dt_next * self.n_steps)
154209
self.t = self.timestepper.t
155210

156211
uend, _stats = self.controller.run(u0=self.x0_pySDC, t0=float(self.t), Tend=float(self.t + self.dt))
157212

158213
# update time variables
159-
if self.level.params.dt != float(self.dt):
214+
if not np.isclose(self.level.params.dt * self.n_steps, float(self.dt)):
160215
self.dt_next = self.level.params.dt
161216

162-
self.t = get_sorted(_stats, type='_time', recomputed=False)[-1][1]
217+
self.t = get_sorted(_stats, type='_time', recomputed=False, comm=self.controller_communicator)[-1][1]
163218

164219
# update time of the Gusto stepper.
165220
# After this step, the Gusto stepper updates its time again to arrive at the correct time
166221
if self.timestepper is not None:
167222
self.timestepper.t = fd.Constant(self.t - self.dt)
168223

169-
self.dt = self.level.params.dt
224+
self.dt = fd.Constant(self.level.params.dt * self.n_steps)
170225

171226
# update stats and output
172227
self.stats = {**self.stats, **_stats}

pySDC/implementations/datatype_classes/firedrake_mesh.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import firedrake as fd
22

33
from pySDC.core.errors import DataError
4+
from pySDC.helpers.firedrake_ensemble_communicator import FiredrakeEnsembleCommunicator
45

56

67
class firedrake_mesh(object):
@@ -77,6 +78,57 @@ def __abs__(self):
7778

7879
return fd.norm(self.functionspace, 'L2')
7980

81+
def isend(self, dest=None, tag=None, comm=None):
82+
"""
83+
Routine for sending data forward in time (non-blocking)
84+
85+
Args:
86+
dest (int): target rank
87+
tag (int): communication tag
88+
comm: communicator
89+
90+
Returns:
91+
request handle
92+
"""
93+
assert (
94+
type(comm) == FiredrakeEnsembleCommunicator
95+
), f'Need to give a FiredrakeEnsembleCommunicator here, not {type(comm)}'
96+
return comm.Isend(self.functionspace, dest=dest, tag=tag)
97+
98+
def irecv(self, source=None, tag=None, comm=None):
99+
"""
100+
Routine for receiving in time
101+
102+
Args:
103+
source (int): source rank
104+
tag (int): communication tag
105+
comm: communicator
106+
107+
Returns:
108+
None
109+
"""
110+
assert (
111+
type(comm) == FiredrakeEnsembleCommunicator
112+
), f'Need to give a FiredrakeEnsembleCommunicator here, not {type(comm)}'
113+
return comm.Irecv(self.functionspace, source=source, tag=tag)
114+
115+
def bcast(self, root=None, comm=None):
116+
"""
117+
Routine for broadcasting values
118+
119+
Args:
120+
root (int): process with value to broadcast
121+
comm: communicator
122+
123+
Returns:
124+
broadcasted values
125+
"""
126+
assert (
127+
type(comm) == FiredrakeEnsembleCommunicator
128+
), f'Need to give a FiredrakeEnsembleCommunicator here, not {type(comm)}'
129+
comm.Bcast(self.functionspace, root=root)
130+
return self
131+
80132

81133
class IMEX_firedrake_mesh(object):
82134
"""

pySDC/tests/test_datatypes/test_firedrake_mesh.py

Lines changed: 82 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,5 +151,86 @@ def test_rmul_rhs(n=3, v1=1, v2=2):
151151
assert np.allclose(b.expl.dat._numpy_data, v2 * v1)
152152

153153

154+
def _test_p2p_communication(comm, u):
155+
import numpy as np
156+
157+
assert comm.size == 2
158+
if comm.rank == 0:
159+
u.assign(3.14)
160+
req = u.isend(dest=1, comm=comm, tag=0)
161+
elif comm.rank == 1:
162+
assert not np.allclose(u.dat._numpy_data, 3.14)
163+
req = u.irecv(source=0, comm=comm, tag=0)
164+
req.wait()
165+
assert np.allclose(u.dat._numpy_data, 3.14)
166+
167+
168+
def _test_bcast(comm, u):
169+
import numpy as np
170+
171+
if comm.rank == 0:
172+
u.assign(3.14)
173+
else:
174+
assert not np.allclose(u.dat._numpy_data, 3.14)
175+
u.bcast(root=0, comm=comm)
176+
assert np.allclose(u.dat._numpy_data, 3.14)
177+
178+
179+
@pytest.mark.firedrake
180+
@pytest.mark.parametrize('pattern', ['p2p', 'bcast'])
181+
def test_communication(pattern, n=2, submit=True):
182+
if submit:
183+
import os
184+
import subprocess
185+
186+
my_env = os.environ.copy()
187+
my_env['COVERAGE_PROCESS_START'] = 'pyproject.toml'
188+
cwd = '.'
189+
num_procs = 2
190+
cmd = f'mpiexec -np {num_procs} python {__file__} --pattern {pattern}'.split()
191+
192+
p = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=my_env, cwd=cwd)
193+
p.wait()
194+
for line in p.stdout:
195+
print(line)
196+
for line in p.stderr:
197+
print(line)
198+
assert p.returncode == 0, 'ERROR: did not get return code 0, got %s with %2i processes' % (
199+
p.returncode,
200+
num_procs,
201+
)
202+
203+
else:
204+
import firedrake as fd
205+
from pySDC.helpers.firedrake_ensemble_communicator import FiredrakeEnsembleCommunicator
206+
from pySDC.implementations.datatype_classes.firedrake_mesh import firedrake_mesh
207+
208+
ensemble_comm = FiredrakeEnsembleCommunicator(fd.COMM_WORLD, 1)
209+
210+
mesh = fd.UnitSquareMesh(n, n, comm=ensemble_comm.space_comm)
211+
V = fd.VectorFunctionSpace(mesh, "CG", 2)
212+
213+
u = firedrake_mesh(V)
214+
215+
if pattern == 'p2p':
216+
_test_p2p_communication(ensemble_comm, u)
217+
elif pattern == 'bcast':
218+
_test_bcast(ensemble_comm, u)
219+
else:
220+
raise NotImplementedError
221+
222+
154223
if __name__ == '__main__':
155-
test_addition()
224+
from argparse import ArgumentParser
225+
226+
parser = ArgumentParser()
227+
parser.add_argument(
228+
'--pattern',
229+
help="pattern for parallel tests",
230+
type=str,
231+
default=None,
232+
)
233+
args = parser.parse_args()
234+
235+
if args.pattern:
236+
test_communication(pattern=args.pattern, submit=False)

0 commit comments

Comments
 (0)