|
1 | | -""" |
| 1 | +r""" |
2 | 2 | Distributed Matrix Multiplication |
3 | 3 | ================================= |
4 | 4 | This example shows how to use the :py:class:`pylops_mpi.basicoperators.MPIMatrixMult` |
5 | | -operator to perform matrix-matrix multiplication between a matrix :math:`\mathbf{A}` |
| 5 | +operator to perform matrix-matrix multiplication between a matrix :math:`\mathbf{A}` |
6 | 6 | blocked over rows (i.e., blocks of rows are stored over different ranks) and a |
7 | | -matrix :math:`\mathbf{X}` blocked over columns (i.e., blocks of columns are |
8 | | -stored over different ranks), with equal number of row and column blocks. |
9 | | -Similarly, the adjoint operation can be peformed with a matrix :math:`\mathbf{Y}` |
| 7 | +matrix :math:`\mathbf{X}` blocked over columns (i.e., blocks of columns are |
| 8 | +stored over different ranks), with equal number of row and column blocks. |
| 9 | +Similarly, the adjoint operation can be peformed with a matrix :math:`\mathbf{Y}` |
10 | 10 | blocked in the same fashion of matrix :math:`\mathbf{X}`. |
11 | 11 |
|
12 | | -Note that whilst the different blocks of the matrix :math:`\mathbf{A}` are directly |
13 | | -stored in the operator on different ranks, the matrix :math:`\mathbf{X}` is |
14 | | -effectively represented by a 1-D :py:class:`pylops_mpi.DistributedArray` where |
| 12 | +Note that whilst the different blocks of the matrix :math:`\mathbf{A}` are directly |
| 13 | +stored in the operator on different ranks, the matrix :math:`\mathbf{X}` is |
| 14 | +effectively represented by a 1-D :py:class:`pylops_mpi.DistributedArray` where |
15 | 15 | the different blocks are flattened and stored on different ranks. Note that to |
16 | | -optimize communications, the ranks are organized in a 2D grid and some of the |
17 | | -row blocks of :math:`\mathbf{A}` and column blocks of :math:`\mathbf{X}` are |
18 | | -replicated across different ranks - see below for details. |
| 16 | +optimize communications, the ranks are organized in a 2D grid and some of the |
| 17 | +row blocks of :math:`\mathbf{A}` and column blocks of :math:`\mathbf{X}` are |
| 18 | +replicated across different ranks - see below for details. |
19 | 19 |
|
20 | 20 | """ |
21 | 21 |
|
|
30 | 30 | plt.close("all") |
31 | 31 |
|
32 | 32 | ############################################################################### |
33 | | -# We set the seed such that all processes can create the input matrices filled |
34 | | -# with the same random number. In practical application, such matrices will be |
| 33 | +# We set the seed such that all processes can create the input matrices filled |
| 34 | +# with the same random number. In practical application, such matrices will be |
35 | 35 | # filled with data that is appropriate that is appropriate the use-case. |
36 | 36 | np.random.seed(42) |
37 | 37 |
|
38 | 38 | ############################################################################### |
39 | | -# Next we obtain the MPI parameters for each rank and check that the number |
| 39 | +# Next we obtain the MPI parameters for each rank and check that the number |
40 | 40 | # of processes (``size``) is a square number |
41 | 41 | comm = MPI.COMM_WORLD |
42 | | -rank = comm.Get_rank() # rank of current process |
43 | | -size = comm.Get_size() # number of processes |
| 42 | +rank = comm.Get_rank() # rank of current process |
| 43 | +size = comm.Get_size() # number of processes |
44 | 44 |
|
45 | 45 | p_prime = math.isqrt(size) |
46 | 46 | repl_factor = p_prime |
|
58 | 58 | X = np.random.rand(K * M).astype(dtype=np.float32).reshape(K, M) |
59 | 59 |
|
60 | 60 | ################################################################################ |
61 | | -# The processes are now arranged in a :math:`\sqrt{P} \times \sqrt{P}` grid, |
| 61 | +# The processes are now arranged in a :math:`\sqrt{P} \times \sqrt{P}` grid, |
62 | 62 | # where :math:`P` is the total number of processes. |
63 | 63 | # |
64 | 64 | # We define |
|
71 | 71 | # .. math:: |
72 | 72 | # R = \bigl\lceil \tfrac{P}{P'} \bigr\rceil. |
73 | 73 | # |
74 | | -# Each process is therefore assigned a pair of coordinates |
| 74 | +# Each process is therefore assigned a pair of coordinates |
75 | 75 | # :math:`(r,c)` within this grid: |
76 | 76 | # |
77 | 77 | # .. math:: |
|
101 | 101 | col_comm = comm.Split(color=my_col, key=my_row) # all procs in same col |
102 | 102 |
|
103 | 103 | ################################################################################ |
104 | | -# At this point we divide the rows and columns of :math:`\mathbf{A}` and |
| 104 | +# At this point we divide the rows and columns of :math:`\mathbf{A}` and |
105 | 105 | # :math:`\mathbf{X}`, respectively, such that each rank ends up with: |
106 | 106 | # |
107 | 107 | # - :math:`A_{p} \in \mathbb{R}^{\text{my_own_rows}\times K}` |
|
147 | 147 | A_p, X_p = A[rs:re, :].copy(), X[:, cs:ce].copy() |
148 | 148 |
|
149 | 149 | ################################################################################ |
150 | | -# We are now ready to create the :py:class:`pylops_mpi.basicoperators.MPIMatrixMult` |
| 150 | +# We are now ready to create the :py:class:`pylops_mpi.basicoperators.MPIMatrixMult` |
151 | 151 | # operator and the input matrix :math:`\mathbf{X}` |
152 | 152 | Aop = MPIMatrixMult(A_p, M, dtype="float32") |
153 | 153 |
|
154 | 154 | col_lens = comm.allgather(my_own_cols) |
155 | | -total_cols = np.sum(col_lens) |
| 155 | +total_cols = np.sum(col_lens) |
156 | 156 | x = DistributedArray(global_shape=K * total_cols, |
157 | 157 | local_shapes=[K * col_len for col_len in col_lens], |
158 | 158 | partition=Partition.SCATTER, |
|
162 | 162 | x[:] = X_p.flatten() |
163 | 163 |
|
164 | 164 | ################################################################################ |
165 | | -# We can now apply the forward pass :math:`\mathbf{y} = \mathbf{Ax}` (which effectively |
| 165 | +# We can now apply the forward pass :math:`\mathbf{y} = \mathbf{Ax}` (which effectively |
166 | 166 | # implements a distributed matrix-matrix multiplication :math:`Y = \mathbf{AX}`) |
167 | | -# Note :math:`\mathbf{Y}` is distributed in the same way as the input |
| 167 | +# Note :math:`\mathbf{Y}` is distributed in the same way as the input |
168 | 168 | # :math:`\mathbf{X}`. |
169 | 169 | y = Aop @ x |
170 | 170 |
|
171 | 171 | ############################################################################### |
172 | | -# Next we apply the adjoint pass :math:`\mathbf{x}_{adj} = \mathbf{A}^H \mathbf{x}` |
173 | | -# (which effectively implements a distributed matrix-matrix multiplication |
| 172 | +# Next we apply the adjoint pass :math:`\mathbf{x}_{adj} = \mathbf{A}^H \mathbf{x}` |
| 173 | +# (which effectively implements a distributed matrix-matrix multiplication |
174 | 174 | # :math:`\mathbf{X}_{adj} = \mathbf{A}^H \mathbf{X}`). Note that |
175 | | -# :math:`\mathbf{X}_{adj}` is again distributed in the same way as the input |
| 175 | +# :math:`\mathbf{X}_{adj}` is again distributed in the same way as the input |
176 | 176 | # :math:`\mathbf{X}`. |
177 | 177 | xadj = Aop.H @ y |
178 | 178 |
|
179 | 179 | ############################################################################### |
180 | | -# To conclude we verify our result against the equivalent serial version of |
| 180 | +# To conclude we verify our result against the equivalent serial version of |
181 | 181 | # the operation by gathering the resulting matrices in rank0 and reorganizing |
182 | 182 | # the returned 1D-arrays into 2D-arrays. |
183 | 183 |
|
|
210 | 210 | xadj_loc = (A.T.dot(y_loc.conj())).conj().squeeze() |
211 | 211 |
|
212 | 212 | if not np.allclose(y, y_loc, rtol=1e-6): |
213 | | - print(f" FORWARD VERIFICATION FAILED") |
214 | | - print(f'distributed: {y}') |
| 213 | + print("FORWARD VERIFICATION FAILED") |
| 214 | + print(f'distributed: {y}') |
215 | 215 | print(f'expected: {y_loc}') |
216 | 216 | else: |
217 | | - print(f"FORWARD VERIFICATION PASSED") |
| 217 | + print("FORWARD VERIFICATION PASSED") |
218 | 218 |
|
219 | 219 | if not np.allclose(xadj, xadj_loc, rtol=1e-6): |
220 | | - print(f" ADJOINT VERIFICATION FAILED") |
221 | | - print(f'distributed: {xadj}') |
| 220 | + print("ADJOINT VERIFICATION FAILED") |
| 221 | + print(f'distributed: {xadj}') |
222 | 222 | print(f'expected: {xadj_loc}') |
223 | 223 | else: |
224 | | - print(f"ADJOINT VERIFICATION PASSED") |
| 224 | + print("ADJOINT VERIFICATION PASSED") |
0 commit comments