@@ -195,7 +195,6 @@ def block_gather(x: DistributedArray, new_shape: Tuple[int, int], orig_shape: Tu
195195 return C [:orr , :orc ]
196196
197197
198-
199198class _MPIBlockMatrixMult (MPILinearOperator ):
200199 r"""MPI Blocked Matrix multiplication
201200
@@ -214,7 +213,7 @@ class _MPIBlockMatrixMult(MPILinearOperator):
214213 Global leading dimension (i.e., number of columns) of the matrices
215214 representing the input model and data vectors.
216215 saveAt : :obj:`bool`, optional
217- Save ``A` ` and ``A.H`` to speed up the computation of adjoint
216+ Save :math:`\mathbf{A} ` and ``A.H`` to speed up the computation of adjoint
218217 (``True``) or create ``A.H`` on-the-fly (``False``)
219218 Note that ``saveAt=True`` will double the amount of required memory.
220219 Default is ``False``.
@@ -253,22 +252,22 @@ class _MPIBlockMatrixMult(MPILinearOperator):
253252 processes by a factor equivalent to :math:`\sqrt{P}` across a square process
254253 grid (:math:`\sqrt{P}\times\sqrt{P}`). More specifically:
255254
256- - The matrix ``A` ` is distributed across MPI processes in a block-row fashion
257- and each process holds a local block of ``A` ` with shape
255+ - The matrix :math:`\mathbf{A} ` is distributed across MPI processes in a block-row fashion
256+ and each process holds a local block of :math:`\mathbf{A} ` with shape
258257 :math:`[N_{loc} \times K]`
259- - The operand matrix ``X` ` is distributed in a block-column fashion and
260- each process holds a local block of ``X` ` with shape
258+ - The operand matrix :math:`\mathbf{X} ` is distributed in a block-column fashion and
259+ each process holds a local block of :math:`\mathbf{X} ` with shape
261260 :math:`[K \times M_{loc}]`
262261 - Communication is minimized by using a 2D process grid layout
263262
264263 **Forward Operation step-by-step**
265264
266- 1. **Input Preparation**: The input vector ``x`` (flattened from matrix ``X` `
265+ 1. **Input Preparation**: The input vector ``x`` (flattened from matrix :math:`\mathbf{X} `
267266 of shape ``(K, M)``) is reshaped to ``(K, M_local)`` where ``M_local``
268267 is the number of columns assigned to the current process.
269268
270269 2. **Local Computation**: Each process computes ``A_local @ X_local`` where:
271- - ``A_local`` is the local block of matrix ``A` ` (shape ``N_local x K``)
270+ - ``A_local`` is the local block of matrix :math:`\mathbf{A} ` (shape ``N_local x K``)
272271 - ``X_local`` is the broadcasted operand (shape ``K x M_local``)
273272
274273 3. **Row-wise Gather**: Results from all processes in each row are gathered
@@ -283,10 +282,10 @@ class _MPIBlockMatrixMult(MPILinearOperator):
283282 representing the local columns of the input matrix.
284283
285284 2. **Local Adjoint Computation**: Each process computes
286- ``A_local.H @ X_tile`` where ``A_local.H`` is either i) Pre -computed
287- and stored in ``At`` (if ``saveAt=True``), ii) computed on-the-fly as
285+ ``A_local.H @ X_tile`` where ``A_local.H`` is either pre -computed
286+ and stored in ``At`` (if ``saveAt=True``), or computed on-the-fly as
288287 ``A.T.conj()`` (if ``saveAt=False``). Each process multiplies its
289- transposed local ``A` ` block ``A_local^H`` (shape ``K x N_block``)
288+ transposed local :math:`\mathbf{A} ` block ``A_local^H`` (shape ``K x N_block``)
290289 with the extracted ``X_tile`` (shape ``N_block x M_local``),
291290 producing a partial result of shape ``(K, M_local)``.
292291 This computes the local contribution of columns of ``A^H`` to the final
@@ -413,7 +412,7 @@ class _MPISummaMatrixMult(MPILinearOperator):
413412 Global number of columns of the matrices representing the input model
414413 and data vectors.
415414 saveAt : :obj:`bool`, optional
416- Save ``A` ` and ``A.H`` to speed up the computation of adjoint
415+ Save :math:`\mathbf{A} ` and ``A.H`` to speed up the computation of adjoint
417416 (``True``) or create ``A.H`` on-the-fly (``False``).
418417 Note that ``saveAt=True`` will double the amount of required memory.
419418 Default is ``False``.
@@ -451,16 +450,16 @@ class _MPISummaMatrixMult(MPILinearOperator):
451450 This implementation is based on a 2D block distribution across a square process
452451 grid (:math:`\sqrt{P}\times\sqrt{P}`). The matrices are distributed as follows:
453452
454- - The matrix ``A` ` is distributed across MPI processes in 2D blocks where
455- each process holds a local block of ``A` ` with shape :math:`[N_{loc} \times K_{loc}]`
453+ - The matrix :math:`\mathbf{A} ` is distributed across MPI processes in 2D blocks where
454+ each process holds a local block of :math:`\mathbf{A} ` with shape :math:`[N_{loc} \times K_{loc}]`
456455 where :math:`N_{loc} = \frac{N}{\sqrt{P}}` and :math:`K_{loc} = \frac{K}{\sqrt{P}}`.
457456
458- - The operand matrix ``X` ` is also distributed across MPI processes in 2D blocks where
459- each process holds a local block of ``X` ` with shape :math:`[K_{loc} \times M_{loc}]`
457+ - The operand matrix :math:`\mathbf{X} ` is also distributed across MPI processes in 2D blocks where
458+ each process holds a local block of :math:`\mathbf{X} ` with shape :math:`[K_{loc} \times M_{loc}]`
460459 where :math:`K_{loc} = \frac{K}{\sqrt{P}}` and :math:`M_{loc} = \frac{M}{\sqrt{P}}`.
461460
462- - The result matrix ``Y` ` is also distributed across MPI processes in 2D blocks where
463- each process holds a local block of ``Y` ` with shape :math:`[N_{loc} \times M_{loc}]`
461+ - The result matrix :math:`\mathbf{Y} ` is also distributed across MPI processes in 2D blocks where
462+ each process holds a local block of :math:`\mathbf{Y} ` with shape :math:`[N_{loc} \times M_{loc}]`
464463 where :math:`N_{loc} = \frac{N}{\sqrt{P}}` and :math:`M_{loc} = \frac{M}{\sqrt{P}}`.
465464
466465
@@ -473,18 +472,18 @@ class _MPISummaMatrixMult(MPILinearOperator):
473472
474473 2. **SUMMA Iteration**: For each step ``k`` in the SUMMA algorithm -- :math:`k \in \[ 0, \sqrt{P} \)}` :
475474
476- a. **Broadcast A blocks**: Process in column ``k`` broadcasts its ``A` `
475+ a. **Broadcast A blocks**: Process in column ``k`` broadcasts its :math:`\mathbf{A} `
477476 block to all other processes in the same process row.
478477
479- b. **Broadcast X blocks**: Process in row ``k`` broadcasts its ``X` `
478+ b. **Broadcast X blocks**: Process in row ``k`` broadcasts its :math:`\mathbf{X} `
480479 block to all other processes in the same process column.
481480
482481 c. **Local Computation**: Each process computes the partial matrix
483482 product ``A_broadcast @ X_broadcast`` and accumulates it to its
484483 local result.
485484
486485 3. **Result Assembly**: After all k SUMMA iterations, each process has computed
487- its local block of the result matrix ``Y` `.
486+ its local block of the result matrix :math:`\mathbf{Y} `.
488487
489488 **Adjoint Operation (SUMMA Algorithm)**
490489
@@ -496,11 +495,11 @@ class _MPISummaMatrixMult(MPILinearOperator):
496495
497496 2. **SUMMA Adjoint Iteration**: For each step ``k`` in the adjoint SUMMA algorithm:
498497
499- a. **Broadcast A^H blocks**: The conjugate transpose of ``A` ` blocks is
498+ a. **Broadcast A^H blocks**: The conjugate transpose of :math:`\mathbf{A} ` blocks is
500499 communicated between processes. If ``saveAt=True``, the pre-computed
501500 ``A.H`` is used; otherwise, ``A.T.conj()`` is computed on-the-fly.
502501
503- b. **Broadcast Y blocks**: Process in row ``k`` broadcasts its ``Y` `
502+ b. **Broadcast Y blocks**: Process in row ``k`` broadcasts its :math:`\mathbf{Y} `
504503 block to all other processes in the same process column.
505504
506505 c. **Local Adjoint Computation**: Each process computes the partial
@@ -683,7 +682,14 @@ def _rmatvec(self, x: DistributedArray) -> DistributedArray:
683682 y [:] = Y_local_unpadded .flatten ()
684683 return y
685684
686- class MPIMatrixMult (MPILinearOperator ):
685+ def MPIMatrixMult (
686+ A : NDArray ,
687+ M : int ,
688+ saveAt : bool = False ,
689+ base_comm : MPI .Comm = MPI .COMM_WORLD ,
690+ kind : Literal ["summa" , "block" ] = "summa" ,
691+ dtype : DTypeLike = "float64" ,
692+ ):
687693 r"""
688694 MPI Distributed Matrix Multiplication Operator
689695
@@ -694,32 +700,32 @@ class MPIMatrixMult(MPILinearOperator):
694700
695701 The forward operation computes::
696702
697- Y = A @ X
703+ :math:`\mathbf{Y} = \mathbf{A} \cdot \mathbf{X}`
698704
699705 where:
700- - ``A`` is the distributed operator matrix of shape `` [N x K]` `
701- - ``X`` is the distributed operand matrix of shape `` [K x M]` `
702- - ``Y`` is the resulting distributed matrix of shape `` [N x M]` `
706+ - :math:`\mathbf{A}` is the distributed operator matrix of shape :math:` [N \times K]`
707+ - :math:`\mathbf{X}` is the distributed operand matrix of shape :math:` [K \times M]`
708+ - :math:`\mathbf{Y}` is the resulting distributed matrix of shape :math:` [N \times M]`
703709
704710 The adjoint (conjugate-transpose) operation computes::
711+
712+ :math:`\mathbf{X}_{adj} = \mathbf{A}^H \cdot \mathbf{Y}`
705713
706- X_adj = A.H @ Y
707-
708- where ``A.H`` is the complex-conjugate transpose of ``A``.
714+ where :math:`\mathbf{A}^H` is the complex-conjugate transpose of :math:`\mathbf{A}`.
709715
710716 Distribution Layouts
711717 --------------------
712718 :summa:
713719 2D block-grid distribution over a square process grid :math:`[\sqrt{P} \times \sqrt{P}]`:
714- - ``A`` and ``X` ` are partitioned into :math:`[N_loc \times K_loc]` and
720+ - :math:`\mathbf{A}` and :math:`\mathbf{X} ` are partitioned into :math:`[N_loc \times K_loc]` and
715721 :math:`[K_loc \times M_loc]` tiles on each rank, respectively.
716- - Each SUMMA iteration broadcasts row- and column-blocks of ``A` ` and
717- ``X` ` and accumulates local partial products.
722+ - Each SUMMA iteration broadcasts row- and column-blocks of :math:`\mathbf{A} ` and
723+ :math:`\mathbf{X} ` and accumulates local partial products.
718724
719725 :block:
720- 1D block-row distribution over a 1 x P grid:
721- - ``A` ` is partitioned into :math:`[N_loc \times K]` blocks across ranks.
722- - ``X`` (and result ``Y` `) are partitioned into :math:`[K \times M_loc]` blocks.
726+ 1D block-row distribution over a :math:`[1 \times P]` grid:
727+ - :math:`\mathbf{A} ` is partitioned into :math:`[N_loc \times K]` blocks across ranks.
728+ - :math:`\mathbf{X}` (and result :math:`\mathbf{Y} `) are partitioned into :math:`[K \times M_loc]` blocks.
723729 - Local multiplication is followed by row-wise gather (forward) or
724730 allreduce (adjoint) across ranks.
725731
@@ -730,7 +736,7 @@ class MPIMatrixMult(MPILinearOperator):
730736 M : int
731737 Global number of columns in the operand and result matrices.
732738 saveAt : bool, optional
733- If ``True``, store both ``A`` and its conjugate transpose ``A.H` `
739+ If ``True``, store both :math:`\mathbf{A}` and its conjugate transpose :math:`\mathbf{A}^H `
734740 to accelerate adjoint operations (uses twice the memory).
735741 Default is ``False``.
736742 base_comm : mpi4py.MPI.Comm, optional
@@ -758,26 +764,9 @@ class MPIMatrixMult(MPILinearOperator):
758764 If the MPI communicator does not form a compatible grid for the
759765 selected algorithm.
760766 """
761- def __init__ (
762- self ,
763- A : NDArray ,
764- M : int ,
765- saveAt : bool = False ,
766- base_comm : MPI .Comm = MPI .COMM_WORLD ,
767- kind :Literal ["summa" , "block" ] = "summa" ,
768- dtype : DTypeLike = "float64" ,
769- ):
770- if kind == "summa" :
771- self ._f = _MPISummaMatrixMult (A ,M ,saveAt ,base_comm ,dtype )
772- elif kind == "block" :
773- self ._f = _MPIBlockMatrixMult (A , M , saveAt , base_comm , dtype )
774- else :
775- raise NotImplementedError ("kind must be summa or block" )
776- self .kind = kind
777- super ().__init__ (shape = self ._f .shape , dtype = dtype , base_comm = base_comm )
778-
779- def _matvec (self , x : DistributedArray ) -> DistributedArray :
780- return self ._f .matvec (x )
781-
782- def _rmatvec (self , x : DistributedArray ) -> DistributedArray :
783- return self ._f .rmatvec (x )
767+ if kind == "summa" :
768+ return _MPISummaMatrixMult (A ,M ,saveAt ,base_comm ,dtype )
769+ elif kind == "block" :
770+ return _MPIBlockMatrixMult (A , M , saveAt , base_comm , dtype )
771+ else :
772+ raise NotImplementedError ("kind must be summa or block" )
0 commit comments