22from firedrake .petsc import DEFAULT_DIRECT_SOLVER_PARAMETERS
33from pyop2 .mpi import MPI
44import pytest
5+ from pytest_mpi import parallel_assert
56
67from operator import mul
78from functools import reduce
2526 pytest .param (False , id = "nonblocking" )]
2627
2728
28- def parallel_assert (assertion , subset = None , msg = "" ):
29- """ Move this functionality to pytest-mpi
30- """
31- if subset :
32- if MPI .COMM_WORLD .rank in subset :
33- evaluation = assertion ()
34- else :
35- evaluation = True
36- else :
37- evaluation = assertion ()
38- all_results = MPI .COMM_WORLD .allgather (evaluation )
39- if not all (all_results ):
40- raise AssertionError (
41- "Parallel assertion failed on ranks: "
42- f"{ [ii for ii , b in enumerate (all_results ) if not b ]} \n " + msg
43- )
44-
45-
4629# unique profile on each mixed function component on each ensemble rank
4730def function_profile (x , y , rank , cpt ):
4831 return sin (cpt + (rank + 1 )* pi * x )* cos (cpt + (rank + 1 )* pi * y )
@@ -204,13 +187,13 @@ def test_ensemble_reduce(ensemble, mesh, W, urank, urank_sum, root, blocking):
204187 root_ranks = {ii + root * ensemble .comm .size for ii in range (ensemble .comm .size )}
205188 parallel_assert (
206189 lambda : error < 1e-12 ,
207- subset = root_ranks ,
190+ participating = COMM_WORLD . rank in root_ranks ,
208191 msg = f"{ error = :.5f} " # noqa: E203, E251
209192 )
210193 error = errornorm (Function (W ).assign (10 ), u_reduce )
211194 parallel_assert (
212195 lambda : error < 1e-12 ,
213- subset = { range ( COMM_WORLD .size )} - root_ranks ,
196+ participating = COMM_WORLD .rank not in root_ranks ,
214197 msg = f"{ error = :.5f} " # noqa: E203, E251
215198 )
216199
@@ -221,9 +204,7 @@ def test_ensemble_reduce(ensemble, mesh, W, urank, urank_sum, root, blocking):
221204 with u_reduce .dat .vec as v :
222205 states [spatial_rank ] = v .stateGet ()
223206 ensemble .comm .Allgather (MPI .IN_PLACE , states )
224- parallel_assert (
225- lambda : len (set (states )) == 1 ,
226- )
207+ parallel_assert (lambda : len (set (states )) == 1 )
227208
228209
229210@pytest .mark .parallel (nprocs = 2 )
@@ -346,7 +327,7 @@ def test_send_and_recv(ensemble, mesh, W, blocking):
346327 root_ranks |= {ii + rank1 * ensemble .comm .size for ii in range (ensemble .comm .size )}
347328 parallel_assert (
348329 lambda : error < 1e-12 ,
349- subset = root_ranks ,
330+ participating = COMM_WORLD . rank in root_ranks ,
350331 msg = f"{ error = :.5f} " # noqa: E203, E251
351332 )
352333
0 commit comments