Skip to content

Commit 6066ec3

Browse files
authored
Merge pull request #136 from astroC86/astroC86-SUMMA
feat: MPIMatrixMult with blocking over rows (of the operator) and column (of the input)
2 parents da326a0 + 0c34b78 commit 6066ec3

File tree

5 files changed

+629
-0
lines changed

5 files changed

+629
-0
lines changed

docs/source/api/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ Basic Operators
4242
.. autosummary::
4343
:toctree: generated/
4444

45+
MPIMatrixMult
4546
MPIBlockDiag
4647
MPIStackedBlockDiag
4748
MPIVStack

examples/plot_matrixmult.py

Lines changed: 223 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,223 @@
1+
r"""
2+
Distributed Matrix Multiplication
3+
=================================
4+
This example shows how to use the :py:class:`pylops_mpi.basicoperators.MPIMatrixMult`
5+
operator to perform matrix-matrix multiplication between a matrix :math:`\mathbf{A}`
6+
blocked over rows (i.e., blocks of rows are stored over different ranks) and a
7+
matrix :math:`\mathbf{X}` blocked over columns (i.e., blocks of columns are
8+
stored over different ranks), with equal number of row and column blocks.
9+
Similarly, the adjoint operation can be peformed with a matrix :math:`\mathbf{Y}`
10+
blocked in the same fashion of matrix :math:`\mathbf{X}`.
11+
12+
Note that whilst the different blocks of the matrix :math:`\mathbf{A}` are directly
13+
stored in the operator on different ranks, the matrix :math:`\mathbf{X}` is
14+
effectively represented by a 1-D :py:class:`pylops_mpi.DistributedArray` where
15+
the different blocks are flattened and stored on different ranks. Note that to
16+
optimize communications, the ranks are organized in a 2D grid and some of the
17+
row blocks of :math:`\mathbf{A}` and column blocks of :math:`\mathbf{X}` are
18+
replicated across different ranks - see below for details.
19+
20+
"""
21+
22+
from matplotlib import pyplot as plt
23+
import math
24+
import numpy as np
25+
from mpi4py import MPI
26+
27+
from pylops_mpi import DistributedArray, Partition
28+
from pylops_mpi.basicoperators.MatrixMult import MPIMatrixMult
29+
30+
plt.close("all")
31+
32+
###############################################################################
33+
# We set the seed such that all processes can create the input matrices filled
34+
# with the same random number. In practical application, such matrices will be
35+
# filled with data that is appropriate that is appropriate the use-case.
36+
np.random.seed(42)
37+
38+
###############################################################################
39+
# We are now ready to create the input matrices :math:`\mathbf{A}` of size
40+
# :math:`M \times k` :math:`\mathbf{A}` of size and :math:`\mathbf{A}` of size
41+
# :math:`K \times N`.
42+
N, K, M = 4, 4, 4
43+
A = np.random.rand(N * K).astype(dtype=np.float32).reshape(N, K)
44+
X = np.random.rand(K * M).astype(dtype=np.float32).reshape(K, M)
45+
46+
################################################################################
47+
# The processes are now arranged in a :math:`P' \times P'` grid,
48+
# where :math:`P` is the total number of processes.
49+
#
50+
# We define
51+
#
52+
# .. math::
53+
# P' = \bigl \lceil \sqrt{P} \bigr \rceil
54+
#
55+
# and the replication factor
56+
#
57+
# .. math::
58+
# R = \bigl\lceil \tfrac{P}{P'} \bigr\rceil.
59+
#
60+
# Each process is therefore assigned a pair of coordinates
61+
# :math:`(r,c)` within this grid:
62+
#
63+
# .. math::
64+
# r = \left\lfloor \frac{\mathrm{rank}}{P'} \right\rfloor,
65+
# \quad
66+
# c = \mathrm{rank} \bmod P'.
67+
#
68+
# For example, when :math:`P = 4` we have :math:`P' = 2`, giving a 2×2 layout:
69+
#
70+
# .. raw:: html
71+
#
72+
# <div style="text-align: center; font-family: monospace; white-space: pre;">
73+
# ┌────────────┬────────────┐
74+
# │ Rank 0 │ Rank 1 │
75+
# │ (r=0, c=0) │ (r=0, c=1) │
76+
# ├────────────┼────────────┤
77+
# │ Rank 2 │ Rank 3 │
78+
# │ (r=1, c=0) │ (r=1, c=1) │
79+
# └────────────┴────────────┘
80+
# </div>
81+
#
82+
# This is obtained by invoking the
83+
# `:func:pylops_mpi.MPIMatrixMult.active_grid_comm` method, which is also
84+
# responsible to identify any rank that should be deactivated (if the number
85+
# of rows of the operator or columns of the input/output matrices are smaller
86+
# than the row or columm ranks.
87+
88+
base_comm = MPI.COMM_WORLD
89+
comm, rank, row_id, col_id, is_active = MPIMatrixMult.active_grid_comm(base_comm, N, M)
90+
print(f"Process {base_comm.Get_rank()} is {"active" if is_active else "inactive"}")
91+
if not is_active: exit(0)
92+
93+
# Create sub‐communicators
94+
p_prime = math.isqrt(comm.Get_size())
95+
row_comm = comm.Split(color=row_id, key=col_id) # all procs in same row
96+
col_comm = comm.Split(color=col_id, key=row_id) # all procs in same col
97+
98+
################################################################################
99+
# At this point we divide the rows and columns of :math:`\mathbf{A}` and
100+
# :math:`\mathbf{X}`, respectively, such that each rank ends up with:
101+
#
102+
# - :math:`A_{p} \in \mathbb{R}^{\text{my_own_rows}\times K}`
103+
# - :math:`X_{p} \in \mathbb{R}^{K\times \text{my_own_cols}}`
104+
#
105+
# .. raw:: html
106+
#
107+
# <div style="text-align: left; font-family: monospace; white-space: pre;">
108+
# <b>Matrix A (4 x 4):</b>
109+
# ┌─────────────────┐
110+
# │ a11 a12 a13 a14 │ <- Rows 0–1 (Process Grid Row 0)
111+
# │ a21 a22 a23 a24 │
112+
# ├─────────────────┤
113+
# │ a41 a42 a43 a44 │ <- Rows 2–3 (Process Grid Row 1)
114+
# │ a51 a52 a53 a54 │
115+
# └─────────────────┘
116+
# </div>
117+
#
118+
# .. raw:: html
119+
#
120+
# <div style="text-align: left; font-family: monospace; white-space: pre;">
121+
# <b>Matrix X (4 x 4):</b>
122+
# ┌─────────┬─────────┐
123+
# │ b11 b12 │ b13 b14 │ <- Cols 0–1 (Process Grid Col 0), Cols 2–3 (Process Grid Col 1)
124+
# │ b21 b22 │ b23 b24 │
125+
# │ b31 b32 │ b33 b34 │
126+
# │ b41 b42 │ b43 b44 │
127+
# └─────────┴─────────┘
128+
# </div>
129+
#
130+
131+
blk_rows = int(math.ceil(N / p_prime))
132+
blk_cols = int(math.ceil(M / p_prime))
133+
134+
rs = col_id * blk_rows
135+
re = min(N, rs + blk_rows)
136+
my_own_rows = max(0, re - rs)
137+
138+
cs = row_id * blk_cols
139+
ce = min(M, cs + blk_cols)
140+
my_own_cols = max(0, ce - cs)
141+
142+
A_p, X_p = A[rs:re, :].copy(), X[:, cs:ce].copy()
143+
144+
################################################################################
145+
# We are now ready to create the :py:class:`pylops_mpi.basicoperators.MPIMatrixMult`
146+
# operator and the input matrix :math:`\mathbf{X}`
147+
Aop = MPIMatrixMult(A_p, M, base_comm=comm, dtype="float32")
148+
149+
col_lens = comm.allgather(my_own_cols)
150+
total_cols = np.sum(col_lens)
151+
x = DistributedArray(global_shape=K * total_cols,
152+
local_shapes=[K * col_len for col_len in col_lens],
153+
partition=Partition.SCATTER,
154+
mask=[i % p_prime for i in range(comm.Get_size())],
155+
base_comm=comm,
156+
dtype="float32")
157+
x[:] = X_p.flatten()
158+
159+
################################################################################
160+
# We can now apply the forward pass :math:`\mathbf{y} = \mathbf{Ax}` (which effectively
161+
# implements a distributed matrix-matrix multiplication :math:`Y = \mathbf{AX}`)
162+
# Note :math:`\mathbf{Y}` is distributed in the same way as the input
163+
# :math:`\mathbf{X}`.
164+
y = Aop @ x
165+
166+
###############################################################################
167+
# Next we apply the adjoint pass :math:`\mathbf{x}_{adj} = \mathbf{A}^H \mathbf{x}`
168+
# (which effectively implements a distributed matrix-matrix multiplication
169+
# :math:`\mathbf{X}_{adj} = \mathbf{A}^H \mathbf{X}`). Note that
170+
# :math:`\mathbf{X}_{adj}` is again distributed in the same way as the input
171+
# :math:`\mathbf{X}`.
172+
xadj = Aop.H @ y
173+
174+
###############################################################################
175+
# To conclude we verify our result against the equivalent serial version of
176+
# the operation by gathering the resulting matrices in rank0 and reorganizing
177+
# the returned 1D-arrays into 2D-arrays.
178+
179+
# Local benchmarks
180+
y = y.asarray(masked=True)
181+
col_counts = [min(blk_cols, M - j * blk_cols) for j in range(p_prime)]
182+
y_blocks = []
183+
offset = 0
184+
for cnt in col_counts:
185+
block_size = N * cnt
186+
y_block = y[offset: offset + block_size]
187+
if len(y_block) != 0:
188+
y_blocks.append(
189+
y_block.reshape(N, cnt)
190+
)
191+
offset += block_size
192+
y = np.hstack(y_blocks)
193+
194+
xadj = xadj.asarray(masked=True)
195+
xadj_blocks = []
196+
offset = 0
197+
for cnt in col_counts:
198+
block_size = K * cnt
199+
xadj_blk = xadj[offset: offset + block_size]
200+
if len(xadj_blk) != 0:
201+
xadj_blocks.append(
202+
xadj_blk.reshape(K, cnt)
203+
)
204+
offset += block_size
205+
xadj = np.hstack(xadj_blocks)
206+
207+
if rank == 0:
208+
y_loc = (A @ X).squeeze()
209+
xadj_loc = (A.T.dot(y_loc.conj())).conj().squeeze()
210+
211+
if not np.allclose(y, y_loc, rtol=1e-6):
212+
print("FORWARD VERIFICATION FAILED")
213+
print(f'distributed: {y}')
214+
print(f'expected: {y_loc}')
215+
else:
216+
print("FORWARD VERIFICATION PASSED")
217+
218+
if not np.allclose(xadj, xadj_loc, rtol=1e-6):
219+
print("ADJOINT VERIFICATION FAILED")
220+
print(f'distributed: {xadj}')
221+
print(f'expected: {xadj_loc}')
222+
else:
223+
print("ADJOINT VERIFICATION PASSED")

0 commit comments

Comments
 (0)