Skip to content

Commit ae5661b

Browse files
committed
minor: small improvements to text
1 parent b7e6702 commit ae5661b

File tree

3 files changed

+56
-55
lines changed

3 files changed

+56
-55
lines changed

examples/plot_matrixmult.py

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
matrix :math:`\mathbf{X}` blocked over columns (i.e., blocks of columns are
88
stored over different ranks), with equal number of row and column blocks.
99
Similarly, the adjoint operation can be peformed with a matrix :math:`\mathbf{Y}`
10-
blocked in the same fashion of matrix :math:`\mathbf{X}.
10+
blocked in the same fashion of matrix :math:`\mathbf{X}`.
1111
1212
Note that whilst the different blocks of the matrix :math:`\mathbf{A}` are directly
1313
stored in the operator on different ranks, the matrix :math:`\mathbf{X}` is
@@ -18,6 +18,7 @@
1818
replicated across different ranks - see below for details.
1919
2020
"""
21+
2122
from matplotlib import pyplot as plt
2223
import math
2324
import numpy as np
@@ -41,7 +42,7 @@
4142
rank = comm.Get_rank() # rank of current process
4243
size = comm.Get_size() # number of processes
4344

44-
p_prime = math.isqrt(size)
45+
p_prime = math.isqrt(size)
4546
repl_factor = p_prime
4647

4748
if (p_prime * repl_factor) != size:
@@ -78,19 +79,19 @@
7879
# \quad
7980
# c = \mathrm{rank} \bmod P'.
8081
#
81-
#For example, when :math:`P = 4` we have :math:`P' = 2`, giving a 2×2 layout:
82+
# For example, when :math:`P = 4` we have :math:`P' = 2`, giving a 2×2 layout:
8283
#
83-
#.. raw:: html
84+
# .. raw:: html
8485
#
85-
# <div style="text-align: center; font-family: monospace; white-space: pre;">
86-
# ┌────────────┬────────────┐
87-
# │ Rank 0 │ Rank 1 │
88-
# │ (r=0, c=0) │ (r=0, c=1) │
89-
# ├────────────┼────────────┤
90-
# │ Rank 2 │ Rank 3 │
91-
# │ (r=1, c=0) │ (r=1, c=1) │
92-
# └────────────┴────────────┘
93-
# </div>
86+
# <div style="text-align: center; font-family: monospace; white-space: pre;">
87+
# ┌────────────┬────────────┐
88+
# │ Rank 0 │ Rank 1 │
89+
# │ (r=0, c=0) │ (r=0, c=1) │
90+
# ├────────────┼────────────┤
91+
# │ Rank 2 │ Rank 3 │
92+
# │ (r=1, c=0) │ (r=1, c=1) │
93+
# └────────────┴────────────┘
94+
# </div>
9495

9596
my_col = rank % p_prime
9697
my_row = rank // p_prime
@@ -111,10 +112,10 @@
111112
# <div style="text-align: left; font-family: monospace; white-space: pre;">
112113
# <b>Matrix A (4 x 4):</b>
113114
# ┌─────────────────┐
114-
# │ a11 a12 a13 a14 │ <- Rows 0–1 (Process Grid Col 0)
115+
# │ a11 a12 a13 a14 │ <- Rows 0–1 (Process Grid Row 0)
115116
# │ a21 a22 a23 a24 │
116117
# ├─────────────────┤
117-
# │ a41 a42 a43 a44 │ <- Rows 2–3 (Process Grid Col 1)
118+
# │ a41 a42 a43 a44 │ <- Rows 2–3 (Process Grid Row 1)
118119
# │ a51 a52 a53 a54 │
119120
# └─────────────────┘
120121
# </div>
@@ -124,7 +125,7 @@
124125
# <div style="text-align: left; font-family: monospace; white-space: pre;">
125126
# <b>Matrix X (4 x 4):</b>
126127
# ┌─────────┬─────────┐
127-
# │ b11 b12 │ b13 b14 │ <- Cols 0–1 (Process Grid Row 0), Cols 2–3 (Process Grid Row 1)
128+
# │ b11 b12 │ b13 b14 │ <- Cols 0–1 (Process Grid Col 0), Cols 2–3 (Process Grid Col 1)
128129
# │ b21 b22 │ b23 b24 │
129130
# │ b31 b32 │ b33 b34 │
130131
# │ b41 b42 │ b43 b44 │
@@ -147,7 +148,7 @@
147148

148149
################################################################################
149150
# We are now ready to create the :py:class:`pylops_mpi.basicoperators.MPIMatrixMult`
150-
# operator and the input matrix math:`\mathbf{X}`
151+
# operator and the input matrix :math:`\mathbf{X}`
151152
Aop = MPIMatrixMult(A_p, M, dtype="float32")
152153

153154
col_lens = comm.allgather(my_own_cols)

pylops_mpi/basicoperators/MatrixMult.py

Lines changed: 20 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ class MPIMatrixMult(MPILinearOperator):
2323
----------
2424
A : :obj:`numpy.ndarray`
2525
Local block of the matrix of shape :math:`[N_{loc} \times K]`
26-
where ``N_loc`` is the number of rows stored on this MPI rank and
26+
where :math:`N_{loc}` is the number of rows stored on this MPI rank and
2727
``K`` is the global number of columns.
2828
M : :obj:`int`
2929
Global leading dimension (i.e., number of columns) of the matrices
@@ -46,7 +46,7 @@ class MPIMatrixMult(MPILinearOperator):
4646
Raises
4747
------
4848
Exception
49-
If the operator is created without a square number of mpi ranks.
49+
If the operator is created with a non-square number of MPI ranks.
5050
ValueError
5151
If input vector does not have the correct partition type.
5252
@@ -64,15 +64,15 @@ class MPIMatrixMult(MPILinearOperator):
6464
:math:`\mathbf{A}^H` is the complex conjugate and transpose of :math:`\mathbf{A}`.
6565
6666
This implementation is based on a 1D block distribution of the operator
67-
matrix and reshaped model and data vectors replicated across math:`P`
67+
matrix and reshaped model and data vectors replicated across :math:`P`
6868
processes by a factor equivalent to :math:`\sqrt{P}` across a square process
6969
grid (:math:`\sqrt{P}\times\sqrt{P}`). More specifically:
7070
7171
- The matrix ``A`` is distributed across MPI processes in a block-row fashion
7272
and each process holds a local block of ``A`` with shape
7373
:math:`[N_{loc} \times K]`
7474
- The operand matrix ``X`` is distributed in a block-column fashion and
75-
each process holds a local block of ``X`` with shape
75+
each process holds a local block of ``X`` with shape
7676
:math:`[K \times M_{loc}]`
7777
- Communication is minimized by using a 2D process grid layout
7878
@@ -82,17 +82,13 @@ class MPIMatrixMult(MPILinearOperator):
8282
of shape ``(K, M)``) is reshaped to ``(K, M_local)`` where ``M_local``
8383
is the number of columns assigned to the current process.
8484
85-
2. **Data Broadcasting**: Within each row (processes with same ``row_id``),
86-
the operand data is broadcast from the process whose ``col_id`` matches
87-
the ``row_id`` (processes along the diagonal). This ensures all processes
88-
in a row have access to the same operand columns.
89-
90-
3. **Local Computation**: Each process computes ``A_local @ X_local`` where:
85+
2. **Local Computation**: Each process computes ``A_local @ X_local`` where:
9186
- ``A_local`` is the local block of matrix ``A`` (shape ``N_local x K``)
9287
- ``X_local`` is the broadcasted operand (shape ``K x M_local``)
9388
94-
4. **Row-wise Gather**: Results from all processes in each row are gathered
95-
using ``allgather`` to reconstruct the full result matrix vertically.
89+
3. **Row-wise Gather**: Results from all processes in each row are gathered
90+
using ``allgather`` to ensure that each rank has a block-column of the
91+
output matrix.
9692
9793
**Adjoint Operation step-by-step**
9894
@@ -101,21 +97,20 @@ class MPIMatrixMult(MPILinearOperator):
10197
1. **Input Reshaping**: The input vector ``x`` is reshaped to ``(N, M_local)``
10298
representing the local columns of the input matrix.
10399
104-
2. **Local Adjoint Computation**:
105-
Each process computes ``A_local.H @ X_tile``
106-
where ``A_local.H`` is either:
107-
- Pre-computed ``At`` (if ``saveAt=True``)
108-
- Computed on-the-fly as ``A.T.conj()`` (if ``saveAt=False``)
109-
Each process multiplies its transposed local ``A`` block ``A_local^H``
110-
(shape ``K x N_block``)
111-
with the extracted ``X_tile`` (shape ``N_block x M_local``),
112-
producing a partial result of shape ``(K, M_local)``.
113-
This computes the local contribution of columns of ``A^H`` to the final result.
100+
2. **Local Adjoint Computation**: Each process computes
101+
``A_local.H @ X_tile`` where ``A_local.H`` is either i) Pre-computed
102+
and stored in ``At`` (if ``saveAt=True``), ii) computed on-the-fly as
103+
``A.T.conj()`` (if ``saveAt=False``). Each process multiplies its
104+
transposed local ``A`` block ``A_local^H`` (shape ``K x N_block``)
105+
with the extracted ``X_tile`` (shape ``N_block x M_local``),
106+
producing a partial result of shape ``(K, M_local)``.
107+
This computes the local contribution of columns of ``A^H`` to the final
108+
result.
114109
115110
3. **Row-wise Reduction**: Since the full result ``Y = A^H \cdot X`` is the
116-
sum of contributions from all column blocks of ``A^H``, processes in the
117-
same rows perform an ``allreduce`` sum to combine their partial results.
118-
This gives the complete ``(K, M_local)`` result for their assigned columns.
111+
sum of the contributions from all column blocks of ``A^H``, processes in
112+
the same row perform an ``allreduce`` sum to combine their partial results.
113+
This gives the complete ``(K, M_local)`` result for their assigned column.
119114
120115
"""
121116
def __init__(

tests/test_matrixmult.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,42 @@
1-
import pytest
1+
"""Test the MPIMatrixMult class
2+
Designed to run with n processes
3+
$ mpiexec -n 10 pytest test_matrixmult.py --with-mpi
4+
"""
5+
import math
26
import numpy as np
37
from numpy.testing import assert_allclose
48
from mpi4py import MPI
5-
import math
6-
import sys
9+
import pytest
710

811
from pylops_mpi import DistributedArray, Partition
912
from pylops_mpi.basicoperators.MatrixMult import MPIMatrixMult
1013

1114
np.random.seed(42)
12-
1315
comm = MPI.COMM_WORLD
1416
rank = comm.Get_rank()
1517
size = comm.Get_size()
1618

17-
# Define test cases: (N K, M, dtype_str)
19+
# Define test cases: (N, K, M, dtype_str)
1820
# M, K, N are matrix dimensions A(N,K), B(K,M)
1921
# P_prime will be ceil(sqrt(size)).
2022
test_params = [
21-
pytest.param(37, 37, 37, "float32", id="f32_37_37_37"),
22-
pytest.param(50, 30, 40, "float64", id="f64_50_30_40"),
23+
pytest.param(37, 37, 37, "float32", id="f32_37_37_37"),
24+
pytest.param(50, 30, 40, "float64", id="f64_50_30_40"),
2325
pytest.param(22, 20, 16, "complex64", id="c64_22_20_16"),
24-
pytest.param( 3, 4, 5, "float32", id="f32_3_4_5"),
25-
pytest.param( 1, 2, 1, "float64", id="f64_1_2_1",),
26-
pytest.param( 2, 1, 3, "float32", id="f32_2_1_3",),
26+
pytest.param( 3, 4, 5, "float32", id="f32_3_4_5"),
27+
pytest.param( 1, 2, 1, "float64", id="f64_1_2_1",),
28+
pytest.param( 2, 1, 3, "float32", id="f32_2_1_3",),
2729
]
2830

2931

30-
@pytest.mark.mpi(min_size=1) # SUMMA should also work for 1 process.
32+
@pytest.mark.mpi(min_size=1)
3133
@pytest.mark.parametrize("M, K, N, dtype_str", test_params)
3234
def test_SUMMAMatrixMult(N, K, M, dtype_str):
3335
p_prime = math.isqrt(size)
3436
C = p_prime
3537
if p_prime * C != size:
36-
pytest.skip(f"Number of processes must be a square number, provided {size} instead...")
38+
pytest.skip(f"Number of processes must be a square number, "
39+
"provided {size} instead...")
3740

3841
dtype = np.dtype(dtype_str)
3942

@@ -86,11 +89,13 @@ def test_SUMMAMatrixMult(N, K, M, dtype_str):
8689

8790
x_dist.local_array[:] = X_p.ravel()
8891

89-
# Forward operation: y = A @ B (distributed)
92+
# Forward operation: y = A @ x (distributed)
9093
y_dist = Aop @ x_dist
94+
9195
# Adjoint operation: xadj = A.H @ y (distributed)
9296
xadj_dist = Aop.H @ y_dist
9397

98+
# Re-organize in local matrix
9499
y = y_dist.asarray(masked=True)
95100
col_counts = [min(blk_cols_X, M - j * blk_cols_X) for j in range(p_prime)]
96101
y_blocks = []

0 commit comments

Comments
 (0)