Skip to content

Commit c7e6638

Browse files
committed
add mpi_distribute
1 parent 103f3ae commit c7e6638

File tree

1 file changed

+66
-31
lines changed

1 file changed

+66
-31
lines changed

meshmode/distributed.py

Lines changed: 66 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
.. autoclass:: InterRankBoundaryInfo
44
.. autoclass:: MPIBoundaryCommSetupHelper
55
6+
.. autofunction:: mpi_distribute
67
.. autofunction:: get_partition_by_pymetis
78
.. autofunction:: membership_list_to_sets
89
.. autofunction:: get_connected_partitions
@@ -34,6 +35,7 @@
3435
"""
3536

3637
from dataclasses import dataclass
38+
from contextlib import contextmanager
3739
import numpy as np
3840
from typing import (
3941
List, Set, Union, Mapping, Hashable, cast, Sequence, TYPE_CHECKING)
@@ -63,12 +65,69 @@
6365
import logging
6466
logger = logging.getLogger(__name__)
6567

66-
TAG_BASE = 83411
67-
TAG_DISTRIBUTE_MESHES = TAG_BASE + 1
68-
6968

7069
# {{{ mesh distributor
7170

71+
@contextmanager
72+
def _duplicate_mpi_comm(mpi_comm):
73+
dup_comm = mpi_comm.Dup()
74+
try:
75+
yield dup_comm
76+
finally:
77+
dup_comm.Free()
78+
79+
80+
def mpi_distribute(mpi_comm, source_data=None, source_rank=0):
81+
"""
82+
Distribute data to a set of processes.
83+
84+
:arg mpi_comm: An ``MPI.Intracomm``
85+
:arg source_data: A :class:`dict` mapping destination ranks to data to be sent.
86+
Only present on the source rank.
87+
:arg source_rank: The rank from which the data is being sent.
88+
:returns: The data local to the current process if there is any, otherwise
89+
*None*.
90+
"""
91+
with _duplicate_mpi_comm(mpi_comm) as mpi_comm:
92+
num_proc = mpi_comm.Get_size()
93+
rank = mpi_comm.Get_rank()
94+
95+
local_data = None
96+
97+
if rank == source_rank:
98+
if source_data is None:
99+
raise TypeError("source rank has no data.")
100+
101+
sending_to = np.full(num_proc, False)
102+
for dest_rank in source_data.keys():
103+
sending_to[dest_rank] = True
104+
105+
mpi_comm.scatter(sending_to, root=source_rank)
106+
107+
reqs = []
108+
for dest_rank, data in source_data.items():
109+
if dest_rank == rank:
110+
local_data = data
111+
logger.info("rank %d: received data", rank)
112+
else:
113+
reqs.append(mpi_comm.isend(data, dest=dest_rank))
114+
115+
logger.info("rank %d: sent all data", rank)
116+
117+
from mpi4py import MPI
118+
MPI.Request.waitall(reqs)
119+
120+
else:
121+
receiving = mpi_comm.scatter(None, root=source_rank)
122+
123+
if receiving:
124+
local_data = mpi_comm.recv(source=source_rank)
125+
logger.info("rank %d: received data", rank)
126+
127+
return local_data
128+
129+
130+
# TODO: Deprecate?
72131
class MPIMeshDistributor:
73132
"""
74133
.. automethod:: is_mananger_rank
@@ -96,9 +155,7 @@ def send_mesh_parts(self, mesh, part_per_element, num_parts):
96155
Sends each partition to a different rank.
97156
Returns one partition that was not sent to any other rank.
98157
"""
99-
mpi_comm = self.mpi_comm
100-
rank = mpi_comm.Get_rank()
101-
assert num_parts <= mpi_comm.Get_size()
158+
assert num_parts <= self.mpi_comm.Get_size()
102159

103160
assert self.is_mananger_rank()
104161

@@ -108,38 +165,16 @@ def send_mesh_parts(self, mesh, part_per_element, num_parts):
108165
from meshmode.mesh.processing import partition_mesh
109166
parts = partition_mesh(mesh, part_num_to_elements)
110167

111-
local_part = None
112-
113-
reqs = []
114-
for r, part in parts.items():
115-
if r == self.manager_rank:
116-
local_part = part
117-
else:
118-
reqs.append(mpi_comm.isend(part, dest=r, tag=TAG_DISTRIBUTE_MESHES))
119-
120-
logger.info("rank %d: sent all mesh partitions", rank)
121-
for req in reqs:
122-
req.wait()
123-
124-
return local_part
168+
return mpi_distribute(
169+
self.mpi_comm, source_data=parts, source_rank=self.manager_rank)
125170

126171
def receive_mesh_part(self):
127172
"""
128173
Returns the mesh sent by the manager rank.
129174
"""
130-
mpi_comm = self.mpi_comm
131-
rank = mpi_comm.Get_rank()
132-
133175
assert not self.is_mananger_rank(), "Manager rank cannot receive mesh"
134176

135-
from mpi4py import MPI
136-
status = MPI.Status()
137-
result = self.mpi_comm.recv(
138-
source=self.manager_rank, tag=TAG_DISTRIBUTE_MESHES,
139-
status=status)
140-
logger.info("rank %d: received local mesh (size = %d)", rank, status.count)
141-
142-
return result
177+
return mpi_distribute(self.mpi_comm, source_rank=self.manager_rank)
143178

144179
# }}}
145180

0 commit comments

Comments
 (0)