Skip to content

Commit 9aedd7c

Browse files
committed
MatrixMul works with non-square prcs by creating square subcommunicator
1 parent 053e52d commit 9aedd7c

File tree

3 files changed

+99
-62
lines changed

3 files changed

+99
-62
lines changed

examples/plot_matrixmult.py

Lines changed: 23 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -35,20 +35,6 @@
3535
# filled with data that is appropriate that is appropriate the use-case.
3636
np.random.seed(42)
3737

38-
###############################################################################
39-
# Next we obtain the MPI parameters for each rank and check that the number
40-
# of processes (``size``) is a square number
41-
comm = MPI.COMM_WORLD
42-
rank = comm.Get_rank() # rank of current process
43-
size = comm.Get_size() # number of processes
44-
45-
p_prime = math.isqrt(size)
46-
repl_factor = p_prime
47-
48-
if (p_prime * repl_factor) != size:
49-
print(f"Number of processes must be a square number, provided {size} instead...")
50-
exit(-1)
51-
5238
###############################################################################
5339
# We are now ready to create the input matrices :math:`\mathbf{A}` of size
5440
# :math:`M \times k` :math:`\mathbf{A}` of size and :math:`\mathbf{A}` of size
@@ -93,12 +79,15 @@
9379
# └────────────┴────────────┘
9480
# </div>
9581

96-
my_col = rank % p_prime
97-
my_row = rank // p_prime
82+
base_comm = MPI.COMM_WORLD
83+
comm, rank, row_id, col_id, is_active = MPIMatrixMult.active_grid_comm(base_comm, N, M)
84+
print(f"Process {base_comm.Get_rank()} is {"active" if is_active else "inactive"}")
85+
if not is_active: exit(0)
86+
p_prime = math.isqrt(comm.Get_size())
9887

9988
# Create sub‐communicators
100-
row_comm = comm.Split(color=my_row, key=my_col) # all procs in same row
101-
col_comm = comm.Split(color=my_col, key=my_row) # all procs in same col
89+
row_comm = comm.Split(color=row_id, key=col_id) # all procs in same row
90+
col_comm = comm.Split(color=col_id, key=row_id) # all procs in same col
10291

10392
################################################################################
10493
# At this point we divide the rows and columns of :math:`\mathbf{A}` and
@@ -136,20 +125,20 @@
136125
blk_rows = int(math.ceil(N / p_prime))
137126
blk_cols = int(math.ceil(M / p_prime))
138127

139-
rs = my_col * blk_rows
128+
rs = col_id * blk_rows
140129
re = min(N, rs + blk_rows)
141-
my_own_rows = re - rs
130+
my_own_rows = max(0,re - rs)
142131

143-
cs = my_row * blk_cols
132+
cs = row_id * blk_cols
144133
ce = min(M, cs + blk_cols)
145-
my_own_cols = ce - cs
134+
my_own_cols = max(0,ce - cs)
146135

147136
A_p, X_p = A[rs:re, :].copy(), X[:, cs:ce].copy()
148137

149138
################################################################################
150139
# We are now ready to create the :py:class:`pylops_mpi.basicoperators.MPIMatrixMult`
151140
# operator and the input matrix :math:`\mathbf{X}`
152-
Aop = MPIMatrixMult(A_p, M, dtype="float32")
141+
Aop = MPIMatrixMult(A_p, M, base_comm=comm, dtype="float32")
153142

154143
col_lens = comm.allgather(my_own_cols)
155144
total_cols = np.sum(col_lens)
@@ -188,9 +177,11 @@
188177
offset = 0
189178
for cnt in col_counts:
190179
block_size = N * cnt
191-
y_blocks.append(
192-
y[offset: offset + block_size].reshape(N, cnt)
193-
)
180+
y_block = y[offset: offset + block_size]
181+
if len(y_block) != 0:
182+
y_blocks.append(
183+
y_block.reshape(N, cnt)
184+
)
194185
offset += block_size
195186
y = np.hstack(y_blocks)
196187

@@ -199,13 +190,15 @@
199190
offset = 0
200191
for cnt in col_counts:
201192
block_size = K * cnt
202-
xadj_blocks.append(
203-
xadj[offset: offset + block_size].reshape(K, cnt)
204-
)
193+
xadj_blk = xadj[offset: offset + block_size]
194+
if len(xadj_blk)!= 0:
195+
xadj_blocks.append(
196+
xadj_blk.reshape(K, cnt)
197+
)
205198
offset += block_size
206199
xadj = np.hstack(xadj_blocks)
207200

208-
if rank == 0:
201+
if comm.Get_rank() == 0:
209202
y_loc = (A @ X).squeeze()
210203
xadj_loc = (A.T.dot(y_loc.conj())).conj().squeeze()
211204

pylops_mpi/basicoperators/MatrixMult.py

Lines changed: 57 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ def __init__(
154154
self._col_start = self._row_id * block_cols
155155
self._col_end = min(self.M, self._col_start + block_cols)
156156

157-
self._local_ncols = self._col_end - self._col_start
157+
self._local_ncols = max(0, self._col_end - self._col_start)
158158
self._rank_col_lens = self.base_comm.allgather(self._local_ncols)
159159
total_ncols = np.sum(self._rank_col_lens)
160160

@@ -168,11 +168,14 @@ def _matvec(self, x: DistributedArray) -> DistributedArray:
168168
if x.partition != Partition.SCATTER:
169169
raise ValueError(f"x should have partition={Partition.SCATTER} Got {x.partition} instead...")
170170

171-
y = DistributedArray(global_shape=(self.N * self.dimsd[1]),
172-
local_shapes=[(self.N * c) for c in self._rank_col_lens],
173-
mask=x.mask,
174-
partition=Partition.SCATTER,
175-
dtype=self.dtype)
171+
y = DistributedArray(
172+
global_shape=(self.N * self.dimsd[1]),
173+
local_shapes=[(self.N * c) for c in self._rank_col_lens],
174+
mask=x.mask,
175+
partition=Partition.SCATTER,
176+
dtype=self.dtype,
177+
base_comm=self.base_comm
178+
)
176179

177180
my_own_cols = self._rank_col_lens[self.rank]
178181
x_arr = x.local_array.reshape((self.dims[0], my_own_cols))
@@ -185,6 +188,53 @@ def _matvec(self, x: DistributedArray) -> DistributedArray:
185188
y[:] = Y_local.flatten()
186189
return y
187190

191+
@staticmethod
192+
def active_grid_comm(base_comm:MPI.Comm, N:int, M:int):
193+
"""
194+
Configure a square process grid from a parent MPI communicator and select the subset of "active" processes.
195+
196+
Each process in base_comm is assigned to a logical 2D grid of size p_prime x p_prime,
197+
where p_prime = floor(sqrt(total_ranks)). Only the first `active_dim x active_dim` processes
198+
(by row-major order) are considered "active". Inactive ranks return immediately with no new communicator.
199+
200+
Parameters:
201+
-----------
202+
base_comm : MPI.Comm
203+
The parent communicator (e.g., MPI.COMM_WORLD).
204+
N : int
205+
Number of rows of your global data domain.
206+
M : int
207+
Number of columns of your global data domain.
208+
209+
Returns:
210+
--------
211+
tuple:
212+
comm (MPI.Comm or None) : Sub-communicator including only active ranks.
213+
rank (int) : Rank within the new sub-communicator (or original rank if inactive).
214+
row (int) : Grid row index of this process in the active grid (or original rank if inactive).
215+
col (int) : Grid column index of this process in the active grid (or original rank if inactive).
216+
is_active (bool) : Flag indicating whether this rank is in the active sub-grid.
217+
"""
218+
rank = base_comm.Get_rank()
219+
size = base_comm.Get_size()
220+
p_prime = math.isqrt(size)
221+
row, col = divmod(rank, p_prime)
222+
active_dim = min(N, M, p_prime)
223+
is_active = (row < active_dim and col < active_dim)
224+
225+
if not is_active:
226+
return None, rank, row, col, False
227+
228+
active_ranks = [r for r in range(size)
229+
if (r // p_prime) < active_dim and (r % p_prime) < active_dim]
230+
new_group = base_comm.Get_group().Incl(active_ranks)
231+
new_comm = base_comm.Create_group(new_group)
232+
233+
p_prime_new = math.isqrt(len(active_ranks))
234+
new_rank = new_comm.Get_rank()
235+
new_row, new_col = divmod(new_rank, p_prime_new)
236+
return new_comm, new_rank, new_row, new_col, True
237+
188238
def _rmatvec(self, x: DistributedArray) -> DistributedArray:
189239
ncp = get_module(x.engine)
190240
if x.partition != Partition.SCATTER:
@@ -196,6 +246,7 @@ def _rmatvec(self, x: DistributedArray) -> DistributedArray:
196246
mask=x.mask,
197247
partition=Partition.SCATTER,
198248
dtype=self.dtype,
249+
base_comm=self.base_comm
199250
)
200251

201252
x_arr = x.local_array.reshape((self.N, self._local_ncols)).astype(self.dtype)

tests/test_matrixmult.py

Lines changed: 19 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,7 @@
1212
from pylops_mpi.basicoperators.MatrixMult import MPIMatrixMult
1313

1414
np.random.seed(42)
15-
comm = MPI.COMM_WORLD
16-
rank = comm.Get_rank()
17-
size = comm.Get_size()
15+
base_comm = MPI.COMM_WORLD
1816

1917
# Define test cases: (N, K, M, dtype_str)
2018
# M, K, N are matrix dimensions A(N,K), B(K,M)
@@ -32,31 +30,25 @@
3230
@pytest.mark.mpi(min_size=1)
3331
@pytest.mark.parametrize("M, K, N, dtype_str", test_params)
3432
def test_SUMMAMatrixMult(N, K, M, dtype_str):
35-
p_prime = math.isqrt(size)
36-
C = p_prime
37-
if p_prime * C != size:
38-
pytest.skip("Number of processes must be a square number, "
39-
"provided {size} instead...")
40-
4133
dtype = np.dtype(dtype_str)
4234

4335
cmplx = 1j if np.issubdtype(dtype, np.complexfloating) else 0
4436
base_float_dtype = np.float32 if dtype == np.complex64 else np.float64
4537

46-
my_col = rank % p_prime
47-
my_row = rank // p_prime
38+
comm, rank, row_id, col_id, is_active = MPIMatrixMult.active_grid_comm(base_comm, N, M)
39+
print(f"Process {base_comm.Get_rank()} is {"active" if is_active else "inactive"}")
40+
if not is_active: return
4841

49-
# Create sub-communicators
50-
row_comm = comm.Split(color=my_row, key=my_col)
51-
col_comm = comm.Split(color=my_col, key=my_row)
42+
size = comm.Get_size()
43+
p_prime = math.isqrt(size)
5244

5345
# Calculate local matrix dimensions
5446
blk_rows_A = int(math.ceil(N / p_prime))
55-
row_start_A = my_col * blk_rows_A
47+
row_start_A = col_id * blk_rows_A
5648
row_end_A = min(N, row_start_A + blk_rows_A)
5749

5850
blk_cols_X = int(math.ceil(M / p_prime))
59-
col_start_X = my_row * blk_cols_X
51+
col_start_X = row_id * blk_cols_X
6052
col_end_X = min(M, col_start_X + blk_cols_X)
6153
local_col_X_len = max(0, col_end_X - col_start_X)
6254

@@ -102,9 +94,11 @@ def test_SUMMAMatrixMult(N, K, M, dtype_str):
10294
offset = 0
10395
for cnt in col_counts:
10496
block_size = N * cnt
105-
y_blocks.append(
106-
y[offset: offset + block_size].reshape(N, cnt)
107-
)
97+
y_block = y[offset: offset + block_size]
98+
if len(y_block) != 0:
99+
y_blocks.append(
100+
y_block.reshape(N, cnt)
101+
)
108102
offset += block_size
109103
y = np.hstack(y_blocks)
110104

@@ -113,9 +107,11 @@ def test_SUMMAMatrixMult(N, K, M, dtype_str):
113107
offset = 0
114108
for cnt in col_counts:
115109
block_size = K * cnt
116-
xadj_blocks.append(
117-
xadj[offset: offset + block_size].reshape(K, cnt)
118-
)
110+
xadj_blk = xadj[offset: offset + block_size]
111+
if len(xadj_blk) != 0:
112+
xadj_blocks.append(
113+
xadj_blk.reshape(K, cnt)
114+
)
119115
offset += block_size
120116
xadj = np.hstack(xadj_blocks)
121117

@@ -134,7 +130,4 @@ def test_SUMMAMatrixMult(N, K, M, dtype_str):
134130
xadj_loc.squeeze(),
135131
rtol=np.finfo(np.dtype(dtype)).resolution,
136132
err_msg=f"Rank {rank}: Ajoint verification failed."
137-
)
138-
139-
col_comm.Free()
140-
row_comm.Free()
133+
)

0 commit comments

Comments
 (0)