|
| 1 | +r""" |
| 2 | +Distributed Matrix Multiplication |
| 3 | +================================= |
| 4 | +This example shows how to use the :py:class:`pylops_mpi.basicoperators.MPIMatrixMult` |
| 5 | +operator to perform matrix-matrix multiplication between a matrix :math:`\mathbf{A}` |
| 6 | +blocked over rows (i.e., blocks of rows are stored over different ranks) and a |
| 7 | +matrix :math:`\mathbf{X}` blocked over columns (i.e., blocks of columns are |
| 8 | +stored over different ranks), with equal number of row and column blocks. |
| 9 | +Similarly, the adjoint operation can be peformed with a matrix :math:`\mathbf{Y}` |
| 10 | +blocked in the same fashion of matrix :math:`\mathbf{X}`. |
| 11 | +
|
| 12 | +Note that whilst the different blocks of the matrix :math:`\mathbf{A}` are directly |
| 13 | +stored in the operator on different ranks, the matrix :math:`\mathbf{X}` is |
| 14 | +effectively represented by a 1-D :py:class:`pylops_mpi.DistributedArray` where |
| 15 | +the different blocks are flattened and stored on different ranks. Note that to |
| 16 | +optimize communications, the ranks are organized in a 2D grid and some of the |
| 17 | +row blocks of :math:`\mathbf{A}` and column blocks of :math:`\mathbf{X}` are |
| 18 | +replicated across different ranks - see below for details. |
| 19 | +
|
| 20 | +""" |
| 21 | + |
| 22 | +from matplotlib import pyplot as plt |
| 23 | +import math |
| 24 | +import numpy as np |
| 25 | +from mpi4py import MPI |
| 26 | + |
| 27 | +from pylops_mpi import DistributedArray, Partition |
| 28 | +from pylops_mpi.basicoperators.MatrixMult import MPIMatrixMult |
| 29 | + |
| 30 | +plt.close("all") |
| 31 | + |
| 32 | +############################################################################### |
| 33 | +# We set the seed such that all processes can create the input matrices filled |
| 34 | +# with the same random number. In practical application, such matrices will be |
| 35 | +# filled with data that is appropriate that is appropriate the use-case. |
| 36 | +np.random.seed(42) |
| 37 | + |
| 38 | +############################################################################### |
| 39 | +# We are now ready to create the input matrices :math:`\mathbf{A}` of size |
| 40 | +# :math:`M \times k` :math:`\mathbf{A}` of size and :math:`\mathbf{A}` of size |
| 41 | +# :math:`K \times N`. |
| 42 | +N, K, M = 4, 4, 4 |
| 43 | +A = np.random.rand(N * K).astype(dtype=np.float32).reshape(N, K) |
| 44 | +X = np.random.rand(K * M).astype(dtype=np.float32).reshape(K, M) |
| 45 | + |
| 46 | +################################################################################ |
| 47 | +# The processes are now arranged in a :math:`P' \times P'` grid, |
| 48 | +# where :math:`P` is the total number of processes. |
| 49 | +# |
| 50 | +# We define |
| 51 | +# |
| 52 | +# .. math:: |
| 53 | +# P' = \bigl \lceil \sqrt{P} \bigr \rceil |
| 54 | +# |
| 55 | +# and the replication factor |
| 56 | +# |
| 57 | +# .. math:: |
| 58 | +# R = \bigl\lceil \tfrac{P}{P'} \bigr\rceil. |
| 59 | +# |
| 60 | +# Each process is therefore assigned a pair of coordinates |
| 61 | +# :math:`(r,c)` within this grid: |
| 62 | +# |
| 63 | +# .. math:: |
| 64 | +# r = \left\lfloor \frac{\mathrm{rank}}{P'} \right\rfloor, |
| 65 | +# \quad |
| 66 | +# c = \mathrm{rank} \bmod P'. |
| 67 | +# |
| 68 | +# For example, when :math:`P = 4` we have :math:`P' = 2`, giving a 2×2 layout: |
| 69 | +# |
| 70 | +# .. raw:: html |
| 71 | +# |
| 72 | +# <div style="text-align: center; font-family: monospace; white-space: pre;"> |
| 73 | +# ┌────────────┬────────────┐ |
| 74 | +# │ Rank 0 │ Rank 1 │ |
| 75 | +# │ (r=0, c=0) │ (r=0, c=1) │ |
| 76 | +# ├────────────┼────────────┤ |
| 77 | +# │ Rank 2 │ Rank 3 │ |
| 78 | +# │ (r=1, c=0) │ (r=1, c=1) │ |
| 79 | +# └────────────┴────────────┘ |
| 80 | +# </div> |
| 81 | +# |
| 82 | +# This is obtained by invoking the |
| 83 | +# `:func:pylops_mpi.MPIMatrixMult.active_grid_comm` method, which is also |
| 84 | +# responsible to identify any rank that should be deactivated (if the number |
| 85 | +# of rows of the operator or columns of the input/output matrices are smaller |
| 86 | +# than the row or columm ranks. |
| 87 | + |
| 88 | +base_comm = MPI.COMM_WORLD |
| 89 | +comm, rank, row_id, col_id, is_active = MPIMatrixMult.active_grid_comm(base_comm, N, M) |
| 90 | +print(f"Process {base_comm.Get_rank()} is {"active" if is_active else "inactive"}") |
| 91 | +if not is_active: exit(0) |
| 92 | + |
| 93 | +# Create sub‐communicators |
| 94 | +p_prime = math.isqrt(comm.Get_size()) |
| 95 | +row_comm = comm.Split(color=row_id, key=col_id) # all procs in same row |
| 96 | +col_comm = comm.Split(color=col_id, key=row_id) # all procs in same col |
| 97 | + |
| 98 | +################################################################################ |
| 99 | +# At this point we divide the rows and columns of :math:`\mathbf{A}` and |
| 100 | +# :math:`\mathbf{X}`, respectively, such that each rank ends up with: |
| 101 | +# |
| 102 | +# - :math:`A_{p} \in \mathbb{R}^{\text{my_own_rows}\times K}` |
| 103 | +# - :math:`X_{p} \in \mathbb{R}^{K\times \text{my_own_cols}}` |
| 104 | +# |
| 105 | +# .. raw:: html |
| 106 | +# |
| 107 | +# <div style="text-align: left; font-family: monospace; white-space: pre;"> |
| 108 | +# <b>Matrix A (4 x 4):</b> |
| 109 | +# ┌─────────────────┐ |
| 110 | +# │ a11 a12 a13 a14 │ <- Rows 0–1 (Process Grid Row 0) |
| 111 | +# │ a21 a22 a23 a24 │ |
| 112 | +# ├─────────────────┤ |
| 113 | +# │ a41 a42 a43 a44 │ <- Rows 2–3 (Process Grid Row 1) |
| 114 | +# │ a51 a52 a53 a54 │ |
| 115 | +# └─────────────────┘ |
| 116 | +# </div> |
| 117 | +# |
| 118 | +# .. raw:: html |
| 119 | +# |
| 120 | +# <div style="text-align: left; font-family: monospace; white-space: pre;"> |
| 121 | +# <b>Matrix X (4 x 4):</b> |
| 122 | +# ┌─────────┬─────────┐ |
| 123 | +# │ b11 b12 │ b13 b14 │ <- Cols 0–1 (Process Grid Col 0), Cols 2–3 (Process Grid Col 1) |
| 124 | +# │ b21 b22 │ b23 b24 │ |
| 125 | +# │ b31 b32 │ b33 b34 │ |
| 126 | +# │ b41 b42 │ b43 b44 │ |
| 127 | +# └─────────┴─────────┘ |
| 128 | +# </div> |
| 129 | +# |
| 130 | + |
| 131 | +blk_rows = int(math.ceil(N / p_prime)) |
| 132 | +blk_cols = int(math.ceil(M / p_prime)) |
| 133 | + |
| 134 | +rs = col_id * blk_rows |
| 135 | +re = min(N, rs + blk_rows) |
| 136 | +my_own_rows = max(0, re - rs) |
| 137 | + |
| 138 | +cs = row_id * blk_cols |
| 139 | +ce = min(M, cs + blk_cols) |
| 140 | +my_own_cols = max(0, ce - cs) |
| 141 | + |
| 142 | +A_p, X_p = A[rs:re, :].copy(), X[:, cs:ce].copy() |
| 143 | + |
| 144 | +################################################################################ |
| 145 | +# We are now ready to create the :py:class:`pylops_mpi.basicoperators.MPIMatrixMult` |
| 146 | +# operator and the input matrix :math:`\mathbf{X}` |
| 147 | +Aop = MPIMatrixMult(A_p, M, base_comm=comm, dtype="float32") |
| 148 | + |
| 149 | +col_lens = comm.allgather(my_own_cols) |
| 150 | +total_cols = np.sum(col_lens) |
| 151 | +x = DistributedArray(global_shape=K * total_cols, |
| 152 | + local_shapes=[K * col_len for col_len in col_lens], |
| 153 | + partition=Partition.SCATTER, |
| 154 | + mask=[i % p_prime for i in range(comm.Get_size())], |
| 155 | + base_comm=comm, |
| 156 | + dtype="float32") |
| 157 | +x[:] = X_p.flatten() |
| 158 | + |
| 159 | +################################################################################ |
| 160 | +# We can now apply the forward pass :math:`\mathbf{y} = \mathbf{Ax}` (which effectively |
| 161 | +# implements a distributed matrix-matrix multiplication :math:`Y = \mathbf{AX}`) |
| 162 | +# Note :math:`\mathbf{Y}` is distributed in the same way as the input |
| 163 | +# :math:`\mathbf{X}`. |
| 164 | +y = Aop @ x |
| 165 | + |
| 166 | +############################################################################### |
| 167 | +# Next we apply the adjoint pass :math:`\mathbf{x}_{adj} = \mathbf{A}^H \mathbf{x}` |
| 168 | +# (which effectively implements a distributed matrix-matrix multiplication |
| 169 | +# :math:`\mathbf{X}_{adj} = \mathbf{A}^H \mathbf{X}`). Note that |
| 170 | +# :math:`\mathbf{X}_{adj}` is again distributed in the same way as the input |
| 171 | +# :math:`\mathbf{X}`. |
| 172 | +xadj = Aop.H @ y |
| 173 | + |
| 174 | +############################################################################### |
| 175 | +# To conclude we verify our result against the equivalent serial version of |
| 176 | +# the operation by gathering the resulting matrices in rank0 and reorganizing |
| 177 | +# the returned 1D-arrays into 2D-arrays. |
| 178 | + |
| 179 | +# Local benchmarks |
| 180 | +y = y.asarray(masked=True) |
| 181 | +col_counts = [min(blk_cols, M - j * blk_cols) for j in range(p_prime)] |
| 182 | +y_blocks = [] |
| 183 | +offset = 0 |
| 184 | +for cnt in col_counts: |
| 185 | + block_size = N * cnt |
| 186 | + y_block = y[offset: offset + block_size] |
| 187 | + if len(y_block) != 0: |
| 188 | + y_blocks.append( |
| 189 | + y_block.reshape(N, cnt) |
| 190 | + ) |
| 191 | + offset += block_size |
| 192 | +y = np.hstack(y_blocks) |
| 193 | + |
| 194 | +xadj = xadj.asarray(masked=True) |
| 195 | +xadj_blocks = [] |
| 196 | +offset = 0 |
| 197 | +for cnt in col_counts: |
| 198 | + block_size = K * cnt |
| 199 | + xadj_blk = xadj[offset: offset + block_size] |
| 200 | + if len(xadj_blk) != 0: |
| 201 | + xadj_blocks.append( |
| 202 | + xadj_blk.reshape(K, cnt) |
| 203 | + ) |
| 204 | + offset += block_size |
| 205 | +xadj = np.hstack(xadj_blocks) |
| 206 | + |
| 207 | +if rank == 0: |
| 208 | + y_loc = (A @ X).squeeze() |
| 209 | + xadj_loc = (A.T.dot(y_loc.conj())).conj().squeeze() |
| 210 | + |
| 211 | + if not np.allclose(y, y_loc, rtol=1e-6): |
| 212 | + print("FORWARD VERIFICATION FAILED") |
| 213 | + print(f'distributed: {y}') |
| 214 | + print(f'expected: {y_loc}') |
| 215 | + else: |
| 216 | + print("FORWARD VERIFICATION PASSED") |
| 217 | + |
| 218 | + if not np.allclose(xadj, xadj_loc, rtol=1e-6): |
| 219 | + print("ADJOINT VERIFICATION FAILED") |
| 220 | + print(f'distributed: {xadj}') |
| 221 | + print(f'expected: {xadj_loc}') |
| 222 | + else: |
| 223 | + print("ADJOINT VERIFICATION PASSED") |
0 commit comments