Skip to content

Commit 56e9414

Browse files
committed
Added docstring
1 parent 1ef09ab commit 56e9414

File tree

1 file changed

+76
-11
lines changed

1 file changed

+76
-11
lines changed

pylops_mpi/basicoperators/MatrixMult.py

Lines changed: 76 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -95,28 +95,89 @@ def block_distribute(array:NDArray, rank:int, comm: MPI.Comm, pad:bool=False):
9595
if pad and (pr or pc): block = np.pad(block, [(0, pr), (0, pc)], mode='constant')
9696
return block, (new_r, new_c)
9797

98-
def local_block_spit(global_shape: Tuple[int, int], rank: int, comm: MPI.Comm) -> Tuple[slice, slice]:
98+
def local_block_spit(global_shape: Tuple[int, int],
99+
rank: int,
100+
comm: MPI.Comm) -> Tuple[slice, slice]:
101+
"""
102+
Compute the local sub‐block of a 2D global array for a process in a square process grid.
103+
104+
Parameters
105+
----------
106+
global_shape : Tuple[int, int]
107+
Dimensions of the global 2D array (n_rows, n_cols).
108+
rank : int
109+
Rank of the MPI process in `comm` for which to get the owned block partition.
110+
comm : MPI.Comm
111+
MPI communicator whose total number of processes :math:`\mathbf{P}`
112+
must be a perfect square :math:`\mathbf{P} = \sqrt{\mathbf{P'}}`.
113+
114+
Returns
115+
-------
116+
Tuple[slice, slice]
117+
Two `slice` objects `(row_slice, col_slice)` indicating the sub‐block
118+
of the global array owned by this rank.
119+
120+
Raises
121+
------
122+
ValueError
123+
if `rank` is out of range.
124+
RuntimeError
125+
If the number of processes participating in the provided communicator is not a perfect square.
126+
"""
99127
size = comm.Get_size()
100128
p_prime = math.isqrt(size)
101129
if p_prime * p_prime != size:
102-
raise Exception(f"Number of processes must be a square number, provided {size} instead...")
130+
raise RuntimeError(f"Number of processes must be a square number, provided {size} instead...")
131+
if not ( isinstance(rank, int) and 0 <= rank < size ):
132+
raise ValueError(f"rank must be integer in [0, {size}), got {rank!r}")
103133

104134
proc_i, proc_j = divmod(rank, p_prime)
105135
orig_r, orig_c = global_shape
136+
106137
new_r = math.ceil(orig_r / p_prime) * p_prime
107138
new_c = math.ceil(orig_c / p_prime) * p_prime
108139

109-
br, bc = new_r // p_prime, new_c // p_prime
110-
i0, j0 = proc_i * br, proc_j * bc
111-
i1, j1 = min(i0 + br, orig_r), min(j0 + bc, orig_c)
140+
blkr, blkc = new_r // p_prime, new_c // p_prime
112141

113-
i_end = None if proc_i == p_prime - 1 else i1
114-
j_end = None if proc_j == p_prime - 1 else j1
115-
return slice(i0, i_end), slice(j0, j_end)
142+
i0, j0 = proc_i * blkr, proc_j * blkc
143+
i1, j1 = min(i0 + blkr, orig_r), min(j0 + blkc, orig_c)
144+
145+
return slice(i0, i1), slice(j0, j1)
146+
147+
148+
def block_gather(x: DistributedArray, new_shape: Tuple[int, int], orig_shape: Tuple[int, int], comm: MPI.Comm):
149+
"""
150+
Gather distributed local blocks from 2D block distributed matrix distributed
151+
amongst a square process grid into the full global array.
152+
153+
Parameters
154+
----------
155+
x : :obj:`pylops_mpi.DistributedArray`
156+
The distributed array to gather locally.
157+
new_shape : Tuple[int, int]
158+
Shape `(N', M')` of the padded global array, where both dimensions
159+
are multiples of :math:`\sqrt{\mathbf{P}}`.
160+
orig_shape : Tuple[int, int]
161+
Original shape `(N, M)` of the global array before padding.
162+
comm : MPI.Comm
163+
MPI communicator whose size must be a perfect square (P = p_prime**2).
164+
165+
Returns
166+
-------
167+
Array
168+
The reconstructed 2D array of shape `orig_shape`, assembled from
169+
the distributed blocks.
116170
117-
def block_gather(x, new_shape, orig_shape, comm):
171+
Raises
172+
------
173+
RuntimeError
174+
If the number of processes participating in the provided communicator is not a perfect square.
175+
"""
118176
ncp = get_module(x.engine)
119177
p_prime = math.isqrt(comm.Get_size())
178+
if p_prime * p_prime != comm.Get_size():
179+
raise RuntimeError(f"Communicator size must be a perfect square, got {comm.Get_size()!r}")
180+
120181
all_blks = comm.allgather(x.local_array)
121182

122183
nr, nc = new_shape
@@ -151,10 +212,14 @@ def block_gather(x, new_shape, orig_shape, comm):
151212
block = all_blks[rank]
152213
if block.ndim == 1:
153214
block = block.reshape(block_rows, block_cols)
154-
C[start_row:start_row + block_rows, start_col:start_col + block_cols] = block
215+
C[start_row:start_row + block_rows,
216+
start_col:start_col + block_cols] = block
217+
218+
# Trim off any padding
155219
return C[:orr, :orc]
156220

157221

222+
158223
class MPIMatrixMult(MPILinearOperator):
159224
r"""MPI Matrix multiplication
160225
@@ -360,7 +425,7 @@ class MPISummaMatrixMult(MPILinearOperator):
360425
Implements distributed matrix-matrix multiplication using the SUMMA algorithm
361426
between a matrix :math:`\mathbf{A}` distributed over a 2D process grid and
362427
input model and data vectors, which are both interpreted as matrices
363-
distributed in block-column fashion.
428+
distributed in block fashion wherein each process owns a tile of the matrix.
364429
365430
Parameters
366431
----------

0 commit comments

Comments
 (0)