|
1 | 1 | r""" |
2 | | -Distributed SUMMA Matrix Multiplication |
3 | | -======================================= |
4 | | -This example shows how to use the :py:class:`pylops_mpi.basicoperators._MPISummaMatrixMult` |
5 | | -operator to perform matrix-matrix multiplication between a matrix :math:`\mathbf{A}` |
6 | | -distributed in 2D blocks across a square process grid and matrices :math:`\mathbf{X}` |
7 | | -and :math:`\mathbf{Y}` distributed in 2D blocks across the same grid. Similarly, |
8 | | -the adjoint operation can be performed with a matrix :math:`\mathbf{Y}` distributed |
9 | | -in the same fashion as matrix :math:`\mathbf{X}`. |
| 2 | +Distributed Matrix Multiplication - SUMMA |
| 3 | +========================================= |
| 4 | +This example shows how to use the :py:class:`pylops_mpi.basicoperators.MPIMatrixMult` |
| 5 | +operator with ``kind='summa'`` to perform matrix-matrix multiplication between |
| 6 | +a matrix :math:`\mathbf{A}` distributed in 2D blocks across a square process |
| 7 | +grid and matrices :math:`\mathbf{X}` and :math:`\mathbf{Y}` distributed in 2D |
| 8 | +blocks across the same grid. Similarly, the adjoint operation can be performed |
| 9 | +with a matrix :math:`\mathbf{Y}` distributed in the same fashion as matrix |
| 10 | +:math:`\mathbf{X}`. |
10 | 11 |
|
11 | 12 | Note that whilst the different blocks of matrix :math:`\mathbf{A}` are directly |
12 | 13 | stored in the operator on different ranks, the matrices :math:`\mathbf{X}` and |
|
24 | 25 |
|
25 | 26 | import pylops_mpi |
26 | 27 | from pylops import Conj |
27 | | -from pylops_mpi.basicoperators.MatrixMult import (local_block_spit, MPIMatrixMult, active_grid_comm) |
| 28 | +from pylops_mpi.basicoperators.MatrixMult import \ |
| 29 | + local_block_split, MPIMatrixMult, active_grid_comm |
28 | 30 |
|
29 | 31 | plt.close("all") |
30 | 32 |
|
31 | 33 | ############################################################################### |
32 | 34 | # We set the seed such that all processes can create the input matrices filled |
33 | 35 | # with the same random number. In practical application, such matrices will be |
34 | | -# filled with data that is appropriate that is appropriate the use-case. |
| 36 | +# filled with data that is appropriate to the use-case. |
35 | 37 | np.random.seed(42) |
36 | 38 |
|
37 | | - |
38 | | -N, M, K = 6, 6, 6 |
39 | | -A_shape, x_shape, y_shape= (N, K), (K, M), (N, M) |
40 | | - |
41 | | - |
42 | | -base_comm = MPI.COMM_WORLD |
43 | | -comm, rank, row_id, col_id, is_active = active_grid_comm(base_comm, N, M) |
44 | | -print(f"Process {base_comm.Get_rank()} is {'active' if is_active else 'inactive'}") |
45 | | - |
46 | | - |
47 | 39 | ############################################################################### |
48 | 40 | # We are now ready to create the input matrices for our distributed matrix |
49 | 41 | # multiplication example. We need to set up: |
| 42 | +# |
50 | 43 | # - Matrix :math:`\mathbf{A}` of size :math:`N \times K` (the left operand) |
51 | 44 | # - Matrix :math:`\mathbf{X}` of size :math:`K \times M` (the right operand) |
52 | | -# - The result will be :math:`\mathbf{Y} = \mathbf{A} \mathbf{X}` of size :math:`N \times M` |
| 45 | +# - The result will be :math:`\mathbf{Y} = \mathbf{A} \mathbf{X}` of size |
| 46 | +# :math:`N \times M` |
| 47 | +# |
| 48 | +# We create here global test matrices with sequential values for easy verification: |
53 | 49 | # |
| 50 | +# - Matrix A: Each element :math:`A_{i,j} = i \cdot K + j` (row-major ordering) |
| 51 | +# - Matrix X: Each element :math:`X_{i,j} = i \cdot M + j` |
| 52 | + |
| 53 | +N, M, K = 6, 6, 6 |
| 54 | +A_shape, x_shape, y_shape = (N, K), (K, M), (N, M) |
| 55 | + |
| 56 | +A_data = np.arange(int(A_shape[0] * A_shape[1])).reshape(A_shape) |
| 57 | +x_data = np.arange(int(x_shape[0] * x_shape[1])).reshape(x_shape) |
| 58 | + |
| 59 | +################################################################################ |
54 | 60 | # For distributed computation, we arrange processes in a square grid of size |
55 | 61 | # :math:`P' \times P'` where :math:`P' = \sqrt{P}` and :math:`P` is the total |
56 | 62 | # number of MPI processes. Each process will own a block of each matrix |
57 | 63 | # according to this 2D grid layout. |
58 | 64 |
|
| 65 | +base_comm = MPI.COMM_WORLD |
| 66 | +comm, rank, row_id, col_id, is_active = active_grid_comm(base_comm, N, M) |
| 67 | +print(f"Process {base_comm.Get_rank()} is {'active' if is_active else 'inactive'}") |
| 68 | + |
59 | 69 | p_prime = math.isqrt(comm.Get_size()) |
60 | 70 | print(f"Process grid: {p_prime} x {p_prime} = {comm.Get_size()} processes") |
61 | 71 |
|
62 | | -# Create global test matrices with sequential values for easy verification |
63 | | -# Matrix A: Each element :math:`A_{i,j} = i \cdot K + j` (row-major ordering) |
64 | | -# Matrix X: Each element :math:`X_{i,j} = i \cdot M + j` |
65 | | -A_data = np.arange(int(A_shape[0] * A_shape[1])).reshape(A_shape) |
66 | | -x_data = np.arange(int(x_shape[0] * x_shape[1])).reshape(x_shape) |
67 | | - |
68 | | -print(f"Global matrix A shape: {A_shape} (N={A_shape[0]}, K={A_shape[1]})") |
69 | | -print(f"Global matrix X shape: {x_shape} (K={x_shape[0]}, M={x_shape[1]})") |
70 | | -print(f"Expected Global result Y shape: ({A_shape[0]}, {x_shape[1]}) = (N, M)") |
| 72 | +if rank == 0: |
| 73 | + print(f"Global matrix A shape: {A_shape} (N={A_shape[0]}, K={A_shape[1]})") |
| 74 | + print(f"Global matrix X shape: {x_shape} (K={x_shape[0]}, M={x_shape[1]})") |
| 75 | + print(f"Expected Global result Y shape: ({A_shape[0]}, {x_shape[1]}) = (N, M)") |
71 | 76 |
|
72 | 77 | ################################################################################ |
73 | | -# Determine which block of each matrix this process should own |
74 | | -# The 2D block distribution ensures: |
75 | | -# - Process at grid position :math:`(i,j)` gets block :math:`\mathbf{A}[i_{start}:i_{end}, j_{start}:j_{end}]` |
76 | | -# - Block sizes are approximately :math:`\lceil N/P' \rceil \times \lceil K/P' \rceil` with edge processes handling remainder |
| 78 | +# Next we must determine which block of each matrix each process should own. |
| 79 | +# |
| 80 | +# The 2D block distribution requires: |
| 81 | +# |
| 82 | +# - Process at grid position :math:`(i,j)` gets block |
| 83 | +# :math:`\mathbf{A}[i_{start}:i_{end}, j_{start}:j_{end}]` |
| 84 | +# - Block sizes are approximately :math:`\lceil N/P' \rceil \times \lceil K/P' \rceil` |
| 85 | +# with edge processes handling remainder |
77 | 86 | # |
78 | 87 | # .. raw:: html |
79 | 88 | # |
|
98 | 107 | # </div> |
99 | 108 | # |
100 | 109 |
|
101 | | -A_slice = local_block_spit(A_shape, rank, comm) |
102 | | -x_slice = local_block_spit(x_shape, rank, comm) |
| 110 | +A_slice = local_block_split(A_shape, rank, comm) |
| 111 | +x_slice = local_block_split(x_shape, rank, comm) |
| 112 | + |
103 | 113 | ################################################################################ |
104 | 114 | # Extract the local portion of each matrix for this process |
105 | 115 | A_local = A_data[A_slice] |
|
108 | 118 | print(f"Process {rank}: A_local shape {A_local.shape}, X_local shape {x_local.shape}") |
109 | 119 | print(f"Process {rank}: A slice {A_slice}, X slice {x_slice}") |
110 | 120 |
|
111 | | -x_dist = pylops_mpi.DistributedArray(global_shape=(K * M), |
112 | | - local_shapes=comm.allgather(x_local.shape[0] * x_local.shape[1]), |
113 | | - base_comm=comm, |
114 | | - partition=pylops_mpi.Partition.SCATTER, |
115 | | - dtype=x_local.dtype) |
116 | | -x_dist[:] = x_local.flatten() |
| 121 | +################################################################################ |
117 | 122 |
|
118 | 123 | ################################################################################ |
119 | 124 | # We are now ready to create the SUMMA :py:class:`pylops_mpi.basicoperators.MPIMatrixMult` |
120 | | -# operator and the input matrix :math:`\mathbf{X}`. Given that we chose a block-block distribution |
121 | | -# of data we shall use SUMMA |
| 125 | +# operator and the input matrix :math:`\mathbf{X}` |
| 126 | + |
122 | 127 | Aop = MPIMatrixMult(A_local, M, base_comm=comm, kind="summa", dtype=A_local.dtype) |
123 | 128 |
|
| 129 | +x_dist = pylops_mpi.DistributedArray( |
| 130 | + global_shape=(K * M), |
| 131 | + local_shapes=comm.allgather(x_local.shape[0] * x_local.shape[1]), |
| 132 | + base_comm=comm, |
| 133 | + partition=pylops_mpi.Partition.SCATTER, |
| 134 | + dtype=x_local.dtype) |
| 135 | +x_dist[:] = x_local.flatten() |
| 136 | + |
124 | 137 | ################################################################################ |
125 | 138 | # We can now apply the forward pass :math:`\mathbf{y} = \mathbf{Ax}` (which |
126 | 139 | # effectively implements a distributed matrix-matrix multiplication |
|
0 commit comments