11import math
22import numpy as np
3- from typing import Tuple
3+ from typing import Tuple , Union , Literal
44from mpi4py import MPI
55
66from pylops .utils .backend import get_module
@@ -196,8 +196,8 @@ def block_gather(x: DistributedArray, new_shape: Tuple[int, int], orig_shape: Tu
196196
197197
198198
199- class MPIMatrixMult (MPILinearOperator ):
200- r"""MPI Matrix multiplication
199+ class _MPIBlockMatrixMult (MPILinearOperator ):
200+ r"""MPI Blocked Matrix multiplication
201201
202202 Implement distributed matrix-matrix multiplication between a matrix
203203 :math:`\mathbf{A}` blocked over rows (i.e., blocks of rows are stored
@@ -395,7 +395,7 @@ def _rmatvec(self, x: DistributedArray) -> DistributedArray:
395395 y [:] = y_layer .flatten ()
396396 return y
397397
398- class MPISummaMatrixMult (MPILinearOperator ):
398+ class _MPISummaMatrixMult (MPILinearOperator ):
399399 r"""MPI SUMMA Matrix multiplication
400400
401401 Implements distributed matrix-matrix multiplication using the SUMMA algorithm
@@ -681,4 +681,103 @@ def _rmatvec(self, x: DistributedArray) -> DistributedArray:
681681
682682 Y_local_unpadded = Y_local [:local_k , :local_m ]
683683 y [:] = Y_local_unpadded .flatten ()
684- return y
684+ return y
685+
686+ class MPIMatrixMult (MPILinearOperator ):
687+ r"""
688+ MPI Distributed Matrix Multiplication Operator
689+
690+ This general operator performs distributed matrix-matrix multiplication
691+ using either the SUMMA (Scalable Universal Matrix Multiplication Algorithm)
692+ or a 1D block-row decomposition algorithm, depending on the specified
693+ ``kind`` parameter.
694+
695+ The forward operation computes::
696+
697+ Y = A @ X
698+
699+ where:
700+ - ``A`` is the distributed operator matrix of shape ``[N x K]``
701+ - ``X`` is the distributed operand matrix of shape ``[K x M]``
702+ - ``Y`` is the resulting distributed matrix of shape ``[N x M]``
703+
704+ The adjoint (conjugate-transpose) operation computes::
705+
706+ X_adj = A.H @ Y
707+
708+ where ``A.H`` is the complex-conjugate transpose of ``A``.
709+
710+ Distribution Layouts
711+ --------------------
712+ :summa:
713+ 2D block-grid distribution over a square process grid :math:`[\sqrt{P} \times \sqrt{P}]`:
714+ - ``A`` and ``X`` are partitioned into :math:`[N_loc \times K_loc]` and
715+ :math:`[K_loc \times M_loc]` tiles on each rank, respectively.
716+ - Each SUMMA iteration broadcasts row- and column-blocks of ``A`` and
717+ ``X`` and accumulates local partial products.
718+
719+ :block:
720+ 1D block-row distribution over a 1 x P grid:
721+ - ``A`` is partitioned into :math:`[N_loc \times K]` blocks across ranks.
722+ - ``X`` (and result ``Y``) are partitioned into :math:`[K \times M_loc]` blocks.
723+ - Local multiplication is followed by row-wise gather (forward) or
724+ allreduce (adjoint) across ranks.
725+
726+ Parameters
727+ ----------
728+ A : NDArray
729+ Local block of the matrix operator.
730+ M : int
731+ Global number of columns in the operand and result matrices.
732+ saveAt : bool, optional
733+ If ``True``, store both ``A`` and its conjugate transpose ``A.H``
734+ to accelerate adjoint operations (uses twice the memory).
735+ Default is ``False``.
736+ base_comm : mpi4py.MPI.Comm, optional
737+ MPI communicator to use. Defaults to ``MPI.COMM_WORLD``.
738+ kind : {'summa', 'block'}, optional
739+ Algorithm to use: ``'summa'`` for the SUMMA 2D algorithm, or
740+ ``'block'`` for the block-row-col decomposition. Default is ``'summa'``.
741+ dtype : DTypeLike, optional
742+ Numeric data type for computations. Default is ``np.float64``.
743+
744+ Attributes
745+ ----------
746+ shape : :obj:`tuple`
747+ Operator shape
748+ comm : mpi4py.MPI.Comm
749+ The MPI communicator in use.
750+ kind : str
751+ Selected distributed matrix multiply algorithm ('summa' or 'block').
752+
753+ Raises
754+ ------
755+ NotImplementedError
756+ If ``kind`` is not one of ``'summa'`` or ``'block'``.
757+ Exception
758+ If the MPI communicator does not form a compatible grid for the
759+ selected algorithm.
760+ """
761+ def __init__ (
762+ self ,
763+ A : NDArray ,
764+ M : int ,
765+ saveAt : bool = False ,
766+ base_comm : MPI .Comm = MPI .COMM_WORLD ,
767+ kind :Literal ["summa" , "block" ] = "summa" ,
768+ dtype : DTypeLike = "float64" ,
769+ ):
770+ if kind == "summa" :
771+ self ._f = _MPISummaMatrixMult (A ,M ,saveAt ,base_comm ,dtype )
772+ elif kind == "block" :
773+ self ._f = _MPIBlockMatrixMult (A , M , saveAt , base_comm , dtype )
774+ else :
775+ raise NotImplementedError ("kind must be summa or block" )
776+ self .kind = kind
777+ super ().__init__ (shape = self ._f .shape , dtype = dtype , base_comm = base_comm )
778+
779+ def _matvec (self , x : DistributedArray ) -> DistributedArray :
780+ return self ._f .matvec (x )
781+
782+ def _rmatvec (self , x : DistributedArray ) -> DistributedArray :
783+ return self ._f .rmatvec (x )
0 commit comments