@@ -82,16 +82,16 @@ class MPIMatrixMult(MPILinearOperator):
8282 of shape ``(K, M)``) is reshaped to ``(K, M_local)`` where ``M_local``
8383 is the number of columns assigned to the current process.
8484
85- 2. **Data Broadcasting**: Within each layer (processes with same ``layer_id ``),
86- the operand data is broadcast from the process whose ``group_id `` matches
87- the ``layer_id`` . This ensures all processes in a layer have access to
88- the same operand columns.
85+ 2. **Data Broadcasting**: Within each row (processes with same ``row_id ``),
86+ the operand data is broadcast from the process whose ``col_id `` matches
87+ the ``row_id`` (processes along the diagonal) . This ensures all processes
88+ in a row have access to the same operand columns.
8989
9090 3. **Local Computation**: Each process computes ``A_local @ X_local`` where:
9191 - ``A_local`` is the local block of matrix ``A`` (shape ``N_local x K``)
9292 - ``X_local`` is the broadcasted operand (shape ``K x M_local``)
9393
94- 4. **Layer Gather**: Results from all processes in each layer are gathered
94+ 4. **Row-wise Gather**: Results from all processes in each row are gathered
9595 using ``allgather`` to reconstruct the full result matrix vertically.
9696
9797 **Adjoint Operation step-by-step**
@@ -112,9 +112,9 @@ class MPIMatrixMult(MPILinearOperator):
112112 producing a partial result of shape ``(K, M_local)``.
113113 This computes the local contribution of columns of ``A^H`` to the final result.
114114
115- 3. **Layer Reduction**: Since the full result ``Y = A^H \cdot X`` is the
115+ 3. **Row-wise Reduction**: Since the full result ``Y = A^H \cdot X`` is the
116116 sum of contributions from all column blocks of ``A^H``, processes in the
117- same layer perform an ``allreduce`` sum to combine their partial results.
117+ same rows perform an ``allreduce`` sum to combine their partial results.
118118 This gives the complete ``(K, M_local)`` result for their assigned columns.
119119
120120 """
@@ -135,29 +135,28 @@ def __init__(
135135 if self ._P_prime * self ._C != size :
136136 raise Exception (f"Number of processes must be a square number, provided { size } instead..." )
137137
138- # Compute this process's group and layer indices
139- self ._group_id = rank % self ._P_prime
140- self ._layer_id = rank // self ._P_prime
138+ self ._col_id = rank % self ._P_prime
139+ self ._row_id = rank // self ._P_prime
141140
142141 # Split communicators by layer (rows) and by group (columns)
143142 self .base_comm = base_comm
144- self ._layer_comm = base_comm .Split (color = self ._layer_id , key = self ._group_id )
145- self ._group_comm = base_comm .Split (color = self ._group_id , key = self ._layer_id )
143+ self ._row_comm = base_comm .Split (color = self ._row_id , key = self ._col_id )
144+ self ._col_comm = base_comm .Split (color = self ._col_id , key = self ._row_id )
146145
147146 self .A = A .astype (np .dtype (dtype ))
148147 if saveAt : self .At = A .T .conj ()
149148
150- self .N = self ._layer_comm .allreduce (self .A .shape [0 ], op = MPI .SUM )
149+ self .N = self ._row_comm .allreduce (self .A .shape [0 ], op = MPI .SUM )
151150 self .K = A .shape [1 ]
152151 self .M = M
153152
154153 block_cols = int (math .ceil (self .M / self ._P_prime ))
155154 blk_rows = int (math .ceil (self .N / self ._P_prime ))
156155
157- self ._row_start = self ._group_id * blk_rows
156+ self ._row_start = self ._col_id * blk_rows
158157 self ._row_end = min (self .N , self ._row_start + blk_rows )
159158
160- self ._col_start = self ._layer_id * block_cols
159+ self ._col_start = self ._row_id * block_cols
161160 self ._col_end = min (self .M , self ._col_start + block_cols )
162161
163162 self ._local_ncols = self ._col_end - self ._col_start
@@ -184,7 +183,7 @@ def _matvec(self, x: DistributedArray) -> DistributedArray:
184183 x_arr = x .local_array .reshape ((self .dims [0 ], my_own_cols ))
185184 X_local = x_arr .astype (self .dtype )
186185 Y_local = ncp .vstack (
187- self ._layer_comm .allgather (
186+ self ._row_comm .allgather (
188187 ncp .matmul (self .A , X_local )
189188 )
190189 )
@@ -208,6 +207,6 @@ def _rmatvec(self, x: DistributedArray) -> DistributedArray:
208207 X_tile = x_arr [self ._row_start :self ._row_end , :]
209208 A_local = self .At if hasattr (self , "At" ) else self .A .T .conj ()
210209 Y_local = ncp .matmul (A_local , X_tile )
211- y_layer = self ._layer_comm .allreduce (Y_local , op = MPI .SUM )
210+ y_layer = self ._row_comm .allreduce (Y_local , op = MPI .SUM )
212211 y [:] = y_layer .flatten ()
213212 return y
0 commit comments