Skip to content

Commit a110ff8

Browse files
committed
minor: cleanup of docstrings and updated example
1 parent 42452a1 commit a110ff8

File tree

4 files changed

+159
-106
lines changed

4 files changed

+159
-106
lines changed

docs/source/api/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ Basic Operators
4242
.. autosummary::
4343
:toctree: generated/
4444

45+
MPIMatrixMult
4546
MPIBlockDiag
4647
MPIStackedBlockDiag
4748
MPIVStack

examples/plot_matrixmult.py

Lines changed: 116 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,22 @@
11
"""
22
Distributed Matrix Multiplication
33
=================================
4-
This example shows how to use the :py:class:`pylops_mpi.basicoperators.MatrixMult.MPIMatrixMult`.
5-
This class provides a way to distribute arrays across multiple processes in
6-
a parallel computing environment.
4+
This example shows how to use the :py:class:`pylops_mpi.basicoperators.MPIMatrixMult`
5+
operator to perform matrix-matrix multiplication between a matrix :math:`\mathbf{A}`
6+
blocked over rows (i.e., blocks of rows are stored over different ranks) and a
7+
matrix :math:`\mathbf{X}` blocked over columns (i.e., blocks of columns are
8+
stored over different ranks), with equal number of row and column blocks.
9+
Similarly, the adjoint operation can be peformed with a matrix :math:`\mathbf{Y}`
10+
blocked in the same fashion of matrix :math:`\mathbf{X}.
11+
12+
Note that whilst the different blocks of the matrix :math:`\mathbf{A}` are directly
13+
stored in the operator on different ranks, the matrix :math:`\mathbf{X}` is
14+
effectively represented by a 1-D :py:class:`pylops_mpi.DistributedArray` where
15+
the different blocks are flattened and stored on different ranks. Note that to
16+
optimize communications, the ranks are organized in a 2D grid and some of the
17+
row blocks of :math:`\mathbf{A}` and column blocks of :math:`\mathbf{X}` are
18+
replicated across different ranks - see below for details.
19+
720
"""
821
from matplotlib import pyplot as plt
922
import math
@@ -14,49 +27,56 @@
1427
from pylops_mpi.basicoperators.MatrixMult import MPIMatrixMult
1528

1629
plt.close("all")
30+
1731
###############################################################################
18-
# We set the seed such that all processes initially start out with the same initial matrix.
19-
# Ideally this data would be loaded in a manner appropriate to the use-case.
32+
# We set the seed such that all processes can create the input matrices filled
33+
# with the same random number. In practical application, such matrices will be
34+
# filled with data that is appropriate that is appropriate the use-case.
2035
np.random.seed(42)
2136

22-
# MPI parameters
37+
###############################################################################
38+
# Next we obtain the MPI parameters for each rank and check that the number
39+
# of processes (``size``) is a square number
2340
comm = MPI.COMM_WORLD
2441
rank = comm.Get_rank() # rank of current process
2542
size = comm.Get_size() # number of processes
2643

2744
p_prime = int(math.ceil(math.sqrt(size)))
28-
C = int(math.ceil(size / p_prime))
45+
repl_factor = int(math.ceil(size / p_prime))
2946

30-
if (p_prime * C) != size:
31-
print("No. of procs has to be a square number")
47+
if (p_prime * repl_factor) != size:
48+
print(f"Number of processes must be a square number, provided {size} instead...")
3249
exit(-1)
3350

34-
# matrix dims
51+
###############################################################################
52+
# We are now ready to create the input matrices :math:`\mathbf{A}` of size
53+
# :math:`M \times k` :math:`\mathbf{A}` of size and :math:`\mathbf{A}` of size
54+
# :math:`K \times N`.
3555
M, K, N = 4, 4, 4
3656
A = np.random.rand(M * K).astype(dtype=np.float32).reshape(M, K)
3757
X = np.random.rand(K * N).astype(dtype=np.float32).reshape(K, N)
58+
3859
################################################################################
39-
#Process Grid Organization
40-
#*************************
41-
#
42-
#The processes are arranged in a :math:`\sqrt{P} \times \sqrt{P}` grid, where :math:`P` is the total number of processes.
60+
# The processes are now arranged in a :math:`\sqrt{P} \times \sqrt{P}` grid,
61+
# where :math:`P` is the total number of processes.
4362
#
44-
#Define
63+
# We define
4564
#
46-
#.. math::
47-
# P' = \bigl \lceil \sqrt{P} \bigr \rceil
65+
# .. math::
66+
# P' = \bigl \lceil \sqrt{P} \bigr \rceil
4867
#
49-
#and the replication factor
68+
# and the replication factor
5069
#
51-
#.. math::
52-
# C = \bigl\lceil \tfrac{P}{P'} \bigr\rceil.
70+
# .. math::
71+
# R = \bigl\lceil \tfrac{P}{P'} \bigr\rceil.
5372
#
54-
#Each process is assigned a pair of coordinates :math:`(g, l)` within this grid:
73+
# Each process is therefore assigned a pair of coordinates
74+
# :math:`(g, l)` within this grid:
5575
#
56-
#.. math::
57-
# g = \mathrm{rank} \bmod P',
58-
# \quad
59-
# l = \left\lfloor \frac{\mathrm{rank}}{P'} \right\rfloor.
76+
# .. math::
77+
# g = \mathrm{rank} \bmod P',
78+
# \quad
79+
# l = \left\lfloor \frac{\mathrm{rank}}{P'} \right\rfloor.
6080
#
6181
#For example, when :math:`P = 4` we have :math:`P' = 2`, giving a 2×2 layout:
6282
#
@@ -75,30 +95,18 @@
7595
my_group = rank % p_prime
7696
my_layer = rank // p_prime
7797

78-
# Create the sub‐communicators
98+
# Create sub‐communicators
7999
layer_comm = comm.Split(color=my_layer, key=my_group) # all procs in same layer
80100
group_comm = comm.Split(color=my_group, key=my_layer) # all procs in same group
81101

82-
blk_rows = int(math.ceil(M / p_prime))
83-
blk_cols = int(math.ceil(N / p_prime))
84-
85-
rs = my_group * blk_rows
86-
re = min(M, rs + blk_rows)
87-
my_own_rows = re - rs
88-
89-
cs = my_layer * blk_cols
90-
ce = min(N, cs + blk_cols)
91-
my_own_cols = ce - cs
92-
93102
################################################################################
94-
#Each rank will end up with:
95-
# - :math:`A_{p} \in \mathbb{R}^{\text{my_own_rows}\times K}`
96-
# - :math:`X_{p} \in \mathbb{R}^{K\times \text{my_own_cols}}`
97-
#as follows:
98-
A_p, X_p = A[rs:re, :].copy(), X[:, cs:ce].copy()
99-
100-
################################################################################
101-
#.. raw:: html
103+
# At this point we divide the rows and columns of :math:`\mathbf{A}` and
104+
# :math:`\mathbf{X}`, respectively, such that each rank ends up with:
105+
#
106+
# - :math:`A_{p} \in \mathbb{R}^{\text{my_own_rows}\times K}`
107+
# - :math:`X_{p} \in \mathbb{R}^{K\times \text{my_own_cols}}`
108+
#
109+
# .. raw:: html
102110
#
103111
# <div style="text-align: left; font-family: monospace; white-space: pre;">
104112
# <b>Matrix A (4 x 4):</b>
@@ -111,10 +119,10 @@
111119
# └─────────────────┘
112120
# </div>
113121
#
114-
#.. raw:: html
122+
# .. raw:: html
115123
#
116124
# <div style="text-align: left; font-family: monospace; white-space: pre;">
117-
# <b>Matrix B (4 x 4):</b>
125+
# <b>Matrix X (4 x 4):</b>
118126
# ┌─────────┬─────────┐
119127
# │ b11 b12 │ b13 b14 │ <- Cols 0–1 (Layer 0), Cols 2–3 (Layer 1)
120128
# │ b21 b22 │ b23 b24 │
@@ -125,48 +133,80 @@
125133
# </div>
126134
#
127135

136+
blk_rows = int(math.ceil(M / p_prime))
137+
blk_cols = int(math.ceil(N / p_prime))
138+
139+
rs = my_group * blk_rows
140+
re = min(M, rs + blk_rows)
141+
my_own_rows = re - rs
142+
143+
cs = my_layer * blk_cols
144+
ce = min(N, cs + blk_cols)
145+
my_own_cols = ce - cs
146+
147+
A_p, X_p = A[rs:re, :].copy(), X[:, cs:ce].copy()
148+
128149
################################################################################
129-
#Forward Operation
130-
#*****************
131-
#To perform our distributed matrix-matrix multiplication :math:`Y = \text{Aop} \times X` we need to create our distributed operator :math:`\text{Aop}` and distributed operand :math:`X` from :math:`A_p` and
132-
#:math:`X_p` respectively
150+
# We are now ready to create the :py:class:`pylops_mpi.basicoperators.MPIMatrixMult`
151+
# operator and the input matrix math:`\mathbf{X}`
133152
Aop = MPIMatrixMult(A_p, N, dtype="float32")
134-
################################################################################
135-
# While as well passing the appropriate values.
153+
136154
col_lens = comm.allgather(my_own_cols)
137-
total_cols = np.sum(col_lens)
138-
x = DistributedArray(global_shape=K * total_cols,
155+
x = DistributedArray(global_shape=K * N,
139156
local_shapes=[K * col_len for col_len in col_lens],
140157
partition=Partition.SCATTER,
141-
mask=[i // p_prime for i in range(comm.Get_size())],
158+
mask=[i % p_prime for i in range(comm.Get_size())],
142159
base_comm=comm,
143160
dtype="float32")
144161
x[:] = X_p.flatten()
162+
145163
################################################################################
146-
#When we perform the matrix-matrix multiplication we shall then obtain a distributed :math:`Y` in the same way our :math:`X` was distributed.
164+
# We can now apply the forward pass :math:`\mathbf{y} = \mathbf{Ax}` (which effectively
165+
# implements a distributed matrix-matrix multiplication :math:`Y = \mathbf{AX}`)
166+
# Note :math:`\mathbf{Y}` is distributed in the same way as the input
167+
# :math:`\mathbf{X}`.
147168
y = Aop @ x
169+
148170
###############################################################################
149-
#Adjoint Operation
150-
#*****************
151-
# In a similar fashion we then perform the Adjoint :math:`Xadj = A^H * Y`
171+
# Next we apply the adjoint pass :math:`\mathbf{x}_{adj} = \mathbf{A}^H \mathbf{x}`
172+
# (which effectively implements a distributed matrix-matrix multiplication
173+
# :math:`\mathbf{X}_{adj} = \mathbf{A}^H \mathbf{X}`). Note that
174+
# :math:`\mathbf{X}_{adj}` is again distributed in the same way as the input
175+
# :math:`\mathbf{X}`.
152176
xadj = Aop.H @ y
177+
153178
###############################################################################
154-
#Here we verify the result against the equivalent serial version of the operation. Each rank checks that it has computed the correct values for it partition.
179+
# To conclude we verify our result against the equivalent serial version of
180+
# the operation by gathering the resulting matrices in rank0 and reorganizing
181+
# the returned 1D-arrays into 2D-arrays.
182+
183+
# Local benchmarks
155184
y_loc = A @ X
156185
xadj_loc = (A.T.dot(y_loc.conj())).conj()
157186

158-
expected_y_loc = y_loc[:, cs:ce].flatten().astype(np.float32)
159-
expected_xadj_loc = xadj_loc[:, cs:ce].flatten().astype(np.float32)
160-
161-
if not np.allclose(y.local_array, expected_y_loc, rtol=1e-6):
162-
print(f"RANK {rank}: FORWARD VERIFICATION FAILED")
163-
print(f'{rank} local: {y.local_array}, expected: {y_loc[:, cs:ce]}')
164-
else:
165-
print(f"RANK {rank}: FORWARD VERIFICATION PASSED")
166-
167-
if not np.allclose(xadj.local_array, expected_xadj_loc, rtol=1e-6):
168-
print(f"RANK {rank}: ADJOINT VERIFICATION FAILED")
169-
print(f'{rank} local: {xadj.local_array}, expected: {xadj_loc[:, cs:ce]}')
170-
else:
171-
print(f"RANK {rank}: ADJOINT VERIFICATION PASSED")
172-
187+
y = y.asarray(masked=True)
188+
if N > 1:
189+
y = y.reshape(p_prime, M, blk_cols)
190+
y = np.hstack([yblock for yblock in y])
191+
xadj = xadj.asarray(masked=True)
192+
if N > 1:
193+
xadj = xadj.reshape(p_prime, K, blk_cols)
194+
xadj = np.hstack([xadjblock for xadjblock in xadj])
195+
196+
if rank == 0:
197+
y_loc = (A @ X).squeeze()
198+
xadj_loc = (A.T.dot(y_loc.conj())).conj().squeeze()
199+
200+
if not np.allclose(y, y_loc, rtol=1e-6):
201+
print(f" FORWARD VERIFICATION FAILED")
202+
print(f'distributed: {y}')
203+
print(f'expected: {y_loc}')
204+
else:
205+
print(f"FORWARD VERIFICATION PASSED")
206+
207+
if not np.allclose(xadj, xadj_loc, rtol=1e-6):
208+
print(f" ADJOINT VERIFICATION FAILED")
209+
print(f'distributed: {xadj}')
210+
print(f'expected: {xadj_loc}')
211+
else:
212+
print(f"ADJOINT VERIFICATION PASSED")

0 commit comments

Comments
 (0)