|
1 | 1 | """ |
2 | 2 | Distributed Matrix Multiplication |
3 | | -========================= |
4 | | -This example shows how to use the :py:class:`pylops_mpi.basicoperators.MatrixMultiply.SUMMAMatrixMult`. |
| 3 | +================================= |
| 4 | +This example shows how to use the :py:class:`pylops_mpi.basicoperators.MatrixMult.MPIMatrixMult`. |
5 | 5 | This class provides a way to distribute arrays across multiple processes in |
6 | 6 | a parallel computing environment. |
7 | 7 | """ |
8 | | - |
| 8 | +from matplotlib import pyplot as plt |
9 | 9 | import math |
10 | 10 | import numpy as np |
11 | 11 | from mpi4py import MPI |
12 | 12 |
|
13 | 13 | from pylops_mpi import DistributedArray, Partition |
14 | 14 | from pylops_mpi.basicoperators.MatrixMult import MPIMatrixMult |
15 | 15 |
|
| 16 | +plt.close("all") |
| 17 | +############################################################################### |
| 18 | +# We set the seed such that all nodes initially start out with the same initial matrix. |
| 19 | +# Ideally this data would be loaded in a manner appropriate to the use-case. |
16 | 20 | np.random.seed(42) |
17 | 21 |
|
| 22 | +# MPI parameters |
18 | 23 | comm = MPI.COMM_WORLD |
19 | | -rank = comm.Get_rank() |
20 | | -n_procs = comm.Get_size() |
| 24 | +rank = comm.Get_rank() # rank of current node |
| 25 | +size = comm.Get_size() # number of nodes |
21 | 26 |
|
22 | | -P_prime = int(math.ceil(math.sqrt(n_procs))) |
23 | | -C = int(math.ceil(n_procs / P_prime)) |
| 27 | +p_prime = int(math.ceil(math.sqrt(size))) |
| 28 | +C = int(math.ceil(size / p_prime)) |
24 | 29 |
|
25 | | -if (P_prime * C) != n_procs: |
| 30 | +if (p_prime * C) != size: |
26 | 31 | print("No. of procs has to be a square number") |
27 | 32 | exit(-1) |
28 | 33 |
|
29 | 34 | # matrix dims |
30 | | -M = 33 |
31 | | -K = 34 |
32 | | -N = 37 |
33 | | - |
| 35 | +M, K, N = 4, 4, 4 |
34 | 36 | A = np.random.rand(M * K).astype(dtype=np.float32).reshape(M, K) |
35 | | -B = np.random.rand(K * N).astype(dtype=np.float32).reshape(K, N) |
36 | | - |
37 | | -my_group = rank % P_prime |
38 | | -my_layer = rank // P_prime |
39 | | - |
40 | | -# sub‐communicators |
| 37 | +X = np.random.rand(K * N).astype(dtype=np.float32).reshape(K, N) |
| 38 | +################################################################################ |
| 39 | +#Process Grid Organization |
| 40 | +#************************* |
| 41 | +# |
| 42 | +#The processes are arranged in a :math:`\sqrt{P} \times \sqrt{P}` grid, where :math:`P` is the total number of processes. |
| 43 | +# |
| 44 | +#Define |
| 45 | +# |
| 46 | +#.. math:: |
| 47 | +# P' = \bigl \lceil \sqrt{P} \bigr \rceil |
| 48 | +# |
| 49 | +#and the replication factor |
| 50 | +# |
| 51 | +#.. math:: |
| 52 | +# C = \bigl\lceil \tfrac{P}{P'} \bigr\rceil. |
| 53 | +# |
| 54 | +#Each process is assigned a pair of coordinates :math:`(g, l)` within this grid: |
| 55 | +# |
| 56 | +#.. math:: |
| 57 | +# g = \mathrm{rank} \bmod P', |
| 58 | +# \quad |
| 59 | +# l = \left\lfloor \frac{\mathrm{rank}}{P'} \right\rfloor. |
| 60 | +# |
| 61 | +#For example, when :math:`P = 4` we have :math:`P' = 2`, giving a 2×2 layout: |
| 62 | +# |
| 63 | +#.. raw:: html |
| 64 | +# |
| 65 | +# <div style="text-align: center; font-family: monospace; white-space: pre;"> |
| 66 | +# ┌────────────┬────────────┐ |
| 67 | +# │ Rank 0 │ Rank 1 │ |
| 68 | +# │ (g=0, l=0) │ (g=1, l=0) │ |
| 69 | +# ├────────────┼────────────┤ |
| 70 | +# │ Rank 2 │ Rank 3 │ |
| 71 | +# │ (g=0, l=1) │ (g=1, l=1) │ |
| 72 | +# └────────────┴────────────┘ |
| 73 | +# </div> |
| 74 | + |
| 75 | +my_group = rank % p_prime |
| 76 | +my_layer = rank // p_prime |
| 77 | + |
| 78 | +# Create the sub‐communicators |
41 | 79 | layer_comm = comm.Split(color=my_layer, key=my_group) # all procs in same layer |
42 | 80 | group_comm = comm.Split(color=my_group, key=my_layer) # all procs in same group |
43 | 81 |
|
44 | | - |
45 | | -#Each rank will end up with: |
46 | | -# - :math:`A_{p} \in \mathbb{R}^{\text{my\_own\_rows}\times K}` |
47 | | -# - :math:`B_{p} \in \mathbb{R}^{K\times \text{my\_own\_cols}}` |
48 | | -# where |
49 | | -blk_rows = int(math.ceil(M / P_prime)) |
50 | | -blk_cols = int(math.ceil(N / P_prime)) |
| 82 | +blk_rows = int(math.ceil(M / p_prime)) |
| 83 | +blk_cols = int(math.ceil(N / p_prime)) |
51 | 84 |
|
52 | 85 | rs = my_group * blk_rows |
53 | 86 | re = min(M, rs + blk_rows) |
|
57 | 90 | ce = min(N, cs + blk_cols) |
58 | 91 | my_own_cols = ce - cs |
59 | 92 |
|
60 | | -A_p, B_p = A[rs:re, :].copy(), B[:, cs:ce].copy() |
61 | | - |
| 93 | +################################################################################ |
| 94 | +#Each rank will end up with: |
| 95 | +# - :math:`A_{p} \in \mathbb{R}^{\text{my_own_rows}\times K}` |
| 96 | +# - :math:`X_{p} \in \mathbb{R}^{K\times \text{my_own_cols}}` |
| 97 | +#as follows: |
| 98 | +A_p, X_p = A[rs:re, :].copy(), X[:, cs:ce].copy() |
| 99 | + |
| 100 | +################################################################################ |
| 101 | +#.. raw:: html |
| 102 | +# |
| 103 | +# <div style="text-align: left; font-family: monospace; white-space: pre;"> |
| 104 | +# <b>Matrix A (4 x 4):</b> |
| 105 | +# ┌─────────────────┐ |
| 106 | +# │ a11 a12 a13 a14 │ <- Rows 0–1 (Group 0) |
| 107 | +# │ a21 a22 a23 a24 │ |
| 108 | +# ├─────────────────┤ |
| 109 | +# │ a41 a42 a43 a44 │ <- Rows 2–3 (Group 1) |
| 110 | +# │ a51 a52 a53 a54 │ |
| 111 | +# └─────────────────┘ |
| 112 | +# </div> |
| 113 | +# |
| 114 | +#.. raw:: html |
| 115 | +# |
| 116 | +# <div style="text-align: left; font-family: monospace; white-space: pre;"> |
| 117 | +# <b>Matrix B (4 x 4):</b> |
| 118 | +# ┌─────────┬─────────┐ |
| 119 | +# │ b11 b12 │ b13 b14 │ <- Cols 0–1 (Layer 0), Cols 2–3 (Layer 1) |
| 120 | +# │ b21 b22 │ b23 b24 │ |
| 121 | +# │ b31 b32 │ b33 b34 │ |
| 122 | +# │ b41 b42 │ b43 b44 │ |
| 123 | +# └─────────┴─────────┘ |
| 124 | +# |
| 125 | +# </div> |
| 126 | +# |
| 127 | + |
| 128 | +################################################################################ |
| 129 | +#To perform our distributed matrix-matrix multiplication :math:`Y = \text{Aop} \times X` we need to create our distributed operator :math:`\text{Aop}` and distributed operand :math:`X` from :math:`A_p` and |
| 130 | +#:math:`X_p` respectively |
62 | 131 | Aop = MPIMatrixMult(A_p, N, dtype="float32") |
| 132 | +################################################################################ |
| 133 | +# While as well passing the appropriate values. |
63 | 134 | col_lens = comm.allgather(my_own_cols) |
64 | 135 | total_cols = np.sum(col_lens) |
65 | 136 | x = DistributedArray(global_shape=K * total_cols, |
66 | 137 | local_shapes=[K * col_len for col_len in col_lens], |
67 | 138 | partition=Partition.SCATTER, |
68 | | - mask=[i % P_prime for i in range(comm.Get_size())], |
| 139 | + mask=[i // p_prime for i in range(comm.Get_size())], |
69 | 140 | base_comm=comm, |
70 | 141 | dtype="float32") |
71 | | -x[:] = B_p.flatten() |
| 142 | +x[:] = X_p.flatten() |
| 143 | +################################################################################ |
| 144 | +#When we perform the matrix-matrix multiplication we shall then obtain a distributed :math:`Y` in the same way our :math:`X` was distributed. |
72 | 145 | y = Aop @ x |
73 | | - |
74 | | -# ======================= VERIFICATION =================-============= |
75 | | -y_loc = A @ B |
| 146 | +############################################################################### |
| 147 | +# In a similar fashion we then perform the Adjoint :math:`Xadj = A^H * Y` |
| 148 | +xadj = Aop.H @ y |
| 149 | +############################################################################### |
| 150 | +#Here we verify the result against the equivalent serial version of the operation. Each rank checks that it has computed the correct values for it partition. |
| 151 | +y_loc = A @ X |
76 | 152 | xadj_loc = (A.T.dot(y_loc.conj())).conj() |
77 | 153 |
|
78 | | - |
79 | 154 | expected_y_loc = y_loc[:, cs:ce].flatten().astype(np.float32) |
80 | 155 | expected_xadj_loc = xadj_loc[:, cs:ce].flatten().astype(np.float32) |
81 | 156 |
|
82 | | -xadj = Aop.H @ y |
83 | 157 | if not np.allclose(y.local_array, expected_y_loc, rtol=1e-6): |
84 | 158 | print(f"RANK {rank}: FORWARD VERIFICATION FAILED") |
85 | 159 | print(f'{rank} local: {y.local_array}, expected: {y_loc[:, cs:ce]}') |
|
0 commit comments