@@ -19,19 +19,96 @@ class MPIMatrixMult(MPILinearOperator):
1919    Parameters 
2020    ---------- 
2121    A : :obj:`numpy.ndarray` 
22-         Matrix multiplication operator of size 
23-         :math:`[ \times ]` 
22+         Local block of the matrix multiplication operator of shape ``(M_loc, K)`` 
23+         where ``M_loc`` is the number of rows stored on this MPI rank and 
24+         ``K`` is the global number of columns. 
25+     N : :obj:`int` 
26+         Global leading dimension of the operand matrix (number of columns). 
2427    saveAt : :obj:`bool`, optional 
2528        Save ``A`` and ``A.H`` to speed up the computation of adjoint 
2629        (``True``) or create ``A.H`` on-the-fly (``False``) 
27-         Note that ``saveAt=True`` will double the amount of required memory 
30+         Note that ``saveAt=True`` will double the amount of required memory. 
31+         The default is ``False``. 
2832    base_comm : :obj:`mpi4py.MPI.Comm`, optional 
2933        MPI Base Communicator. Defaults to ``mpi4py.MPI.COMM_WORLD``. 
3034    dtype : :obj:`str`, optional 
3135        Type of elements in input array. 
3236
37+     Attributes 
38+     ---------- 
39+     shape : :obj:`tuple` 
40+         Operator shape 
41+ 
42+     Raises 
43+     ------ 
44+     Exception 
45+         If the operator is created without a square number of mpi ranks. 
46+     ValueError 
47+         If input vector does not have the correct partition type. 
48+ 
3349    Notes 
3450    ----- 
51+     This implementation uses a 1D block distribution of the operand matrix and 
52+     operator replicated across the processes math:`P` by a factor equivalent 
53+     to math:`\sqrt{P}` across a square process grid  ( math:`\sqrt{P}\times\sqrt{P}`). 
54+ 
55+     The operator implements a distributed matrix-matrix multiplication where: 
56+ 
57+     - The matrix ``A`` is distributed across MPI processes in a block-row fashion 
58+     - Each process holds a local block of ``A`` with shape ``(M_loc, K)`` 
59+     - The operand matrix ``X`` is distributed in a block-column fashion 
60+     - Communication is minimized by using a 2D process grid layout 
61+ 
62+     The forward operation computes :math:`Y = A \cdot X` where: 
63+ 
64+     - :math:`A` is the distributed matrix operator of shape ``(M, K)`` 
65+     - :math:`X` is the distributed operand matrix of shape ``(K, N)`` 
66+     - :math:`Y` is the resulting distributed matrix of shape ``(M, N)`` 
67+ 
68+     The adjoint operation computes :math:`Y = A^H \cdot X` where :math:`A^H` 
69+     is the conjugate transpose of :math:`A`. 
70+ 
71+     Steps for the Forward Operation (:math:`Y = A \cdot X`) 
72+     ---------------------------------------- 
73+     1. **Input Preparation**: The input vector ``x`` (flattened from matrix ``X`` 
74+        of shape ``(K, N)``) is reshaped to ``(K, N_local)`` where ``N_local`` 
75+        is the number of columns assigned to the current process. 
76+ 
77+     2. **Data Broadcasting**: Within each layer (processes with same ``layer_id``), 
78+        the operand data is broadcast from the process whose ``group_id`` matches 
79+        the ``layer_id``. This ensures all processes in a layer have access to 
80+        the same operand columns. 
81+ 
82+     3. **Local Computation**: Each process computes ``A_local @ X_local`` where: 
83+        - ``A_local`` is the local block of matrix ``A`` (shape ``M_local × K``) 
84+        - ``X_local`` is the broadcasted operand (shape ``K × N_local``) 
85+ 
86+     4. **Result Gathering**: Results from all processes in each layer are gathered 
87+        using ``allgather`` to reconstruct the full result matrix vertically. 
88+ 
89+ 
90+     Steps for the Adjoint Operation (:math:`Y = A^H \cdot X`) 
91+     ------------------------------------------- 
92+     The adjoint operation performs the conjugate transpose multiplication: 
93+ 
94+     1. **Input Reshaping**: The input vector ``x`` is reshaped to ``(M, N_local)`` 
95+        representing the local columns of the input matrix. 
96+ 
97+     2. **Local Adjoint Computation**: 
98+         Each process computes ``A_local.H @ X_tile`` 
99+             where ``A_local.H`` is either: 
100+                 - Pre-computed ``At`` (if ``saveAt=True``) 
101+                 - Computed on-the-fly as ``A.T.conj()`` (if ``saveAt=False``) 
102+         Each process multiplies its transposed  local ``A`` block ``A_local^H`` (shape ``K × M_block``) 
103+         with the extracted  ``X_tile`` (shape ``M_block × N_local``), 
104+         producing a partial result of  shape ``(K, N_local)``. 
105+         This computes the local contribution of columns of  ``A^H`` to the final result. 
106+ 
107+     3. **Layer Reduction**: Since the full result ``Y = A^H \cdot X`` is the 
108+        sum of contributions from all column blocks of ``A^H``, processes in the  
109+        same layer perform an ``allreduce`` sum to combine their partial results.  
110+        This gives the complete ``(K, N_local)`` result for their assigned columns. 
111+ 
35112    """ 
36113    def  __init__ (
37114            self ,
@@ -58,14 +135,14 @@ def __init__(
58135        self .base_comm  =  base_comm 
59136        self ._layer_comm  =  base_comm .Split (color = self ._layer_id , key = self ._group_id )
60137        self ._group_comm  =  base_comm .Split (color = self ._group_id , key = self ._layer_id )
138+ 
61139        self .A  =  A .astype (np .dtype (dtype ))
62140        if  saveAt : self .At  =  A .T .conj ()
63141
64142        self .M  =  self ._layer_comm .allreduce (self .A .shape [0 ], op = MPI .SUM )
65143        self .K  =  A .shape [1 ]
66144        self .N  =  N 
67145
68-         # Determine how many columns each group holds 
69146        block_cols  =  int (math .ceil (self .N  /  self ._P_prime ))
70147        blk_rows  =  int (math .ceil (self .M  /  self ._P_prime ))
71148
0 commit comments