Skip to content

Commit 4c662d6

Browse files
committed
minor: stylistic fixes
1 parent 9aedd7c commit 4c662d6

File tree

3 files changed

+43
-26
lines changed

3 files changed

+43
-26
lines changed

examples/plot_matrixmult.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
X = np.random.rand(K * M).astype(dtype=np.float32).reshape(K, M)
4545

4646
################################################################################
47-
# The processes are now arranged in a :math:`\sqrt{P} \times \sqrt{P}` grid,
47+
# The processes are now arranged in a :math:`P' \times P'` grid,
4848
# where :math:`P` is the total number of processes.
4949
#
5050
# We define
@@ -78,14 +78,20 @@
7878
# │ (r=1, c=0) │ (r=1, c=1) │
7979
# └────────────┴────────────┘
8080
# </div>
81+
#
82+
# This is obtained by invoking the
83+
# `:func:pylops_mpi.MPIMatrixMult.active_grid_comm` method, which is also
84+
# responsible to identify any rank that should be deactivated (if the number
85+
# of rows of the operator or columns of the input/output matrices are smaller
86+
# than the row or columm ranks.
8187

8288
base_comm = MPI.COMM_WORLD
8389
comm, rank, row_id, col_id, is_active = MPIMatrixMult.active_grid_comm(base_comm, N, M)
8490
print(f"Process {base_comm.Get_rank()} is {"active" if is_active else "inactive"}")
8591
if not is_active: exit(0)
86-
p_prime = math.isqrt(comm.Get_size())
8792

8893
# Create sub‐communicators
94+
p_prime = math.isqrt(comm.Get_size())
8995
row_comm = comm.Split(color=row_id, key=col_id) # all procs in same row
9096
col_comm = comm.Split(color=col_id, key=row_id) # all procs in same col
9197

@@ -127,11 +133,11 @@
127133

128134
rs = col_id * blk_rows
129135
re = min(N, rs + blk_rows)
130-
my_own_rows = max(0,re - rs)
136+
my_own_rows = max(0, re - rs)
131137

132138
cs = row_id * blk_cols
133139
ce = min(M, cs + blk_cols)
134-
my_own_cols = max(0,ce - cs)
140+
my_own_cols = max(0, ce - cs)
135141

136142
A_p, X_p = A[rs:re, :].copy(), X[:, cs:ce].copy()
137143

@@ -191,14 +197,14 @@
191197
for cnt in col_counts:
192198
block_size = K * cnt
193199
xadj_blk = xadj[offset: offset + block_size]
194-
if len(xadj_blk)!= 0:
200+
if len(xadj_blk) != 0:
195201
xadj_blocks.append(
196202
xadj_blk.reshape(K, cnt)
197203
)
198204
offset += block_size
199205
xadj = np.hstack(xadj_blocks)
200206

201-
if comm.Get_rank() == 0:
207+
if rank == 0:
202208
y_loc = (A @ X).squeeze()
203209
xadj_loc = (A.T.dot(y_loc.conj())).conj().squeeze()
204210

pylops_mpi/basicoperators/MatrixMult.py

Lines changed: 29 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -190,30 +190,41 @@ def _matvec(self, x: DistributedArray) -> DistributedArray:
190190

191191
@staticmethod
192192
def active_grid_comm(base_comm:MPI.Comm, N:int, M:int):
193-
"""
194-
Configure a square process grid from a parent MPI communicator and select the subset of "active" processes.
193+
r"""Configure active grid
195194
196-
Each process in base_comm is assigned to a logical 2D grid of size p_prime x p_prime,
197-
where p_prime = floor(sqrt(total_ranks)). Only the first `active_dim x active_dim` processes
198-
(by row-major order) are considered "active". Inactive ranks return immediately with no new communicator.
195+
Configure a square process grid from a parent MPI communicator and
196+
select the subset of "active" processes. Each process in ``base_comm``
197+
is assigned to a logical 2D grid of size :math:`P' \times P'`,
198+
where :math:`P' = \bigl \lceil \sqrt{P} \bigr \rceil`. Only the first
199+
:math:`active_dim x active_dim` processes
200+
(by row-major order) are considered "active". Inactive ranks return
201+
immediately with no new communicator.
199202
200203
Parameters:
201204
-----------
202-
base_comm : MPI.Comm
203-
The parent communicator (e.g., MPI.COMM_WORLD).
204-
N : int
205-
Number of rows of your global data domain.
206-
M : int
207-
Number of columns of your global data domain.
205+
base_comm : :obj:`mpi4py.MPI.Comm`
206+
MPI Parent Communicator. (e.g., ``mpi4py.MPI.COMM_WORLD``).
207+
N : :obj:`int`
208+
Number of rows of the global data domain.
209+
M : :obj:`int`
210+
Number of columns of the global data domain.
208211
209212
Returns:
210213
--------
211-
tuple:
212-
comm (MPI.Comm or None) : Sub-communicator including only active ranks.
213-
rank (int) : Rank within the new sub-communicator (or original rank if inactive).
214-
row (int) : Grid row index of this process in the active grid (or original rank if inactive).
215-
col (int) : Grid column index of this process in the active grid (or original rank if inactive).
216-
is_active (bool) : Flag indicating whether this rank is in the active sub-grid.
214+
comm : :obj:`mpi4py.MPI.Comm`
215+
Sub-communicator including only active ranks.
216+
rank : :obj:`int`
217+
Rank within the new sub-communicator (or original rank
218+
if inactive).
219+
row : :obj:`int`
220+
Grid row index of this process in the active grid (or original rank
221+
if inactive).
222+
col : :obj:`int`
223+
Grid column index of this process in the active grid
224+
(or original rank if inactive).
225+
is_active : :obj:`bool`
226+
Flag indicating whether this rank is in the active sub-grid.
227+
217228
"""
218229
rank = base_comm.Get_rank()
219230
size = base_comm.Get_size()
@@ -229,10 +240,10 @@ def active_grid_comm(base_comm:MPI.Comm, N:int, M:int):
229240
if (r // p_prime) < active_dim and (r % p_prime) < active_dim]
230241
new_group = base_comm.Get_group().Incl(active_ranks)
231242
new_comm = base_comm.Create_group(new_group)
232-
233243
p_prime_new = math.isqrt(len(active_ranks))
234244
new_rank = new_comm.Get_rank()
235245
new_row, new_col = divmod(new_rank, p_prime_new)
246+
236247
return new_comm, new_rank, new_row, new_col, True
237248

238249
def _rmatvec(self, x: DistributedArray) -> DistributedArray:

tests/test_matrixmult.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
np.random.seed(42)
1515
base_comm = MPI.COMM_WORLD
16+
size = base_comm.Get_size()
1617

1718
# Define test cases: (N, K, M, dtype_str)
1819
# M, K, N are matrix dimensions A(N,K), B(K,M)
@@ -29,14 +30,13 @@
2930

3031
@pytest.mark.mpi(min_size=1)
3132
@pytest.mark.parametrize("M, K, N, dtype_str", test_params)
32-
def test_SUMMAMatrixMult(N, K, M, dtype_str):
33+
def test_MPIMatrixMult(N, K, M, dtype_str):
3334
dtype = np.dtype(dtype_str)
3435

3536
cmplx = 1j if np.issubdtype(dtype, np.complexfloating) else 0
3637
base_float_dtype = np.float32 if dtype == np.complex64 else np.float64
3738

3839
comm, rank, row_id, col_id, is_active = MPIMatrixMult.active_grid_comm(base_comm, N, M)
39-
print(f"Process {base_comm.Get_rank()} is {"active" if is_active else "inactive"}")
4040
if not is_active: return
4141

4242
size = comm.Get_size()

0 commit comments

Comments
 (0)