Skip to content
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
7fcc2cf
Added impl, test and example
astroC86 Jun 2, 2025
6a9d382
Merge branch 'PyLops:main' into astroC86-SUMMA
astroC86 Jun 2, 2025
f72fce6
Addressed some comments
astroC86 Jun 10, 2025
c607283
Example formating
astroC86 Jun 10, 2025
de1a173
Rename MatrixMultiply file to MatrixMult
astroC86 Jun 10, 2025
82b7e34
Addressed more issues
astroC86 Jun 11, 2025
9e1a49f
Addressed comments
astroC86 Jun 13, 2025
22cde7b
Addressing changes
astroC86 Jun 13, 2025
740030d
Minor cosmetic changes
astroC86 Jun 13, 2025
a88dec3
More minor changes
astroC86 Jun 13, 2025
66f1770
Example shape dims general
astroC86 Jun 13, 2025
7ac593d
Added comments to example
astroC86 Jun 13, 2025
8a56096
I donot know why I thought I needed to batch
astroC86 Jun 13, 2025
42452a1
Inital docstring for matrix mult
astroC86 Jun 13, 2025
a110ff8
minor: cleanup of docstrings and updated example
mrava87 Jun 16, 2025
bd9ad37
minor: fix mistake in plot_matrixmult
mrava87 Jun 16, 2025
18db078
removed now useless bcast and fixed mask in test
astroC86 Jun 17, 2025
ef3c283
changed tests
astroC86 Jun 17, 2025
4e39068
Fixed tests and moved checks to root
astroC86 Jun 17, 2025
ed3b585
Fix internal check for MPIMatrixMult
astroC86 Jun 17, 2025
7b76f96
Fixed Notation
astroC86 Jun 17, 2025
3e9659e
Skipping test if number of procs is not square for now
astroC86 Jun 17, 2025
dd9b43c
Merge branch 'main' into astroC86-SUMMA
astroC86 Jun 18, 2025
a85e75a
Fixed Doc error
astroC86 Jun 26, 2025
b7e6702
Renamed layer and group as to row and col respectively
astroC86 Jun 27, 2025
ae5661b
minor: small improvements to text
mrava87 Jun 29, 2025
053e52d
minor: fix flake8
mrava87 Jun 29, 2025
9aedd7c
MatrixMul works with non-square prcs by creating square subcommunicator
astroC86 Jun 30, 2025
4c662d6
minor: stylistic fixes
mrava87 Jul 1, 2025
0c34b78
minor: fix flake8
mrava87 Jul 1, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
172 changes: 172 additions & 0 deletions examples/plot_matrixmult.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
"""
Distributed Matrix Multiplication
=================================
This example shows how to use the :py:class:`pylops_mpi.basicoperators.MatrixMult.MPIMatrixMult`.
This class provides a way to distribute arrays across multiple processes in
a parallel computing environment.
"""
from matplotlib import pyplot as plt
import math
import numpy as np
from mpi4py import MPI

from pylops_mpi import DistributedArray, Partition
from pylops_mpi.basicoperators.MatrixMult import MPIMatrixMult

plt.close("all")
###############################################################################
# We set the seed such that all processes initially start out with the same initial matrix.
# Ideally this data would be loaded in a manner appropriate to the use-case.
np.random.seed(42)

# MPI parameters
comm = MPI.COMM_WORLD
rank = comm.Get_rank() # rank of current process
size = comm.Get_size() # number of processes

p_prime = int(math.ceil(math.sqrt(size)))
C = int(math.ceil(size / p_prime))

if (p_prime * C) != size:
print("No. of procs has to be a square number")
exit(-1)

# matrix dims
M, K, N = 4, 4, 4
A = np.random.rand(M * K).astype(dtype=np.float32).reshape(M, K)
X = np.random.rand(K * N).astype(dtype=np.float32).reshape(K, N)
################################################################################
#Process Grid Organization
#*************************
#
#The processes are arranged in a :math:`\sqrt{P} \times \sqrt{P}` grid, where :math:`P` is the total number of processes.
#
#Define
#
#.. math::
# P' = \bigl \lceil \sqrt{P} \bigr \rceil
#
#and the replication factor
#
#.. math::
# C = \bigl\lceil \tfrac{P}{P'} \bigr\rceil.
#
#Each process is assigned a pair of coordinates :math:`(g, l)` within this grid:
#
#.. math::
# g = \mathrm{rank} \bmod P',
# \quad
# l = \left\lfloor \frac{\mathrm{rank}}{P'} \right\rfloor.
#
#For example, when :math:`P = 4` we have :math:`P' = 2`, giving a 2×2 layout:
#
#.. raw:: html
#
# <div style="text-align: center; font-family: monospace; white-space: pre;">
# ┌────────────┬────────────┐
# │ Rank 0 │ Rank 1 │
# │ (g=0, l=0) │ (g=1, l=0) │
# ├────────────┼────────────┤
# │ Rank 2 │ Rank 3 │
# │ (g=0, l=1) │ (g=1, l=1) │
# └────────────┴────────────┘
# </div>

my_group = rank % p_prime
my_layer = rank // p_prime

# Create the sub‐communicators
layer_comm = comm.Split(color=my_layer, key=my_group) # all procs in same layer
group_comm = comm.Split(color=my_group, key=my_layer) # all procs in same group

blk_rows = int(math.ceil(M / p_prime))
blk_cols = int(math.ceil(N / p_prime))

rs = my_group * blk_rows
re = min(M, rs + blk_rows)
my_own_rows = re - rs

cs = my_layer * blk_cols
ce = min(N, cs + blk_cols)
my_own_cols = ce - cs

################################################################################
#Each rank will end up with:
# - :math:`A_{p} \in \mathbb{R}^{\text{my_own_rows}\times K}`
# - :math:`X_{p} \in \mathbb{R}^{K\times \text{my_own_cols}}`
#as follows:
A_p, X_p = A[rs:re, :].copy(), X[:, cs:ce].copy()

################################################################################
#.. raw:: html
#
# <div style="text-align: left; font-family: monospace; white-space: pre;">
# <b>Matrix A (4 x 4):</b>
# ┌─────────────────┐
# │ a11 a12 a13 a14 │ <- Rows 0–1 (Group 0)
# │ a21 a22 a23 a24 │
# ├─────────────────┤
# │ a41 a42 a43 a44 │ <- Rows 2–3 (Group 1)
# │ a51 a52 a53 a54 │
# └─────────────────┘
# </div>
#
#.. raw:: html
#
# <div style="text-align: left; font-family: monospace; white-space: pre;">
# <b>Matrix B (4 x 4):</b>
# ┌─────────┬─────────┐
# │ b11 b12 │ b13 b14 │ <- Cols 0–1 (Layer 0), Cols 2–3 (Layer 1)
# │ b21 b22 │ b23 b24 │
# │ b31 b32 │ b33 b34 │
# │ b41 b42 │ b43 b44 │
# └─────────┴─────────┘
#
# </div>
#

################################################################################
#Forward Operation
#*****************
#To perform our distributed matrix-matrix multiplication :math:`Y = \text{Aop} \times X` we need to create our distributed operator :math:`\text{Aop}` and distributed operand :math:`X` from :math:`A_p` and
#:math:`X_p` respectively
Aop = MPIMatrixMult(A_p, N, dtype="float32")
################################################################################
# While as well passing the appropriate values.
col_lens = comm.allgather(my_own_cols)
total_cols = np.sum(col_lens)
x = DistributedArray(global_shape=K * total_cols,
local_shapes=[K * col_len for col_len in col_lens],
partition=Partition.SCATTER,
mask=[i // p_prime for i in range(comm.Get_size())],
base_comm=comm,
dtype="float32")
x[:] = X_p.flatten()
################################################################################
#When we perform the matrix-matrix multiplication we shall then obtain a distributed :math:`Y` in the same way our :math:`X` was distributed.
y = Aop @ x
###############################################################################
#Adjoint Operation
#*****************
# In a similar fashion we then perform the Adjoint :math:`Xadj = A^H * Y`
xadj = Aop.H @ y
###############################################################################
#Here we verify the result against the equivalent serial version of the operation. Each rank checks that it has computed the correct values for it partition.
y_loc = A @ X
xadj_loc = (A.T.dot(y_loc.conj())).conj()

expected_y_loc = y_loc[:, cs:ce].flatten().astype(np.float32)
expected_xadj_loc = xadj_loc[:, cs:ce].flatten().astype(np.float32)

if not np.allclose(y.local_array, expected_y_loc, rtol=1e-6):
print(f"RANK {rank}: FORWARD VERIFICATION FAILED")
print(f'{rank} local: {y.local_array}, expected: {y_loc[:, cs:ce]}')
else:
print(f"RANK {rank}: FORWARD VERIFICATION PASSED")

if not np.allclose(xadj.local_array, expected_xadj_loc, rtol=1e-6):
print(f"RANK {rank}: ADJOINT VERIFICATION FAILED")
print(f'{rank} local: {xadj.local_array}, expected: {xadj_loc[:, cs:ce]}')
else:
print(f"RANK {rank}: ADJOINT VERIFICATION PASSED")

206 changes: 206 additions & 0 deletions pylops_mpi/basicoperators/MatrixMult.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
import numpy as np
import math
from mpi4py import MPI
from pylops.utils.backend import get_module
from pylops.utils.typing import DTypeLike, NDArray

from pylops_mpi import (
DistributedArray,
MPILinearOperator,
Partition
)


class MPIMatrixMult(MPILinearOperator):
r"""
Distributed Matrix-Matrix multiplication
Implements a distributed matrix-matrix multiplication

Parameters
----------
A : :obj:`numpy.ndarray`
Local block of the matrix multiplication operator of shape ``(M_loc, K)``
where ``M_loc`` is the number of rows stored on this MPI rank and
``K`` is the global number of columns.
N : :obj:`int`
Global leading dimension of the operand matrix (number of columns).
saveAt : :obj:`bool`, optional
Save ``A`` and ``A.H`` to speed up the computation of adjoint
(``True``) or create ``A.H`` on-the-fly (``False``)
Note that ``saveAt=True`` will double the amount of required memory.
The default is ``False``.
base_comm : :obj:`mpi4py.MPI.Comm`, optional
MPI Base Communicator. Defaults to ``mpi4py.MPI.COMM_WORLD``.
dtype : :obj:`str`, optional
Type of elements in input array.

Attributes
----------
shape : :obj:`tuple`
Operator shape

Raises
------
Exception
If the operator is created without a square number of mpi ranks.
ValueError
If input vector does not have the correct partition type.

Notes
-----
This implementation uses a 1D block distribution of the operand matrix and
operator replicated across the processes math:`P` by a factor equivalent
to math:`\sqrt{P}` across a square process grid ( math:`\sqrt{P}\times\sqrt{P}`).

The operator implements a distributed matrix-matrix multiplication where:

- The matrix ``A`` is distributed across MPI processes in a block-row fashion
- Each process holds a local block of ``A`` with shape ``(M_loc, K)``
- The operand matrix ``X`` is distributed in a block-column fashion
- Communication is minimized by using a 2D process grid layout

The forward operation computes :math:`Y = A \cdot X` where:

- :math:`A` is the distributed matrix operator of shape ``(M, K)``
- :math:`X` is the distributed operand matrix of shape ``(K, N)``
- :math:`Y` is the resulting distributed matrix of shape ``(M, N)``

The adjoint operation computes :math:`Y = A^H \cdot X` where :math:`A^H`
is the conjugate transpose of :math:`A`.

Steps for the Forward Operation (:math:`Y = A \cdot X`)
----------------------------------------
1. **Input Preparation**: The input vector ``x`` (flattened from matrix ``X``
of shape ``(K, N)``) is reshaped to ``(K, N_local)`` where ``N_local``
is the number of columns assigned to the current process.

2. **Data Broadcasting**: Within each layer (processes with same ``layer_id``),
the operand data is broadcast from the process whose ``group_id`` matches
the ``layer_id``. This ensures all processes in a layer have access to
the same operand columns.

3. **Local Computation**: Each process computes ``A_local @ X_local`` where:
- ``A_local`` is the local block of matrix ``A`` (shape ``M_local × K``)
- ``X_local`` is the broadcasted operand (shape ``K × N_local``)

4. **Layer Gather**: Results from all processes in each layer are gathered
using ``allgather`` to reconstruct the full result matrix vertically.


Steps for the Adjoint Operation (:math:`Y = A^H \cdot X`)
-------------------------------------------
The adjoint operation performs the conjugate transpose multiplication:

1. **Input Reshaping**: The input vector ``x`` is reshaped to ``(M, N_local)``
representing the local columns of the input matrix.

2. **Local Adjoint Computation**:
Each process computes ``A_local.H @ X_tile``
where ``A_local.H`` is either:
- Pre-computed ``At`` (if ``saveAt=True``)
- Computed on-the-fly as ``A.T.conj()`` (if ``saveAt=False``)
Each process multiplies its transposed local ``A`` block ``A_local^H`` (shape ``K × M_block``)
with the extracted ``X_tile`` (shape ``M_block × N_local``),
producing a partial result of shape ``(K, N_local)``.
This computes the local contribution of columns of ``A^H`` to the final result.

3. **Layer Reduction**: Since the full result ``Y = A^H \cdot X`` is the
sum of contributions from all column blocks of ``A^H``, processes in the
same layer perform an ``allreduce`` sum to combine their partial results.
This gives the complete ``(K, N_local)`` result for their assigned columns.
"""
def __init__(
self,
A: NDArray,
N: int,
saveAt: bool = False,
base_comm: MPI.Comm = MPI.COMM_WORLD,
dtype: DTypeLike = "float64",
) -> None:
rank = base_comm.Get_rank()
size = base_comm.Get_size()

# Determine grid dimensions (P_prime × C) such that P_prime * C ≥ size
self._P_prime = int(math.ceil(math.sqrt(size)))
self._C = int(math.ceil(size / self._P_prime))
if self._P_prime * self._C != size:
raise Exception("Number of Procs must be a square number")

# Compute this process's group and layer indices
self._group_id = rank % self._P_prime
self._layer_id = rank // self._P_prime

# Split communicators by layer (rows) and by group (columns)
self.base_comm = base_comm
self._layer_comm = base_comm.Split(color=self._layer_id, key=self._group_id)
self._group_comm = base_comm.Split(color=self._group_id, key=self._layer_id)

self.A = A.astype(np.dtype(dtype))
if saveAt: self.At = A.T.conj()

self.M = self._layer_comm.allreduce(self.A.shape[0], op=MPI.SUM)
self.K = A.shape[1]
self.N = N

block_cols = int(math.ceil(self.N / self._P_prime))
blk_rows = int(math.ceil(self.M / self._P_prime))

self._row_start = self._group_id * blk_rows
self._row_end = min(self.M, self._row_start + blk_rows)

self._col_start = self._layer_id * block_cols
self._col_end = min(self.N, self._col_start + block_cols)

self._local_ncols = self._col_end - self._col_start
self._rank_col_lens = self.base_comm.allgather(self._local_ncols)
total_ncols = np.sum(self._rank_col_lens)

self.dims = (self.K, total_ncols)
self.dimsd = (self.M, total_ncols)
shape = (int(np.prod(self.dimsd)), int(np.prod(self.dims)))
super().__init__(shape=shape, dtype=np.dtype(dtype), base_comm=base_comm)

def _matvec(self, x: DistributedArray) -> DistributedArray:
ncp = get_module(x.engine)
if x.partition != Partition.SCATTER:
raise ValueError(f"x should have partition={Partition.SCATTER} Got {x.partition} instead...")

y = DistributedArray(global_shape=(self.M * self.dimsd[1]),
local_shapes=[(self.M * c) for c in self._rank_col_lens],
mask=x.mask,
partition=Partition.SCATTER,
dtype=self.dtype)

my_own_cols = self._rank_col_lens[self.rank]
x_arr = x.local_array.reshape((self.dims[0], my_own_cols))
x_arr = x_arr.astype(self.dtype)

X_local = self._layer_comm.bcast(x_arr if self._group_id == self._layer_id else None, root=self._layer_id)
Y_local = ncp.vstack(
self._layer_comm.allgather(
ncp.matmul(self.A, X_local)
)
)
y[:] = Y_local.flatten()
return y

def _rmatvec(self, x: DistributedArray) -> DistributedArray:
ncp = get_module(x.engine)
if x.partition != Partition.SCATTER:
raise ValueError(f"x should have partition={Partition.SCATTER}. Got {x.partition} instead.")

y = DistributedArray(
global_shape=(self.K * self.dimsd[1]),
local_shapes=[self.K * c for c in self._rank_col_lens],
mask=x.mask,
partition=Partition.SCATTER,
dtype=self.dtype,
)

x_arr = x.local_array.reshape((self.M, self._local_ncols)).astype(self.dtype)
X_tile = x_arr[self._row_start:self._row_end, :]
A_local = self.At if hasattr(self, "At") else self.A.T.conj()
Y_local = ncp.matmul(A_local, X_tile)
y_layer = self._layer_comm.allreduce(Y_local, op=MPI.SUM)
y[:] = y_layer.flatten()
return y
Loading
Loading