Skip to content

Commit 3b1bfdd

Browse files
committed
add mpi_distribute
1 parent d2366ac commit 3b1bfdd

File tree

1 file changed

+66
-31
lines changed

1 file changed

+66
-31
lines changed

meshmode/distributed.py

Lines changed: 66 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
.. autoclass:: InterRankBoundaryInfo
44
.. autoclass:: MPIBoundaryCommSetupHelper
55
6+
.. autofunction:: mpi_distribute
67
.. autofunction:: get_partition_by_pymetis
78
.. autofunction:: membership_list_to_map
89
.. autofunction:: get_connected_partitions
@@ -37,6 +38,7 @@
3738
"""
3839

3940
from dataclasses import dataclass
41+
from contextlib import contextmanager
4042
import numpy as np
4143
from typing import List, Set, Union, Mapping, cast, Sequence, TYPE_CHECKING
4244

@@ -66,12 +68,69 @@
6668
import logging
6769
logger = logging.getLogger(__name__)
6870

69-
TAG_BASE = 83411
70-
TAG_DISTRIBUTE_MESHES = TAG_BASE + 1
71-
7271

7372
# {{{ mesh distributor
7473

74+
@contextmanager
75+
def _duplicate_mpi_comm(mpi_comm):
76+
dup_comm = mpi_comm.Dup()
77+
try:
78+
yield dup_comm
79+
finally:
80+
dup_comm.Free()
81+
82+
83+
def mpi_distribute(mpi_comm, source_data=None, source_rank=0):
84+
"""
85+
Distribute data to a set of processes.
86+
87+
:arg mpi_comm: An ``MPI.Intracomm``
88+
:arg source_data: A :class:`dict` mapping destination ranks to data to be sent.
89+
Only present on the source rank.
90+
:arg source_rank: The rank from which the data is being sent.
91+
:returns: The data local to the current process if there is any, otherwise
92+
*None*.
93+
"""
94+
with _duplicate_mpi_comm(mpi_comm) as mpi_comm:
95+
num_proc = mpi_comm.Get_size()
96+
rank = mpi_comm.Get_rank()
97+
98+
local_data = None
99+
100+
if rank == source_rank:
101+
if source_data is None:
102+
raise TypeError("source rank has no data.")
103+
104+
sending_to = np.full(num_proc, False)
105+
for dest_rank in source_data.keys():
106+
sending_to[dest_rank] = True
107+
108+
mpi_comm.scatter(sending_to, root=source_rank)
109+
110+
reqs = []
111+
for dest_rank, data in source_data.items():
112+
if dest_rank == rank:
113+
local_data = data
114+
logger.info("rank %d: received data", rank)
115+
else:
116+
reqs.append(mpi_comm.isend(data, dest=dest_rank))
117+
118+
logger.info("rank %d: sent all data", rank)
119+
120+
from mpi4py import MPI
121+
MPI.Request.waitall(reqs)
122+
123+
else:
124+
receiving = mpi_comm.scatter(None, root=source_rank)
125+
126+
if receiving:
127+
local_data = mpi_comm.recv(source=source_rank)
128+
logger.info("rank %d: received data", rank)
129+
130+
return local_data
131+
132+
133+
# TODO: Deprecate?
75134
class MPIMeshDistributor:
76135
"""
77136
.. automethod:: is_mananger_rank
@@ -99,9 +158,7 @@ def send_mesh_parts(self, mesh, part_per_element, num_parts):
99158
Sends each partition to a different rank.
100159
Returns one partition that was not sent to any other rank.
101160
"""
102-
mpi_comm = self.mpi_comm
103-
rank = mpi_comm.Get_rank()
104-
assert num_parts <= mpi_comm.Get_size()
161+
assert num_parts <= self.mpi_comm.Get_size()
105162

106163
assert self.is_mananger_rank()
107164

@@ -110,38 +167,16 @@ def send_mesh_parts(self, mesh, part_per_element, num_parts):
110167
from meshmode.mesh.processing import partition_mesh
111168
parts = partition_mesh(mesh, part_num_to_elements)
112169

113-
local_part = None
114-
115-
reqs = []
116-
for r, part in parts.items():
117-
if r == self.manager_rank:
118-
local_part = part
119-
else:
120-
reqs.append(mpi_comm.isend(part, dest=r, tag=TAG_DISTRIBUTE_MESHES))
121-
122-
logger.info("rank %d: sent all mesh partitions", rank)
123-
for req in reqs:
124-
req.wait()
125-
126-
return local_part
170+
return mpi_distribute(
171+
self.mpi_comm, source_data=parts, source_rank=self.manager_rank)
127172

128173
def receive_mesh_part(self):
129174
"""
130175
Returns the mesh sent by the manager rank.
131176
"""
132-
mpi_comm = self.mpi_comm
133-
rank = mpi_comm.Get_rank()
134-
135177
assert not self.is_mananger_rank(), "Manager rank cannot receive mesh"
136178

137-
from mpi4py import MPI
138-
status = MPI.Status()
139-
result = self.mpi_comm.recv(
140-
source=self.manager_rank, tag=TAG_DISTRIBUTE_MESHES,
141-
status=status)
142-
logger.info("rank %d: received local mesh (size = %d)", rank, status.count)
143-
144-
return result
179+
return mpi_distribute(self.mpi_comm, source_rank=self.manager_rank)
145180

146181
# }}}
147182

0 commit comments

Comments
 (0)