33.. autoclass:: InterRankBoundaryInfo
44.. autoclass:: MPIBoundaryCommSetupHelper
55
6+ .. autofunction:: mpi_distribute
67.. autofunction:: get_partition_by_pymetis
78.. autofunction:: membership_list_to_sets
89.. autofunction:: get_connected_partitions
3435"""
3536
3637from dataclasses import dataclass
38+ from contextlib import contextmanager
3739import numpy as np
3840from typing import (
3941 List , Set , Union , Mapping , Hashable , cast , Sequence , TYPE_CHECKING )
6365import logging
6466logger = logging .getLogger (__name__ )
6567
66- TAG_BASE = 83411
67- TAG_DISTRIBUTE_MESHES = TAG_BASE + 1
68-
6968
7069# {{{ mesh distributor
7170
71+ @contextmanager
72+ def _duplicate_mpi_comm (mpi_comm ):
73+ dup_comm = mpi_comm .Dup ()
74+ try :
75+ yield dup_comm
76+ finally :
77+ dup_comm .Free ()
78+
79+
80+ def mpi_distribute (mpi_comm , source_data = None , source_rank = 0 ):
81+ """
82+ Distribute data to a set of processes.
83+
84+ :arg mpi_comm: An ``MPI.Intracomm``
85+ :arg source_data: A :class:`dict` mapping destination ranks to data to be sent.
86+ Only present on the source rank.
87+ :arg source_rank: The rank from which the data is being sent.
88+ :returns: The data local to the current process if there is any, otherwise
89+ *None*.
90+ """
91+ with _duplicate_mpi_comm (mpi_comm ) as mpi_comm :
92+ num_proc = mpi_comm .Get_size ()
93+ rank = mpi_comm .Get_rank ()
94+
95+ local_data = None
96+
97+ if rank == source_rank :
98+ if source_data is None :
99+ raise TypeError ("source rank has no data." )
100+
101+ sending_to = np .full (num_proc , False )
102+ for dest_rank in source_data .keys ():
103+ sending_to [dest_rank ] = True
104+
105+ mpi_comm .scatter (sending_to , root = source_rank )
106+
107+ reqs = []
108+ for dest_rank , data in source_data .items ():
109+ if dest_rank == rank :
110+ local_data = data
111+ logger .info ("rank %d: received data" , rank )
112+ else :
113+ reqs .append (mpi_comm .isend (data , dest = dest_rank ))
114+
115+ logger .info ("rank %d: sent all data" , rank )
116+
117+ from mpi4py import MPI
118+ MPI .Request .waitall (reqs )
119+
120+ else :
121+ receiving = mpi_comm .scatter (None , root = source_rank )
122+
123+ if receiving :
124+ local_data = mpi_comm .recv (source = source_rank )
125+ logger .info ("rank %d: received data" , rank )
126+
127+ return local_data
128+
129+
130+ # TODO: Deprecate?
72131class MPIMeshDistributor :
73132 """
74133 .. automethod:: is_mananger_rank
@@ -96,9 +155,7 @@ def send_mesh_parts(self, mesh, part_per_element, num_parts):
96155 Sends each partition to a different rank.
97156 Returns one partition that was not sent to any other rank.
98157 """
99- mpi_comm = self .mpi_comm
100- rank = mpi_comm .Get_rank ()
101- assert num_parts <= mpi_comm .Get_size ()
158+ assert num_parts <= self .mpi_comm .Get_size ()
102159
103160 assert self .is_mananger_rank ()
104161
@@ -108,38 +165,16 @@ def send_mesh_parts(self, mesh, part_per_element, num_parts):
108165 from meshmode .mesh .processing import partition_mesh
109166 parts = partition_mesh (mesh , part_num_to_elements )
110167
111- local_part = None
112-
113- reqs = []
114- for r , part in parts .items ():
115- if r == self .manager_rank :
116- local_part = part
117- else :
118- reqs .append (mpi_comm .isend (part , dest = r , tag = TAG_DISTRIBUTE_MESHES ))
119-
120- logger .info ("rank %d: sent all mesh partitions" , rank )
121- for req in reqs :
122- req .wait ()
123-
124- return local_part
168+ return mpi_distribute (
169+ self .mpi_comm , source_data = parts , source_rank = self .manager_rank )
125170
126171 def receive_mesh_part (self ):
127172 """
128173 Returns the mesh sent by the manager rank.
129174 """
130- mpi_comm = self .mpi_comm
131- rank = mpi_comm .Get_rank ()
132-
133175 assert not self .is_mananger_rank (), "Manager rank cannot receive mesh"
134176
135- from mpi4py import MPI
136- status = MPI .Status ()
137- result = self .mpi_comm .recv (
138- source = self .manager_rank , tag = TAG_DISTRIBUTE_MESHES ,
139- status = status )
140- logger .info ("rank %d: received local mesh (size = %d)" , rank , status .count )
141-
142- return result
177+ return mpi_distribute (self .mpi_comm , source_rank = self .manager_rank )
143178
144179# }}}
145180
0 commit comments