|  | 
|  | 1 | +r""" | 
|  | 2 | +Distributed Matrix Multiplication - SUMMA | 
|  | 3 | +========================================= | 
|  | 4 | +This example shows how to use the :py:class:`pylops_mpi.basicoperators.MPIMatrixMult` | 
|  | 5 | +operator with ``kind='summa'`` to perform matrix-matrix multiplication between  | 
|  | 6 | +a matrix :math:`\mathbf{A}` distributed in 2D blocks across a square process  | 
|  | 7 | +grid and matrices :math:`\mathbf{X}` and :math:`\mathbf{Y}` distributed in 2D  | 
|  | 8 | +blocks across the same grid. Similarly, the adjoint operation can be performed  | 
|  | 9 | +with a matrix :math:`\mathbf{Y}` distributed in the same fashion as matrix  | 
|  | 10 | +:math:`\mathbf{X}`. | 
|  | 11 | +
 | 
|  | 12 | +Note that whilst the different blocks of matrix :math:`\mathbf{A}` are directly | 
|  | 13 | +stored in the operator on different ranks, the matrices :math:`\mathbf{X}` and | 
|  | 14 | +:math:`\mathbf{Y}` are effectively represented by 1-D :py:class:`pylops_mpi.DistributedArray` | 
|  | 15 | +objects where the different blocks are flattened and stored on different ranks. | 
|  | 16 | +Note that to optimize communications, the ranks are organized in a square grid and | 
|  | 17 | +blocks of :math:`\mathbf{A}` and :math:`\mathbf{X}` are systematically broadcast | 
|  | 18 | +across different ranks during computation - see below for details. | 
|  | 19 | +""" | 
|  | 20 | + | 
|  | 21 | +import math | 
|  | 22 | +import numpy as np | 
|  | 23 | +from mpi4py import MPI | 
|  | 24 | +from matplotlib import pyplot as plt | 
|  | 25 | + | 
|  | 26 | +import pylops_mpi | 
|  | 27 | +from pylops import Conj | 
|  | 28 | +from pylops_mpi.basicoperators.MatrixMult import \ | 
|  | 29 | +    local_block_split, MPIMatrixMult, active_grid_comm | 
|  | 30 | + | 
|  | 31 | +plt.close("all") | 
|  | 32 | + | 
|  | 33 | +############################################################################### | 
|  | 34 | +# We set the seed such that all processes can create the input matrices filled | 
|  | 35 | +# with the same random number. In practical applications, such matrices will be | 
|  | 36 | +# filled with data that is appropriate to the use-case. | 
|  | 37 | +np.random.seed(42) | 
|  | 38 | + | 
|  | 39 | +############################################################################### | 
|  | 40 | +# We are now ready to create the input matrices for our distributed matrix | 
|  | 41 | +# multiplication example. We need to set up: | 
|  | 42 | +# | 
|  | 43 | +# - Matrix :math:`\mathbf{A}` of size :math:`N \times K` (the left operand) | 
|  | 44 | +# - Matrix :math:`\mathbf{X}` of size :math:`K \times M` (the right operand)   | 
|  | 45 | +# - The result will be :math:`\mathbf{Y} = \mathbf{A} \mathbf{X}` of size  | 
|  | 46 | +# :math:`N \times M` | 
|  | 47 | +# | 
|  | 48 | +# We create here global test matrices with sequential values for easy verification: | 
|  | 49 | +# | 
|  | 50 | +# - Matrix A: Each element :math:`A_{i,j} = i \cdot K + j` (row-major ordering) | 
|  | 51 | +# - Matrix X: Each element :math:`X_{i,j} = i \cdot M + j` | 
|  | 52 | + | 
|  | 53 | +N, M, K = 6, 6, 6 | 
|  | 54 | +A_shape, x_shape, y_shape = (N, K), (K, M), (N, M) | 
|  | 55 | + | 
|  | 56 | +A_data = np.arange(int(A_shape[0] * A_shape[1])).reshape(A_shape) | 
|  | 57 | +x_data = np.arange(int(x_shape[0] * x_shape[1])).reshape(x_shape) | 
|  | 58 | + | 
|  | 59 | +################################################################################ | 
|  | 60 | +# For distributed computation, we arrange processes in a square grid of size | 
|  | 61 | +# :math:`P' \times P'` where :math:`P' = \sqrt{P}` and :math:`P` is the total  | 
|  | 62 | +# number of MPI processes. Each process will own a block of each matrix  | 
|  | 63 | +# according to this 2D grid layout. | 
|  | 64 | + | 
|  | 65 | +base_comm = MPI.COMM_WORLD | 
|  | 66 | +comm, rank, row_id, col_id, is_active = active_grid_comm(base_comm, N, M) | 
|  | 67 | +print(f"Process {base_comm.Get_rank()} is {'active' if is_active else 'inactive'}") | 
|  | 68 | + | 
|  | 69 | +p_prime = math.isqrt(comm.Get_size()) | 
|  | 70 | +print(f"Process grid: {p_prime} x {p_prime} = {comm.Get_size()} processes") | 
|  | 71 | + | 
|  | 72 | +if rank == 0: | 
|  | 73 | +    print(f"Global matrix A shape: {A_shape} (N={A_shape[0]}, K={A_shape[1]})") | 
|  | 74 | +    print(f"Global matrix X shape: {x_shape} (K={x_shape[0]}, M={x_shape[1]})") | 
|  | 75 | +    print(f"Expected Global result Y shape: ({A_shape[0]}, {x_shape[1]}) = (N, M)") | 
|  | 76 | + | 
|  | 77 | +################################################################################ | 
|  | 78 | +# Next we must determine which block of each matrix each process should own. | 
|  | 79 | +# | 
|  | 80 | +# The 2D block distribution requires: | 
|  | 81 | +# | 
|  | 82 | +# - Process at grid position :math:`(i,j)` gets block  | 
|  | 83 | +#   :math:`\mathbf{A}[i_{start}:i_{end}, j_{start}:j_{end}]` | 
|  | 84 | +# - Block sizes are approximately :math:`\lceil N/P' \rceil \times \lceil K/P' \rceil` | 
|  | 85 | +#   with edge processes handling remainder | 
|  | 86 | +# | 
|  | 87 | +# .. raw:: html | 
|  | 88 | +# | 
|  | 89 | +#   <div style="text-align: left; font-family: monospace; white-space: pre;"> | 
|  | 90 | +#   <b>Example: 2x2 Process Grid with 6x6 Matrices</b> | 
|  | 91 | +#    | 
|  | 92 | +#   Matrix A (6x6):                    Matrix X (6x6): | 
|  | 93 | +#   ┌───────────┬───────────┐      ┌───────────┬───────────┐ | 
|  | 94 | +#   │  0  1  2  │  3  4  5  │      │  0  1  2  │  3  4  5  │ | 
|  | 95 | +#   │  6  7  8  │  9 10 11  │      │  6  7  8  │  9 10 11  │ | 
|  | 96 | +#   │ 12 13 14  │ 15 16 17  │      │ 12 13 14  │ 15 16 17  │ | 
|  | 97 | +#   ├───────────┼───────────┤      ├───────────┼───────────┤ | 
|  | 98 | +#   │ 18 19 20  │ 21 22 23  │      │ 18 19 20  │ 21 22 23  │ | 
|  | 99 | +#   │ 24 25 26  │ 27 28 29  │      │ 24 25 26  │ 27 28 29  │ | 
|  | 100 | +#   │ 30 31 32  │ 33 34 35  │      │ 30 31 32  │ 33 34 35  │ | 
|  | 101 | +#   └───────────┴───────────┘      └───────────┴───────────┘ | 
|  | 102 | +#    | 
|  | 103 | +#   Process (0,0): A[0:3, 0:3], X[0:3, 0:3] | 
|  | 104 | +#   Process (0,1): A[0:3, 3:6], X[0:3, 3:6]   | 
|  | 105 | +#   Process (1,0): A[3:6, 0:3], X[3:6, 0:3] | 
|  | 106 | +#   Process (1,1): A[3:6, 3:6], X[3:6, 3:6] | 
|  | 107 | +#   </div> | 
|  | 108 | +# | 
|  | 109 | + | 
|  | 110 | +A_slice = local_block_split(A_shape, rank, comm) | 
|  | 111 | +x_slice = local_block_split(x_shape, rank, comm) | 
|  | 112 | + | 
|  | 113 | +################################################################################ | 
|  | 114 | +# Extract the local portion of each matrix for this process | 
|  | 115 | +A_local = A_data[A_slice] | 
|  | 116 | +x_local = x_data[x_slice] | 
|  | 117 | + | 
|  | 118 | +print(f"Process {rank}: A_local shape {A_local.shape}, X_local shape {x_local.shape}") | 
|  | 119 | +print(f"Process {rank}: A slice {A_slice}, X slice {x_slice}") | 
|  | 120 | + | 
|  | 121 | +################################################################################ | 
|  | 122 | + | 
|  | 123 | +################################################################################ | 
|  | 124 | +# We are now ready to create the SUMMA :py:class:`pylops_mpi.basicoperators.MPIMatrixMult` | 
|  | 125 | +# operator and the input matrix :math:`\mathbf{X}` | 
|  | 126 | + | 
|  | 127 | +Aop = MPIMatrixMult(A_local, M, base_comm=comm, kind="summa", dtype=A_local.dtype) | 
|  | 128 | + | 
|  | 129 | +x_dist = pylops_mpi.DistributedArray( | 
|  | 130 | +    global_shape=(K * M), | 
|  | 131 | +    local_shapes=comm.allgather(x_local.shape[0] * x_local.shape[1]), | 
|  | 132 | +    base_comm=comm, | 
|  | 133 | +    partition=pylops_mpi.Partition.SCATTER, | 
|  | 134 | +    dtype=x_local.dtype) | 
|  | 135 | +x_dist[:] = x_local.flatten() | 
|  | 136 | + | 
|  | 137 | +################################################################################ | 
|  | 138 | +# We can now apply the forward pass :math:`\mathbf{y} = \mathbf{Ax}` (which | 
|  | 139 | +# effectively implements a distributed matrix-matrix multiplication | 
|  | 140 | +# :math:`Y = \mathbf{AX}`). Note :math:`\mathbf{Y}` is distributed in the same | 
|  | 141 | +# way as the input :math:`\mathbf{X}` in a block-block fashion. | 
|  | 142 | +y_dist = Aop @ x_dist | 
|  | 143 | + | 
|  | 144 | +############################################################################### | 
|  | 145 | +# Next we apply the adjoint pass :math:`\mathbf{x}_{adj} = \mathbf{A}^H \mathbf{x}` | 
|  | 146 | +# (which effectively implements a distributed summa matrix-matrix multiplication | 
|  | 147 | +# :math:`\mathbf{X}_{adj} = \mathbf{A}^H \mathbf{X}`). Note that | 
|  | 148 | +# :math:`\mathbf{X}_{adj}` is again distributed in the same way as the input | 
|  | 149 | +# :math:`\mathbf{X}` in a block-block fashion. | 
|  | 150 | +xadj_dist = Aop.H @ y_dist | 
|  | 151 | + | 
|  | 152 | +############################################################################### | 
|  | 153 | +# Finally, we show that the SUMMA :py:class:`pylops_mpi.basicoperators.MPIMatrixMult` | 
|  | 154 | +# operator can be combined with any other PyLops-MPI operator. We are going to | 
|  | 155 | +# apply here a conjugate operator to the output of the matrix multiplication. | 
|  | 156 | +Dop = Conj(dims=(A_local.shape[0], x_local.shape[1])) | 
|  | 157 | +DBop = pylops_mpi.MPIBlockDiag(ops=[Dop, ]) | 
|  | 158 | +Op = DBop @ Aop | 
|  | 159 | +y1 = Op @ x_dist | 
0 commit comments