@@ -23,7 +23,7 @@ class MPIMatrixMult(MPILinearOperator):
2323 ----------
2424 A : :obj:`numpy.ndarray`
2525 Local block of the matrix of shape :math:`[N_{loc} \times K]`
26- where ``N_loc` ` is the number of rows stored on this MPI rank and
26+ where :math:`N_{loc} ` is the number of rows stored on this MPI rank and
2727 ``K`` is the global number of columns.
2828 M : :obj:`int`
2929 Global leading dimension (i.e., number of columns) of the matrices
@@ -46,7 +46,7 @@ class MPIMatrixMult(MPILinearOperator):
4646 Raises
4747 ------
4848 Exception
49- If the operator is created without a square number of mpi ranks.
49+ If the operator is created with a non- square number of MPI ranks.
5050 ValueError
5151 If input vector does not have the correct partition type.
5252
@@ -64,15 +64,15 @@ class MPIMatrixMult(MPILinearOperator):
6464 :math:`\mathbf{A}^H` is the complex conjugate and transpose of :math:`\mathbf{A}`.
6565
6666 This implementation is based on a 1D block distribution of the operator
67- matrix and reshaped model and data vectors replicated across math:`P`
67+ matrix and reshaped model and data vectors replicated across : math:`P`
6868 processes by a factor equivalent to :math:`\sqrt{P}` across a square process
6969 grid (:math:`\sqrt{P}\times\sqrt{P}`). More specifically:
7070
7171 - The matrix ``A`` is distributed across MPI processes in a block-row fashion
7272 and each process holds a local block of ``A`` with shape
7373 :math:`[N_{loc} \times K]`
7474 - The operand matrix ``X`` is distributed in a block-column fashion and
75- each process holds a local block of ``X`` with shape
75+ each process holds a local block of ``X`` with shape
7676 :math:`[K \times M_{loc}]`
7777 - Communication is minimized by using a 2D process grid layout
7878
@@ -82,17 +82,13 @@ class MPIMatrixMult(MPILinearOperator):
8282 of shape ``(K, M)``) is reshaped to ``(K, M_local)`` where ``M_local``
8383 is the number of columns assigned to the current process.
8484
85- 2. **Data Broadcasting**: Within each row (processes with same ``row_id``),
86- the operand data is broadcast from the process whose ``col_id`` matches
87- the ``row_id`` (processes along the diagonal). This ensures all processes
88- in a row have access to the same operand columns.
89-
90- 3. **Local Computation**: Each process computes ``A_local @ X_local`` where:
85+ 2. **Local Computation**: Each process computes ``A_local @ X_local`` where:
9186 - ``A_local`` is the local block of matrix ``A`` (shape ``N_local x K``)
9287 - ``X_local`` is the broadcasted operand (shape ``K x M_local``)
9388
94- 4. **Row-wise Gather**: Results from all processes in each row are gathered
95- using ``allgather`` to reconstruct the full result matrix vertically.
89+ 3. **Row-wise Gather**: Results from all processes in each row are gathered
90+ using ``allgather`` to ensure that each rank has a block-column of the
91+ output matrix.
9692
9793 **Adjoint Operation step-by-step**
9894
@@ -101,21 +97,20 @@ class MPIMatrixMult(MPILinearOperator):
10197 1. **Input Reshaping**: The input vector ``x`` is reshaped to ``(N, M_local)``
10298 representing the local columns of the input matrix.
10399
104- 2. **Local Adjoint Computation**:
105- Each process computes ``A_local.H @ X_tile``
106- where ``A_local.H`` is either:
107- - Pre-computed ``At`` (if ``saveAt=True``)
108- - Computed on-the-fly as ``A.T.conj()`` (if ``saveAt=False``)
109- Each process multiplies its transposed local ``A`` block ``A_local^H``
110- (shape ``K x N_block``)
111- with the extracted ``X_tile`` (shape ``N_block x M_local``),
112- producing a partial result of shape ``(K, M_local)``.
113- This computes the local contribution of columns of ``A^H`` to the final result.
100+ 2. **Local Adjoint Computation**: Each process computes
101+ ``A_local.H @ X_tile`` where ``A_local.H`` is either i) Pre-computed
102+ and stored in ``At`` (if ``saveAt=True``), ii) computed on-the-fly as
103+ ``A.T.conj()`` (if ``saveAt=False``). Each process multiplies its
104+ transposed local ``A`` block ``A_local^H`` (shape ``K x N_block``)
105+ with the extracted ``X_tile`` (shape ``N_block x M_local``),
106+ producing a partial result of shape ``(K, M_local)``.
107+ This computes the local contribution of columns of ``A^H`` to the final
108+ result.
114109
115110 3. **Row-wise Reduction**: Since the full result ``Y = A^H \cdot X`` is the
116- sum of contributions from all column blocks of ``A^H``, processes in the
117- same rows perform an ``allreduce`` sum to combine their partial results.
118- This gives the complete ``(K, M_local)`` result for their assigned columns .
111+ sum of the contributions from all column blocks of ``A^H``, processes in
112+ the same row perform an ``allreduce`` sum to combine their partial results.
113+ This gives the complete ``(K, M_local)`` result for their assigned column .
119114
120115 """
121116 def __init__ (
0 commit comments