@@ -152,6 +152,64 @@ def __init__(
152152 shape = (int (np .prod (self .dimsd )), int (np .prod (self .dims )))
153153 super ().__init__ (shape = shape , dtype = np .dtype (dtype ), base_comm = base_comm )
154154
155+ @staticmethod
156+ def active_grid_comm (base_comm : MPI .Comm , N : int , M : int ):
157+ r"""Configure active grid
158+
159+ Configure a square process grid from a parent MPI communicator and
160+ select a subset of "active" processes. Each process in ``base_comm``
161+ is assigned to a logical 2D grid of size :math:`P' \times P'`,
162+ where :math:`P' = \bigl \lceil \sqrt{P} \bigr \rceil`. Only the first
163+ :math:`active_dim x active_dim` processes
164+ (by row-major order) are considered "active". Inactive ranks return
165+ immediately with no new communicator.
166+
167+ Parameters:
168+ -----------
169+ base_comm : :obj:`mpi4py.MPI.Comm`
170+ MPI Parent Communicator. (e.g., ``mpi4py.MPI.COMM_WORLD``).
171+ N : :obj:`int`
172+ Number of rows of the global data domain.
173+ M : :obj:`int`
174+ Number of columns of the global data domain.
175+
176+ Returns:
177+ --------
178+ comm : :obj:`mpi4py.MPI.Comm`
179+ Sub-communicator including only active ranks.
180+ rank : :obj:`int`
181+ Rank within the new sub-communicator (or original rank
182+ if inactive).
183+ row : :obj:`int`
184+ Grid row index of this process in the active grid (or original rank
185+ if inactive).
186+ col : :obj:`int`
187+ Grid column index of this process in the active grid
188+ (or original rank if inactive).
189+ is_active : :obj:`bool`
190+ Flag indicating whether this rank is in the active sub-grid.
191+
192+ """
193+ rank = base_comm .Get_rank ()
194+ size = base_comm .Get_size ()
195+ p_prime = math .isqrt (size )
196+ row , col = divmod (rank , p_prime )
197+ active_dim = min (N , M , p_prime )
198+ is_active = (row < active_dim and col < active_dim )
199+
200+ if not is_active :
201+ return None , rank , row , col , False
202+
203+ active_ranks = [r for r in range (size )
204+ if (r // p_prime ) < active_dim and (r % p_prime ) < active_dim ]
205+ new_group = base_comm .Get_group ().Incl (active_ranks )
206+ new_comm = base_comm .Create_group (new_group )
207+ p_prime_new = math .isqrt (len (active_ranks ))
208+ new_rank = new_comm .Get_rank ()
209+ new_row , new_col = divmod (new_rank , p_prime_new )
210+
211+ return new_comm , new_rank , new_row , new_col , True
212+
155213 @staticmethod
156214 def block_distribute (array , proc_i , proc_j , comm ):
157215 p_prime = math .isqrt (comm .Get_size ())
@@ -188,11 +246,12 @@ def _matvec(self, x: DistributedArray) -> DistributedArray:
188246 ncp = get_module (x .engine )
189247 if x .partition != Partition .SCATTER :
190248 raise ValueError (f"x should have partition={ Partition .SCATTER } Got { x .partition } instead..." )
191- y = DistributedArray (global_shape = (self .N // self ._P_prime , self .M * self ._P_prime ),
249+ local_shape = (self .N // self ._P_prime ) * ( self .M * self ._P_prime // self .size )
250+ y = DistributedArray (global_shape = ((self .N // self ._P_prime ) * self .M * self ._P_prime ),
192251 mask = x .mask ,
252+ local_shapes = [ local_shape for _ in range (self .size )],
193253 partition = Partition .SCATTER ,
194- dtype = self .dtype ,
195- axis = 1 )
254+ dtype = self .dtype )
196255
197256 x = x .local_array .reshape ((self .A .shape [1 ], - 1 ))
198257 c_local = np .zeros ((self .A .shape [0 ], x .shape [1 ]))
@@ -202,26 +261,47 @@ def _matvec(self, x: DistributedArray) -> DistributedArray:
202261 self ._row_comm .Bcast (Atemp , root = k )
203262 self ._col_comm .Bcast (Xtemp , root = k )
204263 c_local += ncp .dot (Atemp , Xtemp )
205- y [:] = c_local
264+ y [:] = c_local . flatten ()
206265 return y
207266
267+
208268 def _rmatvec (self , x : DistributedArray ) -> DistributedArray :
209269 ncp = get_module (x .engine )
210270 if x .partition != Partition .SCATTER :
211271 raise ValueError (f"x should have partition={ Partition .SCATTER } . Got { x .partition } instead." )
212- return None
213- # y = DistributedArray(
214- # global_shape=(self.K * self.dimsd[1]),
215- # local_shapes=[self.K * c for c in self._rank_col_lens],
216- # mask=x.mask,
217- # partition=Partition.SCATTER,
218- # dtype=self.dtype,
219- # )
220- #
221- # x_arr = x.local_array.reshape((self.N, self._local_ncols)).astype(self.dtype)
222- # X_tile = x_arr[self._row_start:self._row_end, :]
223- # A_local = self.At if hasattr(self, "At") else self.A.T.conj()
224- # Y_local = ncp.matmul(A_local, X_tile)
225- # y_layer = self._row_comm.allreduce(Y_local, op=MPI.SUM)
226- # y[:] = y_layer.flatten()
227- # return y
272+
273+ local_shape = (self .K // self ._P_prime ) * (self .M * self ._P_prime // self .size )
274+ y = DistributedArray (
275+ global_shape = ((self .K // self ._P_prime ) * self .M * self ._P_prime ),
276+ mask = x .mask ,
277+ local_shapes = [local_shape for _ in range (self .size )],
278+ partition = Partition .SCATTER ,
279+ dtype = self .dtype ,
280+ )
281+ x_reshaped = x .local_array .reshape ((self .A .shape [0 ], - 1 ))
282+ A_local = self .At if hasattr (self , "At" ) else self .A .T .conj ()
283+ c_local = np .zeros ((self .A .shape [1 ], x_reshaped .shape [1 ]))
284+ P = self ._P_prime
285+
286+ for k in range (P ):
287+ temps = {}
288+ requests = []
289+ for buf , owner , base , name in (
290+ (A_local , self ._row_id , 100 , 'A' ),
291+ (x_reshaped , self ._col_id , 200 , 'B' ),
292+ ):
293+ tmp = np .empty_like (buf )
294+ temps [name ] = tmp
295+ src , tag = k * P + owner , (base + k ) * 1000 + self .rank
296+ requests .append (self .base_comm .Irecv (tmp , source = src , tag = tag ))
297+
298+ if self .rank // P == k :
299+ fixed = self .rank % P
300+ for moving in range (P ):
301+ dest = (fixed * P + moving ) if name == 'A' else moving * P + fixed
302+ tag = (base + k ) * 1000 + dest
303+ requests .append (self .base_comm .Isend (buf , dest = dest , tag = tag ))
304+ MPI .Request .Waitall (requests )
305+ c_local += ncp .dot (temps ['A' ], temps ['B' ])
306+ y [:] = c_local .flatten ()
307+ return y
0 commit comments