Skip to content

Commit 018eb5f

Browse files
authored
mpi-pytest is now on pip (#4048)
1 parent e038d72 commit 018eb5f

File tree

5 files changed

+9
-28
lines changed

5 files changed

+9
-28
lines changed

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,11 +67,11 @@ spydump = "pyop2.scripts.spydump:main"
6767

6868
[project.optional-dependencies]
6969
test = [
70-
"pylit",
70+
"mpi-pytest",
7171
"nbval",
72+
"pylit",
7273
"pytest",
7374
"pytest-xdist",
74-
"pytest-mpi @ git+https://github.com/firedrakeproject/pytest-mpi.git@main",
7575
]
7676
dev = [
7777
"flake8",

requirements-ext.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ Cython>=3.0
33
decorator<=4.4.2
44
flake8
55
mpi4py
6+
mpi-pytest
67
nbval
78
numpy
89
packaging

requirements-git.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,5 @@ git+https://github.com/firedrakeproject/ufl.git#egg=fenics-ufl
22
git+https://github.com/firedrakeproject/fiat.git#egg=fenics-fiat
33
git+https://github.com/dolfin-adjoint/pyadjoint.git#egg=pyadjoint-ad
44
git+https://github.com/firedrakeproject/loopy.git@main#egg=loopy
5-
git+https://github.com/firedrakeproject/pytest-mpi.git@main#egg=pytest-mpi
65
git+https://github.com/firedrakeproject/petsc.git@firedrake#egg=petsc
76
git+https://github.com/firedrakeproject/libsupermesh.git#egg=libsupermesh

scripts/firedrake-run-split-tests

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
#
2525
# * pytest
2626
# * pytest-split
27-
# * pytest-mpi (https://github.com/firedrakeproject/pytest-mpi)
27+
# * mpi-pytest
2828
# * GNU parallel
2929

3030
num_procs=$1

tests/firedrake/regression/test_ensembleparallelism.py

Lines changed: 5 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from firedrake.petsc import DEFAULT_DIRECT_SOLVER_PARAMETERS
33
from pyop2.mpi import MPI
44
import pytest
5+
from pytest_mpi import parallel_assert
56

67
from operator import mul
78
from functools import reduce
@@ -25,24 +26,6 @@
2526
pytest.param(False, id="nonblocking")]
2627

2728

28-
def parallel_assert(assertion, subset=None, msg=""):
29-
""" Move this functionality to pytest-mpi
30-
"""
31-
if subset:
32-
if MPI.COMM_WORLD.rank in subset:
33-
evaluation = assertion()
34-
else:
35-
evaluation = True
36-
else:
37-
evaluation = assertion()
38-
all_results = MPI.COMM_WORLD.allgather(evaluation)
39-
if not all(all_results):
40-
raise AssertionError(
41-
"Parallel assertion failed on ranks: "
42-
f"{[ii for ii, b in enumerate(all_results) if not b]}\n" + msg
43-
)
44-
45-
4629
# unique profile on each mixed function component on each ensemble rank
4730
def function_profile(x, y, rank, cpt):
4831
return sin(cpt + (rank+1)*pi*x)*cos(cpt + (rank+1)*pi*y)
@@ -204,13 +187,13 @@ def test_ensemble_reduce(ensemble, mesh, W, urank, urank_sum, root, blocking):
204187
root_ranks = {ii + root*ensemble.comm.size for ii in range(ensemble.comm.size)}
205188
parallel_assert(
206189
lambda: error < 1e-12,
207-
subset=root_ranks,
190+
participating=COMM_WORLD.rank in root_ranks,
208191
msg=f"{error = :.5f}" # noqa: E203, E251
209192
)
210193
error = errornorm(Function(W).assign(10), u_reduce)
211194
parallel_assert(
212195
lambda: error < 1e-12,
213-
subset={range(COMM_WORLD.size)} - root_ranks,
196+
participating=COMM_WORLD.rank not in root_ranks,
214197
msg=f"{error = :.5f}" # noqa: E203, E251
215198
)
216199

@@ -221,9 +204,7 @@ def test_ensemble_reduce(ensemble, mesh, W, urank, urank_sum, root, blocking):
221204
with u_reduce.dat.vec as v:
222205
states[spatial_rank] = v.stateGet()
223206
ensemble.comm.Allgather(MPI.IN_PLACE, states)
224-
parallel_assert(
225-
lambda: len(set(states)) == 1,
226-
)
207+
parallel_assert(lambda: len(set(states)) == 1)
227208

228209

229210
@pytest.mark.parallel(nprocs=2)
@@ -346,7 +327,7 @@ def test_send_and_recv(ensemble, mesh, W, blocking):
346327
root_ranks |= {ii + rank1*ensemble.comm.size for ii in range(ensemble.comm.size)}
347328
parallel_assert(
348329
lambda: error < 1e-12,
349-
subset=root_ranks,
330+
participating=COMM_WORLD.rank in root_ranks,
350331
msg=f"{error = :.5f}" # noqa: E203, E251
351332
)
352333

0 commit comments

Comments
 (0)