@@ -25,7 +25,7 @@ class MPIMatrixMult(MPILinearOperator):
2525 Local block of the matrix of shape :math:`[M_{loc} \times K]`
2626 where ``M_loc`` is the number of rows stored on this MPI rank and
2727 ``K`` is the global number of columns.
28- N : :obj:`int`
28+ M : :obj:`int`
2929 Global leading dimension (i.e., number of columns) of the matrices
3030 representing the input model and data vectors.
3131 saveAt : :obj:`bool`, optional
@@ -55,9 +55,9 @@ class MPIMatrixMult(MPILinearOperator):
5555 This operator performs a matrix-matrix multiplication, whose forward
5656 operation can be described as :math:`Y = A \cdot X` where:
5757
58- - :math:`\mathbf{A}` is the distributed matrix operator of shape :math:`[M \times K]`
59- - :math:`\mathbf{X}` is the distributed operand matrix of shape :math:`[K \times N ]`
60- - :math:`\mathbf{Y}` is the resulting distributed matrix of shape :math:`[M \times N ]`
58+ - :math:`\mathbf{A}` is the distributed matrix operator of shape :math:`[N \times K]`
59+ - :math:`\mathbf{X}` is the distributed operand matrix of shape :math:`[K \times M ]`
60+ - :math:`\mathbf{Y}` is the resulting distributed matrix of shape :math:`[N \times M ]`
6161
6262 whilst the adjoint operation is represented by
6363 :math:`\mathbf{X}_{adj} = \mathbf{A}^H \cdot \mathbf{Y}` where
@@ -70,16 +70,16 @@ class MPIMatrixMult(MPILinearOperator):
7070
7171 - The matrix ``A`` is distributed across MPI processes in a block-row fashion
7272 and each process holds a local block of ``A`` with shape
73- :math:`[M_ {loc} \times K]`
73+ :math:`[N_ {loc} \times K]`
7474 - The operand matrix ``X`` is distributed in a block-column fashion and
75- and each process holds a local block of ``X`` with shape
76- :math:`[K \times N_ {loc}]`
75+ each process holds a local block of ``X`` with shape
76+ :math:`[K \times M_ {loc}]`
7777 - Communication is minimized by using a 2D process grid layout
7878
7979 **Forward Operation step-by-step**
8080
8181 1. **Input Preparation**: The input vector ``x`` (flattened from matrix ``X``
82- of shape ``(K, N )``) is reshaped to ``(K, N_local )`` where ``N_local ``
82+ of shape ``(K, M )``) is reshaped to ``(K, M_local )`` where ``M_local ``
8383 is the number of columns assigned to the current process.
8484
8585 2. **Data Broadcasting**: Within each layer (processes with same ``layer_id``),
@@ -88,8 +88,8 @@ class MPIMatrixMult(MPILinearOperator):
8888 the same operand columns.
8989
9090 3. **Local Computation**: Each process computes ``A_local @ X_local`` where:
91- - ``A_local`` is the local block of matrix ``A`` (shape ``M_local x K``)
92- - ``X_local`` is the broadcasted operand (shape ``K x N_local ``)
91+ - ``A_local`` is the local block of matrix ``A`` (shape ``N_local x K``)
92+ - ``X_local`` is the broadcasted operand (shape ``K x M_local ``)
9393
9494 4. **Layer Gather**: Results from all processes in each layer are gathered
9595 using ``allgather`` to reconstruct the full result matrix vertically.
@@ -98,7 +98,7 @@ class MPIMatrixMult(MPILinearOperator):
9898
9999 The adjoint operation performs the conjugate transpose multiplication:
100100
101- 1. **Input Reshaping**: The input vector ``x`` is reshaped to ``(M, N_local )``
101+ 1. **Input Reshaping**: The input vector ``x`` is reshaped to ``(N, M_local )``
102102 representing the local columns of the input matrix.
103103
104104 2. **Local Adjoint Computation**:
@@ -107,21 +107,21 @@ class MPIMatrixMult(MPILinearOperator):
107107 - Pre-computed ``At`` (if ``saveAt=True``)
108108 - Computed on-the-fly as ``A.T.conj()`` (if ``saveAt=False``)
109109 Each process multiplies its transposed local ``A`` block ``A_local^H``
110- (shape ``K x M_block ``)
111- with the extracted ``X_tile`` (shape ``M_block x N_local ``),
112- producing a partial result of shape ``(K, N_local )``.
110+ (shape ``K x N_block ``)
111+ with the extracted ``X_tile`` (shape ``N_block x M_local ``),
112+ producing a partial result of shape ``(K, M_local )``.
113113 This computes the local contribution of columns of ``A^H`` to the final result.
114114
115115 3. **Layer Reduction**: Since the full result ``Y = A^H \cdot X`` is the
116116 sum of contributions from all column blocks of ``A^H``, processes in the
117117 same layer perform an ``allreduce`` sum to combine their partial results.
118- This gives the complete ``(K, N_local )`` result for their assigned columns.
118+ This gives the complete ``(K, M_local )`` result for their assigned columns.
119119
120120 """
121121 def __init__ (
122122 self ,
123123 A : NDArray ,
124- N : int ,
124+ M : int ,
125125 saveAt : bool = False ,
126126 base_comm : MPI .Comm = MPI .COMM_WORLD ,
127127 dtype : DTypeLike = "float64" ,
@@ -147,25 +147,25 @@ def __init__(
147147 self .A = A .astype (np .dtype (dtype ))
148148 if saveAt : self .At = A .T .conj ()
149149
150- self .M = self ._layer_comm .allreduce (self .A .shape [0 ], op = MPI .SUM )
150+ self .N = self ._layer_comm .allreduce (self .A .shape [0 ], op = MPI .SUM )
151151 self .K = A .shape [1 ]
152- self .N = N
152+ self .M = M
153153
154- block_cols = int (math .ceil (self .N / self ._P_prime ))
155- blk_rows = int (math .ceil (self .M / self ._P_prime ))
154+ block_cols = int (math .ceil (self .M / self ._P_prime ))
155+ blk_rows = int (math .ceil (self .N / self ._P_prime ))
156156
157157 self ._row_start = self ._group_id * blk_rows
158- self ._row_end = min (self .M , self ._row_start + blk_rows )
158+ self ._row_end = min (self .N , self ._row_start + blk_rows )
159159
160160 self ._col_start = self ._layer_id * block_cols
161- self ._col_end = min (self .N , self ._col_start + block_cols )
161+ self ._col_end = min (self .M , self ._col_start + block_cols )
162162
163163 self ._local_ncols = self ._col_end - self ._col_start
164164 self ._rank_col_lens = self .base_comm .allgather (self ._local_ncols )
165165 total_ncols = np .sum (self ._rank_col_lens )
166166
167167 self .dims = (self .K , total_ncols )
168- self .dimsd = (self .M , total_ncols )
168+ self .dimsd = (self .N , total_ncols )
169169 shape = (int (np .prod (self .dimsd )), int (np .prod (self .dims )))
170170 super ().__init__ (shape = shape , dtype = np .dtype (dtype ), base_comm = base_comm )
171171
@@ -174,8 +174,8 @@ def _matvec(self, x: DistributedArray) -> DistributedArray:
174174 if x .partition != Partition .SCATTER :
175175 raise ValueError (f"x should have partition={ Partition .SCATTER } Got { x .partition } instead..." )
176176
177- y = DistributedArray (global_shape = (self .M * self .dimsd [1 ]),
178- local_shapes = [(self .M * c ) for c in self ._rank_col_lens ],
177+ y = DistributedArray (global_shape = (self .N * self .dimsd [1 ]),
178+ local_shapes = [(self .N * c ) for c in self ._rank_col_lens ],
179179 mask = x .mask ,
180180 partition = Partition .SCATTER ,
181181 dtype = self .dtype )
@@ -204,7 +204,7 @@ def _rmatvec(self, x: DistributedArray) -> DistributedArray:
204204 dtype = self .dtype ,
205205 )
206206
207- x_arr = x .local_array .reshape ((self .M , self ._local_ncols )).astype (self .dtype )
207+ x_arr = x .local_array .reshape ((self .N , self ._local_ncols )).astype (self .dtype )
208208 X_tile = x_arr [self ._row_start :self ._row_end , :]
209209 A_local = self .At if hasattr (self , "At" ) else self .A .T .conj ()
210210 Y_local = ncp .matmul (A_local , X_tile )
0 commit comments