@@ -95,28 +95,89 @@ def block_distribute(array:NDArray, rank:int, comm: MPI.Comm, pad:bool=False):
9595 if pad and (pr or pc ): block = np .pad (block , [(0 , pr ), (0 , pc )], mode = 'constant' )
9696 return block , (new_r , new_c )
9797
98- def local_block_spit (global_shape : Tuple [int , int ], rank : int , comm : MPI .Comm ) -> Tuple [slice , slice ]:
98+ def local_block_spit (global_shape : Tuple [int , int ],
99+ rank : int ,
100+ comm : MPI .Comm ) -> Tuple [slice , slice ]:
101+ """
102+ Compute the local sub‐block of a 2D global array for a process in a square process grid.
103+
104+ Parameters
105+ ----------
106+ global_shape : Tuple[int, int]
107+ Dimensions of the global 2D array (n_rows, n_cols).
108+ rank : int
109+ Rank of the MPI process in `comm` for which to get the owned block partition.
110+ comm : MPI.Comm
111+ MPI communicator whose total number of processes :math:`\mathbf{P}`
112+ must be a perfect square :math:`\mathbf{P} = \sqrt{\mathbf{P'}}`.
113+
114+ Returns
115+ -------
116+ Tuple[slice, slice]
117+ Two `slice` objects `(row_slice, col_slice)` indicating the sub‐block
118+ of the global array owned by this rank.
119+
120+ Raises
121+ ------
122+ ValueError
123+ if `rank` is out of range.
124+ RuntimeError
125+ If the number of processes participating in the provided communicator is not a perfect square.
126+ """
99127 size = comm .Get_size ()
100128 p_prime = math .isqrt (size )
101129 if p_prime * p_prime != size :
102- raise Exception (f"Number of processes must be a square number, provided { size } instead..." )
130+ raise RuntimeError (f"Number of processes must be a square number, provided { size } instead..." )
131+ if not ( isinstance (rank , int ) and 0 <= rank < size ):
132+ raise ValueError (f"rank must be integer in [0, { size } ), got { rank !r} " )
103133
104134 proc_i , proc_j = divmod (rank , p_prime )
105135 orig_r , orig_c = global_shape
136+
106137 new_r = math .ceil (orig_r / p_prime ) * p_prime
107138 new_c = math .ceil (orig_c / p_prime ) * p_prime
108139
109- br , bc = new_r // p_prime , new_c // p_prime
110- i0 , j0 = proc_i * br , proc_j * bc
111- i1 , j1 = min (i0 + br , orig_r ), min (j0 + bc , orig_c )
140+ blkr , blkc = new_r // p_prime , new_c // p_prime
112141
113- i_end = None if proc_i == p_prime - 1 else i1
114- j_end = None if proc_j == p_prime - 1 else j1
115- return slice (i0 , i_end ), slice (j0 , j_end )
142+ i0 , j0 = proc_i * blkr , proc_j * blkc
143+ i1 , j1 = min (i0 + blkr , orig_r ), min (j0 + blkc , orig_c )
144+
145+ return slice (i0 , i1 ), slice (j0 , j1 )
146+
147+
148+ def block_gather (x : DistributedArray , new_shape : Tuple [int , int ], orig_shape : Tuple [int , int ], comm : MPI .Comm ):
149+ """
150+ Gather distributed local blocks from 2D block distributed matrix distributed
151+ amongst a square process grid into the full global array.
152+
153+ Parameters
154+ ----------
155+ x : :obj:`pylops_mpi.DistributedArray`
156+ The distributed array to gather locally.
157+ new_shape : Tuple[int, int]
158+ Shape `(N', M')` of the padded global array, where both dimensions
159+ are multiples of :math:`\sqrt{\mathbf{P}}`.
160+ orig_shape : Tuple[int, int]
161+ Original shape `(N, M)` of the global array before padding.
162+ comm : MPI.Comm
163+ MPI communicator whose size must be a perfect square (P = p_prime**2).
164+
165+ Returns
166+ -------
167+ Array
168+ The reconstructed 2D array of shape `orig_shape`, assembled from
169+ the distributed blocks.
116170
117- def block_gather (x , new_shape , orig_shape , comm ):
171+ Raises
172+ ------
173+ RuntimeError
174+ If the number of processes participating in the provided communicator is not a perfect square.
175+ """
118176 ncp = get_module (x .engine )
119177 p_prime = math .isqrt (comm .Get_size ())
178+ if p_prime * p_prime != comm .Get_size ():
179+ raise RuntimeError (f"Communicator size must be a perfect square, got { comm .Get_size ()!r} " )
180+
120181 all_blks = comm .allgather (x .local_array )
121182
122183 nr , nc = new_shape
@@ -151,10 +212,14 @@ def block_gather(x, new_shape, orig_shape, comm):
151212 block = all_blks [rank ]
152213 if block .ndim == 1 :
153214 block = block .reshape (block_rows , block_cols )
154- C [start_row :start_row + block_rows , start_col :start_col + block_cols ] = block
215+ C [start_row :start_row + block_rows ,
216+ start_col :start_col + block_cols ] = block
217+
218+ # Trim off any padding
155219 return C [:orr , :orc ]
156220
157221
222+
158223class MPIMatrixMult (MPILinearOperator ):
159224 r"""MPI Matrix multiplication
160225
@@ -360,7 +425,7 @@ class MPISummaMatrixMult(MPILinearOperator):
360425 Implements distributed matrix-matrix multiplication using the SUMMA algorithm
361426 between a matrix :math:`\mathbf{A}` distributed over a 2D process grid and
362427 input model and data vectors, which are both interpreted as matrices
363- distributed in block-column fashion.
428+ distributed in block fashion wherein each process owns a tile of the matrix .
364429
365430 Parameters
366431 ----------
0 commit comments