|
| 1 | +r""" |
| 2 | +Distributed Matrix Multiplication - SUMMA |
| 3 | +========================================= |
| 4 | +This example shows how to use the :py:class:`pylops_mpi.basicoperators.MPIMatrixMult` |
| 5 | +operator with ``kind='summa'`` to perform matrix-matrix multiplication between |
| 6 | +a matrix :math:`\mathbf{A}` distributed in 2D blocks across a square process |
| 7 | +grid and matrices :math:`\mathbf{X}` and :math:`\mathbf{Y}` distributed in 2D |
| 8 | +blocks across the same grid. Similarly, the adjoint operation can be performed |
| 9 | +with a matrix :math:`\mathbf{Y}` distributed in the same fashion as matrix |
| 10 | +:math:`\mathbf{X}`. |
| 11 | +
|
| 12 | +Note that whilst the different blocks of matrix :math:`\mathbf{A}` are directly |
| 13 | +stored in the operator on different ranks, the matrices :math:`\mathbf{X}` and |
| 14 | +:math:`\mathbf{Y}` are effectively represented by 1-D :py:class:`pylops_mpi.DistributedArray` |
| 15 | +objects where the different blocks are flattened and stored on different ranks. |
| 16 | +Note that to optimize communications, the ranks are organized in a square grid and |
| 17 | +blocks of :math:`\mathbf{A}` and :math:`\mathbf{X}` are systematically broadcast |
| 18 | +across different ranks during computation - see below for details. |
| 19 | +""" |
| 20 | + |
| 21 | +import math |
| 22 | +import numpy as np |
| 23 | +from mpi4py import MPI |
| 24 | +from matplotlib import pyplot as plt |
| 25 | + |
| 26 | +import pylops_mpi |
| 27 | +from pylops import Conj |
| 28 | +from pylops_mpi.basicoperators.MatrixMult import \ |
| 29 | + local_block_split, MPIMatrixMult, active_grid_comm |
| 30 | + |
| 31 | +plt.close("all") |
| 32 | + |
| 33 | +############################################################################### |
| 34 | +# We set the seed such that all processes can create the input matrices filled |
| 35 | +# with the same random number. In practical applications, such matrices will be |
| 36 | +# filled with data that is appropriate to the use-case. |
| 37 | +np.random.seed(42) |
| 38 | + |
| 39 | +############################################################################### |
| 40 | +# We are now ready to create the input matrices for our distributed matrix |
| 41 | +# multiplication example. We need to set up: |
| 42 | +# |
| 43 | +# - Matrix :math:`\mathbf{A}` of size :math:`N \times K` (the left operand) |
| 44 | +# - Matrix :math:`\mathbf{X}` of size :math:`K \times M` (the right operand) |
| 45 | +# - The result will be :math:`\mathbf{Y} = \mathbf{A} \mathbf{X}` of size |
| 46 | +# :math:`N \times M` |
| 47 | +# |
| 48 | +# We create here global test matrices with sequential values for easy verification: |
| 49 | +# |
| 50 | +# - Matrix A: Each element :math:`A_{i,j} = i \cdot K + j` (row-major ordering) |
| 51 | +# - Matrix X: Each element :math:`X_{i,j} = i \cdot M + j` |
| 52 | + |
| 53 | +N, M, K = 6, 6, 6 |
| 54 | +A_shape, x_shape, y_shape = (N, K), (K, M), (N, M) |
| 55 | + |
| 56 | +A_data = np.arange(int(A_shape[0] * A_shape[1])).reshape(A_shape) |
| 57 | +x_data = np.arange(int(x_shape[0] * x_shape[1])).reshape(x_shape) |
| 58 | + |
| 59 | +################################################################################ |
| 60 | +# For distributed computation, we arrange processes in a square grid of size |
| 61 | +# :math:`P' \times P'` where :math:`P' = \sqrt{P}` and :math:`P` is the total |
| 62 | +# number of MPI processes. Each process will own a block of each matrix |
| 63 | +# according to this 2D grid layout. |
| 64 | + |
| 65 | +base_comm = MPI.COMM_WORLD |
| 66 | +comm, rank, row_id, col_id, is_active = active_grid_comm(base_comm, N, M) |
| 67 | +print(f"Process {base_comm.Get_rank()} is {'active' if is_active else 'inactive'}") |
| 68 | + |
| 69 | +p_prime = math.isqrt(comm.Get_size()) |
| 70 | +print(f"Process grid: {p_prime} x {p_prime} = {comm.Get_size()} processes") |
| 71 | + |
| 72 | +if rank == 0: |
| 73 | + print(f"Global matrix A shape: {A_shape} (N={A_shape[0]}, K={A_shape[1]})") |
| 74 | + print(f"Global matrix X shape: {x_shape} (K={x_shape[0]}, M={x_shape[1]})") |
| 75 | + print(f"Expected Global result Y shape: ({A_shape[0]}, {x_shape[1]}) = (N, M)") |
| 76 | + |
| 77 | +################################################################################ |
| 78 | +# Next we must determine which block of each matrix each process should own. |
| 79 | +# |
| 80 | +# The 2D block distribution requires: |
| 81 | +# |
| 82 | +# - Process at grid position :math:`(i,j)` gets block |
| 83 | +# :math:`\mathbf{A}[i_{start}:i_{end}, j_{start}:j_{end}]` |
| 84 | +# - Block sizes are approximately :math:`\lceil N/P' \rceil \times \lceil K/P' \rceil` |
| 85 | +# with edge processes handling remainder |
| 86 | +# |
| 87 | +# .. raw:: html |
| 88 | +# |
| 89 | +# <div style="text-align: left; font-family: monospace; white-space: pre;"> |
| 90 | +# <b>Example: 2x2 Process Grid with 6x6 Matrices</b> |
| 91 | +# |
| 92 | +# Matrix A (6x6): Matrix X (6x6): |
| 93 | +# ┌───────────┬───────────┐ ┌───────────┬───────────┐ |
| 94 | +# │ 0 1 2 │ 3 4 5 │ │ 0 1 2 │ 3 4 5 │ |
| 95 | +# │ 6 7 8 │ 9 10 11 │ │ 6 7 8 │ 9 10 11 │ |
| 96 | +# │ 12 13 14 │ 15 16 17 │ │ 12 13 14 │ 15 16 17 │ |
| 97 | +# ├───────────┼───────────┤ ├───────────┼───────────┤ |
| 98 | +# │ 18 19 20 │ 21 22 23 │ │ 18 19 20 │ 21 22 23 │ |
| 99 | +# │ 24 25 26 │ 27 28 29 │ │ 24 25 26 │ 27 28 29 │ |
| 100 | +# │ 30 31 32 │ 33 34 35 │ │ 30 31 32 │ 33 34 35 │ |
| 101 | +# └───────────┴───────────┘ └───────────┴───────────┘ |
| 102 | +# |
| 103 | +# Process (0,0): A[0:3, 0:3], X[0:3, 0:3] |
| 104 | +# Process (0,1): A[0:3, 3:6], X[0:3, 3:6] |
| 105 | +# Process (1,0): A[3:6, 0:3], X[3:6, 0:3] |
| 106 | +# Process (1,1): A[3:6, 3:6], X[3:6, 3:6] |
| 107 | +# </div> |
| 108 | +# |
| 109 | + |
| 110 | +A_slice = local_block_split(A_shape, rank, comm) |
| 111 | +x_slice = local_block_split(x_shape, rank, comm) |
| 112 | + |
| 113 | +################################################################################ |
| 114 | +# Extract the local portion of each matrix for this process |
| 115 | +A_local = A_data[A_slice] |
| 116 | +x_local = x_data[x_slice] |
| 117 | + |
| 118 | +print(f"Process {rank}: A_local shape {A_local.shape}, X_local shape {x_local.shape}") |
| 119 | +print(f"Process {rank}: A slice {A_slice}, X slice {x_slice}") |
| 120 | + |
| 121 | +################################################################################ |
| 122 | + |
| 123 | +################################################################################ |
| 124 | +# We are now ready to create the SUMMA :py:class:`pylops_mpi.basicoperators.MPIMatrixMult` |
| 125 | +# operator and the input matrix :math:`\mathbf{X}` |
| 126 | + |
| 127 | +Aop = MPIMatrixMult(A_local, M, base_comm=comm, kind="summa", dtype=A_local.dtype) |
| 128 | + |
| 129 | +x_dist = pylops_mpi.DistributedArray( |
| 130 | + global_shape=(K * M), |
| 131 | + local_shapes=comm.allgather(x_local.shape[0] * x_local.shape[1]), |
| 132 | + base_comm=comm, |
| 133 | + partition=pylops_mpi.Partition.SCATTER, |
| 134 | + dtype=x_local.dtype) |
| 135 | +x_dist[:] = x_local.flatten() |
| 136 | + |
| 137 | +################################################################################ |
| 138 | +# We can now apply the forward pass :math:`\mathbf{y} = \mathbf{Ax}` (which |
| 139 | +# effectively implements a distributed matrix-matrix multiplication |
| 140 | +# :math:`Y = \mathbf{AX}`). Note :math:`\mathbf{Y}` is distributed in the same |
| 141 | +# way as the input :math:`\mathbf{X}` in a block-block fashion. |
| 142 | +y_dist = Aop @ x_dist |
| 143 | + |
| 144 | +############################################################################### |
| 145 | +# Next we apply the adjoint pass :math:`\mathbf{x}_{adj} = \mathbf{A}^H \mathbf{x}` |
| 146 | +# (which effectively implements a distributed summa matrix-matrix multiplication |
| 147 | +# :math:`\mathbf{X}_{adj} = \mathbf{A}^H \mathbf{X}`). Note that |
| 148 | +# :math:`\mathbf{X}_{adj}` is again distributed in the same way as the input |
| 149 | +# :math:`\mathbf{X}` in a block-block fashion. |
| 150 | +xadj_dist = Aop.H @ y_dist |
| 151 | + |
| 152 | +############################################################################### |
| 153 | +# Finally, we show that the SUMMA :py:class:`pylops_mpi.basicoperators.MPIMatrixMult` |
| 154 | +# operator can be combined with any other PyLops-MPI operator. We are going to |
| 155 | +# apply here a conjugate operator to the output of the matrix multiplication. |
| 156 | +Dop = Conj(dims=(A_local.shape[0], x_local.shape[1])) |
| 157 | +DBop = pylops_mpi.MPIBlockDiag(ops=[Dop, ]) |
| 158 | +Op = DBop @ Aop |
| 159 | +y1 = Op @ x_dist |
0 commit comments