@@ -82,16 +82,16 @@ class MPIMatrixMult(MPILinearOperator):
8282 of shape ``(K, M)``) is reshaped to ``(K, M_local)`` where ``M_local``
8383 is the number of columns assigned to the current process.
8484
85- 2. **Data Broadcasting**: Within each layer (processes with same ``layer_id ``),
86- the operand data is broadcast from the process whose ``group_id `` matches
87- the ``layer_id`` . This ensures all processes in a layer have access to
88- the same operand columns.
85+ 2. **Data Broadcasting**: Within each row (processes with same ``row_id ``),
86+ the operand data is broadcast from the process whose ``col_id `` matches
87+ the ``row_id`` (processes along the diagonal) . This ensures all processes
88+ in a row have access to the same operand columns.
8989
9090 3. **Local Computation**: Each process computes ``A_local @ X_local`` where:
9191 - ``A_local`` is the local block of matrix ``A`` (shape ``N_local x K``)
9292 - ``X_local`` is the broadcasted operand (shape ``K x M_local``)
9393
94- 4. **Layer Gather**: Results from all processes in each layer are gathered
94+ 4. **Row-wise Gather**: Results from all processes in each row are gathered
9595 using ``allgather`` to reconstruct the full result matrix vertically.
9696
9797 **Adjoint Operation step-by-step**
@@ -112,9 +112,9 @@ class MPIMatrixMult(MPILinearOperator):
112112 producing a partial result of shape ``(K, M_local)``.
113113 This computes the local contribution of columns of ``A^H`` to the final result.
114114
115- 3. **Layer Reduction**: Since the full result ``Y = A^H \cdot X`` is the
115+ 3. **Row-wise Reduction**: Since the full result ``Y = A^H \cdot X`` is the
116116 sum of contributions from all column blocks of ``A^H``, processes in the
117- same layer perform an ``allreduce`` sum to combine their partial results.
117+ same rows perform an ``allreduce`` sum to combine their partial results.
118118 This gives the complete ``(K, M_local)`` result for their assigned columns.
119119
120120 """
@@ -135,29 +135,27 @@ def __init__(
135135 if self ._P_prime * self ._C != size :
136136 raise Exception (f"Number of processes must be a square number, provided { size } instead..." )
137137
138- # Compute this process's group and layer indices
139- self ._group_id = rank % self ._P_prime
140- self ._layer_id = rank // self ._P_prime
138+ self ._col_id = rank % self ._P_prime
139+ self ._row_id = rank // self ._P_prime
141140
142- # Split communicators by layer (rows) and by group (columns)
143141 self .base_comm = base_comm
144- self ._layer_comm = base_comm .Split (color = self ._layer_id , key = self ._group_id )
145- self ._group_comm = base_comm .Split (color = self ._group_id , key = self ._layer_id )
142+ self ._row_comm = base_comm .Split (color = self ._row_id , key = self ._col_id )
143+ self ._col_comm = base_comm .Split (color = self ._col_id , key = self ._row_id )
146144
147145 self .A = A .astype (np .dtype (dtype ))
148146 if saveAt : self .At = A .T .conj ()
149147
150- self .N = self ._layer_comm .allreduce (self .A .shape [0 ], op = MPI .SUM )
148+ self .N = self ._row_comm .allreduce (self .A .shape [0 ], op = MPI .SUM )
151149 self .K = A .shape [1 ]
152150 self .M = M
153151
154152 block_cols = int (math .ceil (self .M / self ._P_prime ))
155153 blk_rows = int (math .ceil (self .N / self ._P_prime ))
156154
157- self ._row_start = self ._group_id * blk_rows
155+ self ._row_start = self ._col_id * blk_rows
158156 self ._row_end = min (self .N , self ._row_start + blk_rows )
159157
160- self ._col_start = self ._layer_id * block_cols
158+ self ._col_start = self ._row_id * block_cols
161159 self ._col_end = min (self .M , self ._col_start + block_cols )
162160
163161 self ._local_ncols = self ._col_end - self ._col_start
@@ -184,7 +182,7 @@ def _matvec(self, x: DistributedArray) -> DistributedArray:
184182 x_arr = x .local_array .reshape ((self .dims [0 ], my_own_cols ))
185183 X_local = x_arr .astype (self .dtype )
186184 Y_local = ncp .vstack (
187- self ._layer_comm .allgather (
185+ self ._row_comm .allgather (
188186 ncp .matmul (self .A , X_local )
189187 )
190188 )
@@ -208,6 +206,6 @@ def _rmatvec(self, x: DistributedArray) -> DistributedArray:
208206 X_tile = x_arr [self ._row_start :self ._row_end , :]
209207 A_local = self .At if hasattr (self , "At" ) else self .A .T .conj ()
210208 Y_local = ncp .matmul (A_local , X_tile )
211- y_layer = self ._layer_comm .allreduce (Y_local , op = MPI .SUM )
209+ y_layer = self ._row_comm .allreduce (Y_local , op = MPI .SUM )
212210 y [:] = y_layer .flatten ()
213211 return y
0 commit comments