Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions pySDC/helpers/firedrake_ensemble_communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ def __init__(self, comm, space_size):
ensemble (firedrake.Ensemble): Ensemble communicator
"""
self.ensemble = fd.Ensemble(comm, space_size)
self.comm_wold = comm

def Split(self, *args, **kwargs):
return FiredrakeEnsembleCommunicator(self.comm_wold.Split(*args, **kwargs), space_size=self.space_comm.size)

@property
def space_comm(self):
Expand Down Expand Up @@ -53,6 +57,16 @@ def Bcast(self, buf, root=0):
else:
self.ensemble.bcast(buf, root=root)

def Irecv(self, buf, source, tag=MPI.ANY_TAG):
if type(buf) in [np.ndarray, list]:
return self.ensemble.ensemble_comm.Irecv(buf=buf, source=source, tag=tag)
return self.ensemble.irecv(buf, source, tag=tag)[0]

def Isend(self, buf, dest, tag=MPI.ANY_TAG):
if type(buf) in [np.ndarray, list]:
return self.ensemble.ensemble_comm.Isend(buf=buf, dest=dest, tag=tag)
return self.ensemble.isend(buf, dest, tag=tag)[0]


def get_ensemble(comm, space_size):
return fd.Ensemble(comm, space_size)
78 changes: 66 additions & 12 deletions pySDC/helpers/pySDC_as_gusto_time_discretization.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,14 @@
from gusto.core.labels import explicit

from pySDC.implementations.controller_classes.controller_nonMPI import controller_nonMPI
from pySDC.implementations.controller_classes.controller_MPI import controller_MPI
from pySDC.implementations.problem_classes.GenericGusto import GenericGusto, GenericGustoImex
from pySDC.core.hooks import Hooks
from pySDC.helpers.stats_helper import get_sorted

import logging
import numpy as np


class LogTime(Hooks):
"""
Expand All @@ -34,25 +38,29 @@ class pySDC_integrator(TimeDiscretisation):
It will construct a pySDC controller which can be used by itself and will be used within the time step when called
from Gusto. Access the controller via `pySDC_integrator.controller`. This class also has `pySDC_integrator.stats`,
which gathers all of the pySDC stats recorded in the hooks during every time step when used within Gusto.

This class supports subcycling with multi-step SDC. You can use pseudo-parallelism by simply giving `n_steps` > 1 or
do proper parallelism by giving a `controller_communicator` of kind `pySDC.FiredrakeEnsembleCommunicator` with the
appropriate size. You also have to toggle between pseudo and proper parallelism with `useMPIController`.
"""

def __init__(
self,
equation,
description,
controller_params,
domain,
field_name=None,
solver_parameters=None,
options=None,
t0=0,
imex=False,
useMPIController=False,
n_steps=1,
controller_communicator=None,
):
"""
Initialization

Args:
equation (:class:`PrognosticEquation`): the prognostic equation.
description (dict): pySDC description
controller_params (dict): pySDC controller params
domain (:class:`Domain`): the model's domain object, containing the
Expand All @@ -65,6 +73,10 @@ def __init__(
options to either be passed to the spatial discretisation, or
to control the "wrapper" methods, such as Embedded DG or a
recovery method. Defaults to None.
imex (bool): Whether to use IMEX splitting
useMPIController (bool): Whether to use the pseudo-parallel or proper parallel pySDC controller
n_steps (int): Number of steps done in parallel when using pseudo-parallel pySDC controller
controller_communicator (pySDC.FiredrakeEnsembleCommunicator, optional): Communicator for the proper parallel controller
"""

self._residual = None
Expand All @@ -81,6 +93,23 @@ def __init__(
self.timestepper = None
self.dt_next = None
self.imex = imex
self.useMPIController = useMPIController
self.controller_communicator = controller_communicator

if useMPIController:
assert (
type(self.controller_communicator).__name__ == 'FiredrakeEnsembleCommunicator'
), f'Need to give a FiredrakeEnsembleCommunicator here, not {type(self.controller_communicator)}'
if n_steps > 1:
logging.getLogger(type(self).__name__).warning(
f'Warning: You selected {n_steps=}, which will be ignored when using the MPI controller!'
)
assert (
controller_communicator is not None
), 'You need to supply a communicator when using the MPI controller!'
self.n_steps = controller_communicator.size
else:
self.n_steps = n_steps

def setup(self, equation, apply_bcs=True, *active_labels):
super().setup(equation, apply_bcs, *active_labels)
Expand All @@ -96,8 +125,9 @@ def setup(self, equation, apply_bcs=True, *active_labels):
'equation': equation,
'solver_parameters': self.solver_parameters,
'residual': self._residual,
**self.description['problem_params'],
}
self.description['level_params']['dt'] = float(self.domain.dt)
self.description['level_params']['dt'] = float(self.domain.dt) / self.n_steps

# add utility hook required for step size adaptivity
hook_class = self.controller_params.get('hook_class', [])
Expand All @@ -107,7 +137,17 @@ def setup(self, equation, apply_bcs=True, *active_labels):
self.controller_params['hook_class'] = hook_class

# prepare controller and variables
self.controller = controller_nonMPI(1, description=self.description, controller_params=self.controller_params)
if self.useMPIController:
self.controller = controller_MPI(
comm=self.controller_communicator,
description=self.description,
controller_params=self.controller_params,
)
else:
self.controller = controller_nonMPI(
self.n_steps, description=self.description, controller_params=self.controller_params
)

self.prob = self.level.prob
self.sweeper = self.level.sweep
self.x0_pySDC = self.prob.dtype_u(self.prob.init)
Expand All @@ -126,14 +166,26 @@ def residual(self):
def residual(self, value):
"""Make sure the pySDC problem residual and this residual are the same"""
if hasattr(self, 'prob'):
self.prob.residual = value
if self.useMPIController:
self.controller.S.levels[0].prob.residual = value
else:
for S in self.controller.MS:
S.levels[0].prob.residual = value
else:
self._residual = value

@property
def step(self):
"""Get the first step on the controller"""
if self.useMPIController:
return self.controller.S
else:
return self.controller.MS[0]

@property
def level(self):
"""Get the finest pySDC level"""
return self.controller.MS[0].levels[0]
return self.step.levels[0]

@wrapper_apply
def apply(self, x_out, x_in):
Expand All @@ -145,29 +197,31 @@ def apply(self, x_out, x_in):
x_in (:class:`Function`): the input field.
"""
self.x0_pySDC.functionspace.assign(x_in)
assert self.level.params.dt == float(self.dt), 'Step sizes have diverged between pySDC and Gusto'
assert np.isclose(
self.level.params.dt * self.n_steps, float(self.dt)
), 'Step sizes have diverged between pySDC and Gusto'

if self.dt_next is not None:
assert (
self.timestepper is not None
), 'You need to set self.timestepper to the timestepper in order to facilitate adaptive step size selection here!'
self.timestepper.dt = fd.Constant(self.dt_next)
self.timestepper.dt = fd.Constant(self.dt_next * self.n_steps)
self.t = self.timestepper.t

uend, _stats = self.controller.run(u0=self.x0_pySDC, t0=float(self.t), Tend=float(self.t + self.dt))

# update time variables
if self.level.params.dt != float(self.dt):
if not np.isclose(self.level.params.dt * self.n_steps, float(self.dt)):
self.dt_next = self.level.params.dt

self.t = get_sorted(_stats, type='_time', recomputed=False)[-1][1]
self.t = get_sorted(_stats, type='_time', recomputed=False, comm=self.controller_communicator)[-1][1]

# update time of the Gusto stepper.
# After this step, the Gusto stepper updates its time again to arrive at the correct time
if self.timestepper is not None:
self.timestepper.t = fd.Constant(self.t - self.dt)

self.dt = self.level.params.dt
self.dt = fd.Constant(self.level.params.dt * self.n_steps)

# update stats and output
self.stats = {**self.stats, **_stats}
Expand Down
52 changes: 52 additions & 0 deletions pySDC/implementations/datatype_classes/firedrake_mesh.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import firedrake as fd

from pySDC.core.errors import DataError
from pySDC.helpers.firedrake_ensemble_communicator import FiredrakeEnsembleCommunicator


class firedrake_mesh(object):
Expand Down Expand Up @@ -77,6 +78,57 @@ def __abs__(self):

return fd.norm(self.functionspace, 'L2')

def isend(self, dest=None, tag=None, comm=None):
"""
Routine for sending data forward in time (non-blocking)

Args:
dest (int): target rank
tag (int): communication tag
comm: communicator

Returns:
request handle
"""
assert (
type(comm) == FiredrakeEnsembleCommunicator
), f'Need to give a FiredrakeEnsembleCommunicator here, not {type(comm)}'
return comm.Isend(self.functionspace, dest=dest, tag=tag)

def irecv(self, source=None, tag=None, comm=None):
"""
Routine for receiving in time

Args:
source (int): source rank
tag (int): communication tag
comm: communicator

Returns:
None
"""
assert (
type(comm) == FiredrakeEnsembleCommunicator
), f'Need to give a FiredrakeEnsembleCommunicator here, not {type(comm)}'
return comm.Irecv(self.functionspace, source=source, tag=tag)

def bcast(self, root=None, comm=None):
"""
Routine for broadcasting values

Args:
root (int): process with value to broadcast
comm: communicator

Returns:
broadcasted values
"""
assert (
type(comm) == FiredrakeEnsembleCommunicator
), f'Need to give a FiredrakeEnsembleCommunicator here, not {type(comm)}'
comm.Bcast(self.functionspace, root=root)
return self


class IMEX_firedrake_mesh(object):
"""
Expand Down
83 changes: 82 additions & 1 deletion pySDC/tests/test_datatypes/test_firedrake_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,5 +151,86 @@ def test_rmul_rhs(n=3, v1=1, v2=2):
assert np.allclose(b.expl.dat._numpy_data, v2 * v1)


def _test_p2p_communication(comm, u):
import numpy as np

assert comm.size == 2
if comm.rank == 0:
u.assign(3.14)
req = u.isend(dest=1, comm=comm, tag=0)
elif comm.rank == 1:
assert not np.allclose(u.dat._numpy_data, 3.14)
req = u.irecv(source=0, comm=comm, tag=0)
req.wait()
assert np.allclose(u.dat._numpy_data, 3.14)


def _test_bcast(comm, u):
import numpy as np

if comm.rank == 0:
u.assign(3.14)
else:
assert not np.allclose(u.dat._numpy_data, 3.14)
u.bcast(root=0, comm=comm)
assert np.allclose(u.dat._numpy_data, 3.14)


@pytest.mark.firedrake
@pytest.mark.parametrize('pattern', ['p2p', 'bcast'])
def test_communication(pattern, n=2, submit=True):
if submit:
import os
import subprocess

my_env = os.environ.copy()
my_env['COVERAGE_PROCESS_START'] = 'pyproject.toml'
cwd = '.'
num_procs = 2
cmd = f'mpiexec -np {num_procs} python {__file__} --pattern {pattern}'.split()

p = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=my_env, cwd=cwd)
p.wait()
for line in p.stdout:
print(line)
for line in p.stderr:
print(line)
assert p.returncode == 0, 'ERROR: did not get return code 0, got %s with %2i processes' % (
p.returncode,
num_procs,
)

else:
import firedrake as fd
from pySDC.helpers.firedrake_ensemble_communicator import FiredrakeEnsembleCommunicator
from pySDC.implementations.datatype_classes.firedrake_mesh import firedrake_mesh

ensemble_comm = FiredrakeEnsembleCommunicator(fd.COMM_WORLD, 1)

mesh = fd.UnitSquareMesh(n, n, comm=ensemble_comm.space_comm)
V = fd.VectorFunctionSpace(mesh, "CG", 2)

u = firedrake_mesh(V)

if pattern == 'p2p':
_test_p2p_communication(ensemble_comm, u)
elif pattern == 'bcast':
_test_bcast(ensemble_comm, u)
else:
raise NotImplementedError


if __name__ == '__main__':
test_addition()
from argparse import ArgumentParser

parser = ArgumentParser()
parser.add_argument(
'--pattern',
help="pattern for parallel tests",
type=str,
default=None,
)
args = parser.parse_args()

if args.pattern:
test_communication(pattern=args.pattern, submit=False)
Loading