Skip to content

Commit 66e3296

Browse files
committed
Added Generic MatMulOp with docstring
1 parent 9d36b0c commit 66e3296

File tree

1 file changed

+104
-5
lines changed

1 file changed

+104
-5
lines changed

pylops_mpi/basicoperators/MatrixMult.py

Lines changed: 104 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import math
22
import numpy as np
3-
from typing import Tuple
3+
from typing import Tuple, Union, Literal
44
from mpi4py import MPI
55

66
from pylops.utils.backend import get_module
@@ -196,8 +196,8 @@ def block_gather(x: DistributedArray, new_shape: Tuple[int, int], orig_shape: Tu
196196

197197

198198

199-
class MPIMatrixMult(MPILinearOperator):
200-
r"""MPI Matrix multiplication
199+
class _MPIBlockMatrixMult(MPILinearOperator):
200+
r"""MPI Blocked Matrix multiplication
201201
202202
Implement distributed matrix-matrix multiplication between a matrix
203203
:math:`\mathbf{A}` blocked over rows (i.e., blocks of rows are stored
@@ -395,7 +395,7 @@ def _rmatvec(self, x: DistributedArray) -> DistributedArray:
395395
y[:] = y_layer.flatten()
396396
return y
397397

398-
class MPISummaMatrixMult(MPILinearOperator):
398+
class _MPISummaMatrixMult(MPILinearOperator):
399399
r"""MPI SUMMA Matrix multiplication
400400
401401
Implements distributed matrix-matrix multiplication using the SUMMA algorithm
@@ -681,4 +681,103 @@ def _rmatvec(self, x: DistributedArray) -> DistributedArray:
681681

682682
Y_local_unpadded = Y_local[:local_k, :local_m]
683683
y[:] = Y_local_unpadded.flatten()
684-
return y
684+
return y
685+
686+
class MPIMatrixMult(MPILinearOperator):
687+
r"""
688+
MPI Distributed Matrix Multiplication Operator
689+
690+
This general operator performs distributed matrix-matrix multiplication
691+
using either the SUMMA (Scalable Universal Matrix Multiplication Algorithm)
692+
or a 1D block-row decomposition algorithm, depending on the specified
693+
``kind`` parameter.
694+
695+
The forward operation computes::
696+
697+
Y = A @ X
698+
699+
where:
700+
- ``A`` is the distributed operator matrix of shape ``[N x K]``
701+
- ``X`` is the distributed operand matrix of shape ``[K x M]``
702+
- ``Y`` is the resulting distributed matrix of shape ``[N x M]``
703+
704+
The adjoint (conjugate-transpose) operation computes::
705+
706+
X_adj = A.H @ Y
707+
708+
where ``A.H`` is the complex-conjugate transpose of ``A``.
709+
710+
Distribution Layouts
711+
--------------------
712+
:summa:
713+
2D block-grid distribution over a square process grid :math:`[\sqrt{P} \times \sqrt{P}]`:
714+
- ``A`` and ``X`` are partitioned into :math:`[N_loc \times K_loc]` and
715+
:math:`[K_loc \times M_loc]` tiles on each rank, respectively.
716+
- Each SUMMA iteration broadcasts row- and column-blocks of ``A`` and
717+
``X`` and accumulates local partial products.
718+
719+
:block:
720+
1D block-row distribution over a 1 x P grid:
721+
- ``A`` is partitioned into :math:`[N_loc \times K]` blocks across ranks.
722+
- ``X`` (and result ``Y``) are partitioned into :math:`[K \times M_loc]` blocks.
723+
- Local multiplication is followed by row-wise gather (forward) or
724+
allreduce (adjoint) across ranks.
725+
726+
Parameters
727+
----------
728+
A : NDArray
729+
Local block of the matrix operator.
730+
M : int
731+
Global number of columns in the operand and result matrices.
732+
saveAt : bool, optional
733+
If ``True``, store both ``A`` and its conjugate transpose ``A.H``
734+
to accelerate adjoint operations (uses twice the memory).
735+
Default is ``False``.
736+
base_comm : mpi4py.MPI.Comm, optional
737+
MPI communicator to use. Defaults to ``MPI.COMM_WORLD``.
738+
kind : {'summa', 'block'}, optional
739+
Algorithm to use: ``'summa'`` for the SUMMA 2D algorithm, or
740+
``'block'`` for the block-row-col decomposition. Default is ``'summa'``.
741+
dtype : DTypeLike, optional
742+
Numeric data type for computations. Default is ``np.float64``.
743+
744+
Attributes
745+
----------
746+
shape : :obj:`tuple`
747+
Operator shape
748+
comm : mpi4py.MPI.Comm
749+
The MPI communicator in use.
750+
kind : str
751+
Selected distributed matrix multiply algorithm ('summa' or 'block').
752+
753+
Raises
754+
------
755+
NotImplementedError
756+
If ``kind`` is not one of ``'summa'`` or ``'block'``.
757+
Exception
758+
If the MPI communicator does not form a compatible grid for the
759+
selected algorithm.
760+
"""
761+
def __init__(
762+
self,
763+
A: NDArray,
764+
M: int,
765+
saveAt: bool = False,
766+
base_comm: MPI.Comm = MPI.COMM_WORLD,
767+
kind:Literal["summa", "block"] = "summa",
768+
dtype: DTypeLike = "float64",
769+
):
770+
if kind == "summa":
771+
self._f = _MPISummaMatrixMult(A,M,saveAt,base_comm,dtype)
772+
elif kind == "block":
773+
self._f = _MPIBlockMatrixMult(A, M, saveAt, base_comm, dtype)
774+
else:
775+
raise NotImplementedError("kind must be summa or block")
776+
self.kind = kind
777+
super().__init__(shape=self._f.shape, dtype=dtype, base_comm=base_comm)
778+
779+
def _matvec(self, x: DistributedArray) -> DistributedArray:
780+
return self._f.matvec(x)
781+
782+
def _rmatvec(self, x: DistributedArray) -> DistributedArray:
783+
return self._f.rmatvec(x)

0 commit comments

Comments
 (0)