Skip to content

Commit 42452a1

Browse files
committed
Inital docstring for matrix mult
1 parent 8a56096 commit 42452a1

File tree

1 file changed

+80
-4
lines changed

1 file changed

+80
-4
lines changed

pylops_mpi/basicoperators/MatrixMult.py

Lines changed: 80 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,19 +19,95 @@ class MPIMatrixMult(MPILinearOperator):
1919
Parameters
2020
----------
2121
A : :obj:`numpy.ndarray`
22-
Matrix multiplication operator of size
23-
:math:`[ \times ]`
22+
Local block of the matrix multiplication operator of shape ``(M_loc, K)``
23+
where ``M_loc`` is the number of rows stored on this MPI rank and
24+
``K`` is the global number of columns.
25+
N : :obj:`int`
26+
Global leading dimension of the operand matrix (number of columns).
2427
saveAt : :obj:`bool`, optional
2528
Save ``A`` and ``A.H`` to speed up the computation of adjoint
2629
(``True``) or create ``A.H`` on-the-fly (``False``)
27-
Note that ``saveAt=True`` will double the amount of required memory
30+
Note that ``saveAt=True`` will double the amount of required memory.
31+
The default is ``False``.
2832
base_comm : :obj:`mpi4py.MPI.Comm`, optional
2933
MPI Base Communicator. Defaults to ``mpi4py.MPI.COMM_WORLD``.
3034
dtype : :obj:`str`, optional
3135
Type of elements in input array.
3236
37+
Attributes
38+
----------
39+
shape : :obj:`tuple`
40+
Operator shape
41+
42+
Raises
43+
------
44+
Exception
45+
If the operator is created without a square number of mpi ranks.
46+
ValueError
47+
If input vector does not have the correct partition type.
48+
3349
Notes
3450
-----
51+
This implementation uses a 1D block distribution of the operand matrix and
52+
operator replicated across the processes math:`P` by a factor equivalent
53+
to math:`\sqrt{P}` across a square process grid ( math:`\sqrt{P}\times\sqrt{P}`).
54+
55+
The operator implements a distributed matrix-matrix multiplication where:
56+
57+
- The matrix ``A`` is distributed across MPI processes in a block-row fashion
58+
- Each process holds a local block of ``A`` with shape ``(M_loc, K)``
59+
- The operand matrix ``X`` is distributed in a block-column fashion
60+
- Communication is minimized by using a 2D process grid layout
61+
62+
The forward operation computes :math:`Y = A \cdot X` where:
63+
64+
- :math:`A` is the distributed matrix operator of shape ``(M, K)``
65+
- :math:`X` is the distributed operand matrix of shape ``(K, N)``
66+
- :math:`Y` is the resulting distributed matrix of shape ``(M, N)``
67+
68+
The adjoint operation computes :math:`Y = A^H \cdot X` where :math:`A^H`
69+
is the conjugate transpose of :math:`A`.
70+
71+
Steps for the Forward Operation (:math:`Y = A \cdot X`)
72+
----------------------------------------
73+
1. **Input Preparation**: The input vector ``x`` (flattened from matrix ``X``
74+
of shape ``(K, N)``) is reshaped to ``(K, N_local)`` where ``N_local``
75+
is the number of columns assigned to the current process.
76+
77+
2. **Data Broadcasting**: Within each layer (processes with same ``layer_id``),
78+
the operand data is broadcast from the process whose ``group_id`` matches
79+
the ``layer_id``. This ensures all processes in a layer have access to
80+
the same operand columns.
81+
82+
3. **Local Computation**: Each process computes ``A_local @ X_local`` where:
83+
- ``A_local`` is the local block of matrix ``A`` (shape ``M_local × K``)
84+
- ``X_local`` is the broadcasted operand (shape ``K × N_local``)
85+
86+
4. **Layer Gather**: Results from all processes in each layer are gathered
87+
using ``allgather`` to reconstruct the full result matrix vertically.
88+
89+
90+
Steps for the Adjoint Operation (:math:`Y = A^H \cdot X`)
91+
-------------------------------------------
92+
The adjoint operation performs the conjugate transpose multiplication:
93+
94+
1. **Input Reshaping**: The input vector ``x`` is reshaped to ``(M, N_local)``
95+
representing the local columns of the input matrix.
96+
97+
2. **Local Adjoint Computation**:
98+
Each process computes ``A_local.H @ X_tile``
99+
where ``A_local.H`` is either:
100+
- Pre-computed ``At`` (if ``saveAt=True``)
101+
- Computed on-the-fly as ``A.T.conj()`` (if ``saveAt=False``)
102+
Each process multiplies its transposed local ``A`` block ``A_local^H`` (shape ``K × M_block``)
103+
with the extracted ``X_tile`` (shape ``M_block × N_local``),
104+
producing a partial result of shape ``(K, N_local)``.
105+
This computes the local contribution of columns of ``A^H`` to the final result.
106+
107+
3. **Layer Reduction**: Since the full result ``Y = A^H \cdot X`` is the
108+
sum of contributions from all column blocks of ``A^H``, processes in the
109+
same layer perform an ``allreduce`` sum to combine their partial results.
110+
This gives the complete ``(K, N_local)`` result for their assigned columns.
35111
"""
36112
def __init__(
37113
self,
@@ -58,14 +134,14 @@ def __init__(
58134
self.base_comm = base_comm
59135
self._layer_comm = base_comm.Split(color=self._layer_id, key=self._group_id)
60136
self._group_comm = base_comm.Split(color=self._group_id, key=self._layer_id)
137+
61138
self.A = A.astype(np.dtype(dtype))
62139
if saveAt: self.At = A.T.conj()
63140

64141
self.M = self._layer_comm.allreduce(self.A.shape[0], op=MPI.SUM)
65142
self.K = A.shape[1]
66143
self.N = N
67144

68-
# Determine how many columns each group holds
69145
block_cols = int(math.ceil(self.N / self._P_prime))
70146
blk_rows = int(math.ceil(self.M / self._P_prime))
71147

0 commit comments

Comments
 (0)