@@ -19,19 +19,95 @@ class MPIMatrixMult(MPILinearOperator):
1919 Parameters
2020 ----------
2121 A : :obj:`numpy.ndarray`
22- Matrix multiplication operator of size
23- :math:`[ \times ]`
22+ Local block of the matrix multiplication operator of shape ``(M_loc, K)``
23+ where ``M_loc`` is the number of rows stored on this MPI rank and
24+ ``K`` is the global number of columns.
25+ N : :obj:`int`
26+ Global leading dimension of the operand matrix (number of columns).
2427 saveAt : :obj:`bool`, optional
2528 Save ``A`` and ``A.H`` to speed up the computation of adjoint
2629 (``True``) or create ``A.H`` on-the-fly (``False``)
27- Note that ``saveAt=True`` will double the amount of required memory
30+ Note that ``saveAt=True`` will double the amount of required memory.
31+ The default is ``False``.
2832 base_comm : :obj:`mpi4py.MPI.Comm`, optional
2933 MPI Base Communicator. Defaults to ``mpi4py.MPI.COMM_WORLD``.
3034 dtype : :obj:`str`, optional
3135 Type of elements in input array.
3236
37+ Attributes
38+ ----------
39+ shape : :obj:`tuple`
40+ Operator shape
41+
42+ Raises
43+ ------
44+ Exception
45+ If the operator is created without a square number of mpi ranks.
46+ ValueError
47+ If input vector does not have the correct partition type.
48+
3349 Notes
3450 -----
51+ This implementation uses a 1D block distribution of the operand matrix and
52+ operator replicated across the processes math:`P` by a factor equivalent
53+ to math:`\sqrt{P}` across a square process grid ( math:`\sqrt{P}\times\sqrt{P}`).
54+
55+ The operator implements a distributed matrix-matrix multiplication where:
56+
57+ - The matrix ``A`` is distributed across MPI processes in a block-row fashion
58+ - Each process holds a local block of ``A`` with shape ``(M_loc, K)``
59+ - The operand matrix ``X`` is distributed in a block-column fashion
60+ - Communication is minimized by using a 2D process grid layout
61+
62+ The forward operation computes :math:`Y = A \cdot X` where:
63+
64+ - :math:`A` is the distributed matrix operator of shape ``(M, K)``
65+ - :math:`X` is the distributed operand matrix of shape ``(K, N)``
66+ - :math:`Y` is the resulting distributed matrix of shape ``(M, N)``
67+
68+ The adjoint operation computes :math:`Y = A^H \cdot X` where :math:`A^H`
69+ is the conjugate transpose of :math:`A`.
70+
71+ Steps for the Forward Operation (:math:`Y = A \cdot X`)
72+ ----------------------------------------
73+ 1. **Input Preparation**: The input vector ``x`` (flattened from matrix ``X``
74+ of shape ``(K, N)``) is reshaped to ``(K, N_local)`` where ``N_local``
75+ is the number of columns assigned to the current process.
76+
77+ 2. **Data Broadcasting**: Within each layer (processes with same ``layer_id``),
78+ the operand data is broadcast from the process whose ``group_id`` matches
79+ the ``layer_id``. This ensures all processes in a layer have access to
80+ the same operand columns.
81+
82+ 3. **Local Computation**: Each process computes ``A_local @ X_local`` where:
83+ - ``A_local`` is the local block of matrix ``A`` (shape ``M_local × K``)
84+ - ``X_local`` is the broadcasted operand (shape ``K × N_local``)
85+
86+ 4. **Layer Gather**: Results from all processes in each layer are gathered
87+ using ``allgather`` to reconstruct the full result matrix vertically.
88+
89+
90+ Steps for the Adjoint Operation (:math:`Y = A^H \cdot X`)
91+ -------------------------------------------
92+ The adjoint operation performs the conjugate transpose multiplication:
93+
94+ 1. **Input Reshaping**: The input vector ``x`` is reshaped to ``(M, N_local)``
95+ representing the local columns of the input matrix.
96+
97+ 2. **Local Adjoint Computation**:
98+ Each process computes ``A_local.H @ X_tile``
99+ where ``A_local.H`` is either:
100+ - Pre-computed ``At`` (if ``saveAt=True``)
101+ - Computed on-the-fly as ``A.T.conj()`` (if ``saveAt=False``)
102+ Each process multiplies its transposed local ``A`` block ``A_local^H`` (shape ``K × M_block``)
103+ with the extracted ``X_tile`` (shape ``M_block × N_local``),
104+ producing a partial result of shape ``(K, N_local)``.
105+ This computes the local contribution of columns of ``A^H`` to the final result.
106+
107+ 3. **Layer Reduction**: Since the full result ``Y = A^H \cdot X`` is the
108+ sum of contributions from all column blocks of ``A^H``, processes in the
109+ same layer perform an ``allreduce`` sum to combine their partial results.
110+ This gives the complete ``(K, N_local)`` result for their assigned columns.
35111 """
36112 def __init__ (
37113 self ,
@@ -58,14 +134,14 @@ def __init__(
58134 self .base_comm = base_comm
59135 self ._layer_comm = base_comm .Split (color = self ._layer_id , key = self ._group_id )
60136 self ._group_comm = base_comm .Split (color = self ._group_id , key = self ._layer_id )
137+
61138 self .A = A .astype (np .dtype (dtype ))
62139 if saveAt : self .At = A .T .conj ()
63140
64141 self .M = self ._layer_comm .allreduce (self .A .shape [0 ], op = MPI .SUM )
65142 self .K = A .shape [1 ]
66143 self .N = N
67144
68- # Determine how many columns each group holds
69145 block_cols = int (math .ceil (self .N / self ._P_prime ))
70146 blk_rows = int (math .ceil (self .M / self ._P_prime ))
71147
0 commit comments