@@ -190,30 +190,41 @@ def _matvec(self, x: DistributedArray) -> DistributedArray:
190190
191191 @staticmethod
192192 def active_grid_comm (base_comm :MPI .Comm , N :int , M :int ):
193- """
194- Configure a square process grid from a parent MPI communicator and select the subset of "active" processes.
193+ r"""Configure active grid
195194
196- Each process in base_comm is assigned to a logical 2D grid of size p_prime x p_prime,
197- where p_prime = floor(sqrt(total_ranks)). Only the first `active_dim x active_dim` processes
198- (by row-major order) are considered "active". Inactive ranks return immediately with no new communicator.
195+ Configure a square process grid from a parent MPI communicator and
196+ select the subset of "active" processes. Each process in ``base_comm``
197+ is assigned to a logical 2D grid of size :math:`P' \times P'`,
198+ where :math:`P' = \bigl \lceil \sqrt{P} \bigr \rceil`. Only the first
199+ :math:`active_dim x active_dim` processes
200+ (by row-major order) are considered "active". Inactive ranks return
201+ immediately with no new communicator.
199202
200203 Parameters:
201204 -----------
202- base_comm : MPI.Comm
203- The parent communicator (e.g., MPI.COMM_WORLD).
204- N : int
205- Number of rows of your global data domain.
206- M : int
207- Number of columns of your global data domain.
205+ base_comm : :obj:`mpi4py. MPI.Comm`
206+ MPI Parent Communicator. (e.g., ``mpi4py. MPI.COMM_WORLD`` ).
207+ N : :obj:` int`
208+ Number of rows of the global data domain.
209+ M : :obj:` int`
210+ Number of columns of the global data domain.
208211
209212 Returns:
210213 --------
211- tuple:
212- comm (MPI.Comm or None) : Sub-communicator including only active ranks.
213- rank (int) : Rank within the new sub-communicator (or original rank if inactive).
214- row (int) : Grid row index of this process in the active grid (or original rank if inactive).
215- col (int) : Grid column index of this process in the active grid (or original rank if inactive).
216- is_active (bool) : Flag indicating whether this rank is in the active sub-grid.
214+ comm : :obj:`mpi4py.MPI.Comm`
215+ Sub-communicator including only active ranks.
216+ rank : :obj:`int`
217+ Rank within the new sub-communicator (or original rank
218+ if inactive).
219+ row : :obj:`int`
220+ Grid row index of this process in the active grid (or original rank
221+ if inactive).
222+ col : :obj:`int`
223+ Grid column index of this process in the active grid
224+ (or original rank if inactive).
225+ is_active : :obj:`bool`
226+ Flag indicating whether this rank is in the active sub-grid.
227+
217228 """
218229 rank = base_comm .Get_rank ()
219230 size = base_comm .Get_size ()
@@ -229,10 +240,10 @@ def active_grid_comm(base_comm:MPI.Comm, N:int, M:int):
229240 if (r // p_prime ) < active_dim and (r % p_prime ) < active_dim ]
230241 new_group = base_comm .Get_group ().Incl (active_ranks )
231242 new_comm = base_comm .Create_group (new_group )
232-
233243 p_prime_new = math .isqrt (len (active_ranks ))
234244 new_rank = new_comm .Get_rank ()
235245 new_row , new_col = divmod (new_rank , p_prime_new )
246+
236247 return new_comm , new_rank , new_row , new_col , True
237248
238249 def _rmatvec (self , x : DistributedArray ) -> DistributedArray :
0 commit comments