33.. autoclass:: InterRankBoundaryInfo
44.. autoclass:: MPIBoundaryCommSetupHelper
55
6+ .. autofunction:: mpi_distribute
67.. autofunction:: get_partition_by_pymetis
78.. autofunction:: membership_list_to_map
89.. autofunction:: get_connected_partitions
3738"""
3839
3940from dataclasses import dataclass
41+ from contextlib import contextmanager
4042import numpy as np
4143from typing import List , Set , Union , Mapping , cast , Sequence , TYPE_CHECKING
4244
6668import logging
6769logger = logging .getLogger (__name__ )
6870
69- TAG_BASE = 83411
70- TAG_DISTRIBUTE_MESHES = TAG_BASE + 1
71-
7271
7372# {{{ mesh distributor
7473
74+ @contextmanager
75+ def _duplicate_mpi_comm (mpi_comm ):
76+ dup_comm = mpi_comm .Dup ()
77+ try :
78+ yield dup_comm
79+ finally :
80+ dup_comm .Free ()
81+
82+
83+ def mpi_distribute (mpi_comm , source_data = None , source_rank = 0 ):
84+ """
85+ Distribute data to a set of processes.
86+
87+ :arg mpi_comm: An ``MPI.Intracomm``
88+ :arg source_data: A :class:`dict` mapping destination ranks to data to be sent.
89+ Only present on the source rank.
90+ :arg source_rank: The rank from which the data is being sent.
91+ :returns: The data local to the current process if there is any, otherwise
92+ *None*.
93+ """
94+ with _duplicate_mpi_comm (mpi_comm ) as mpi_comm :
95+ num_proc = mpi_comm .Get_size ()
96+ rank = mpi_comm .Get_rank ()
97+
98+ local_data = None
99+
100+ if rank == source_rank :
101+ if source_data is None :
102+ raise TypeError ("source rank has no data." )
103+
104+ sending_to = np .full (num_proc , False )
105+ for dest_rank in source_data .keys ():
106+ sending_to [dest_rank ] = True
107+
108+ mpi_comm .scatter (sending_to , root = source_rank )
109+
110+ reqs = []
111+ for dest_rank , data in source_data .items ():
112+ if dest_rank == rank :
113+ local_data = data
114+ logger .info ("rank %d: received data" , rank )
115+ else :
116+ reqs .append (mpi_comm .isend (data , dest = dest_rank ))
117+
118+ logger .info ("rank %d: sent all data" , rank )
119+
120+ from mpi4py import MPI
121+ MPI .Request .waitall (reqs )
122+
123+ else :
124+ receiving = mpi_comm .scatter (None , root = source_rank )
125+
126+ if receiving :
127+ local_data = mpi_comm .recv (source = source_rank )
128+ logger .info ("rank %d: received data" , rank )
129+
130+ return local_data
131+
132+
133+ # TODO: Deprecate?
75134class MPIMeshDistributor :
76135 """
77136 .. automethod:: is_mananger_rank
@@ -99,9 +158,7 @@ def send_mesh_parts(self, mesh, part_per_element, num_parts):
99158 Sends each partition to a different rank.
100159 Returns one partition that was not sent to any other rank.
101160 """
102- mpi_comm = self .mpi_comm
103- rank = mpi_comm .Get_rank ()
104- assert num_parts <= mpi_comm .Get_size ()
161+ assert num_parts <= self .mpi_comm .Get_size ()
105162
106163 assert self .is_mananger_rank ()
107164
@@ -110,38 +167,16 @@ def send_mesh_parts(self, mesh, part_per_element, num_parts):
110167 from meshmode .mesh .processing import partition_mesh
111168 parts = partition_mesh (mesh , part_num_to_elements )
112169
113- local_part = None
114-
115- reqs = []
116- for r , part in parts .items ():
117- if r == self .manager_rank :
118- local_part = part
119- else :
120- reqs .append (mpi_comm .isend (part , dest = r , tag = TAG_DISTRIBUTE_MESHES ))
121-
122- logger .info ("rank %d: sent all mesh partitions" , rank )
123- for req in reqs :
124- req .wait ()
125-
126- return local_part
170+ return mpi_distribute (
171+ self .mpi_comm , source_data = parts , source_rank = self .manager_rank )
127172
128173 def receive_mesh_part (self ):
129174 """
130175 Returns the mesh sent by the manager rank.
131176 """
132- mpi_comm = self .mpi_comm
133- rank = mpi_comm .Get_rank ()
134-
135177 assert not self .is_mananger_rank (), "Manager rank cannot receive mesh"
136178
137- from mpi4py import MPI
138- status = MPI .Status ()
139- result = self .mpi_comm .recv (
140- source = self .manager_rank , tag = TAG_DISTRIBUTE_MESHES ,
141- status = status )
142- logger .info ("rank %d: received local mesh (size = %d)" , rank , status .count )
143-
144- return result
179+ return mpi_distribute (self .mpi_comm , source_rank = self .manager_rank )
145180
146181# }}}
147182
0 commit comments