Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
7fcc2cf
Added impl, test and example
astroC86 Jun 2, 2025
6a9d382
Merge branch 'PyLops:main' into astroC86-SUMMA
astroC86 Jun 2, 2025
f72fce6
Addressed some comments
astroC86 Jun 10, 2025
c607283
Example formating
astroC86 Jun 10, 2025
de1a173
Rename MatrixMultiply file to MatrixMult
astroC86 Jun 10, 2025
82b7e34
Addressed more issues
astroC86 Jun 11, 2025
9e1a49f
Addressed comments
astroC86 Jun 13, 2025
22cde7b
Addressing changes
astroC86 Jun 13, 2025
740030d
Minor cosmetic changes
astroC86 Jun 13, 2025
a88dec3
More minor changes
astroC86 Jun 13, 2025
66f1770
Example shape dims general
astroC86 Jun 13, 2025
7ac593d
Added comments to example
astroC86 Jun 13, 2025
8a56096
I donot know why I thought I needed to batch
astroC86 Jun 13, 2025
42452a1
Inital docstring for matrix mult
astroC86 Jun 13, 2025
a110ff8
minor: cleanup of docstrings and updated example
mrava87 Jun 16, 2025
bd9ad37
minor: fix mistake in plot_matrixmult
mrava87 Jun 16, 2025
18db078
removed now useless bcast and fixed mask in test
astroC86 Jun 17, 2025
ef3c283
changed tests
astroC86 Jun 17, 2025
4e39068
Fixed tests and moved checks to root
astroC86 Jun 17, 2025
ed3b585
Fix internal check for MPIMatrixMult
astroC86 Jun 17, 2025
7b76f96
Fixed Notation
astroC86 Jun 17, 2025
3e9659e
Skipping test if number of procs is not square for now
astroC86 Jun 17, 2025
dd9b43c
Merge branch 'main' into astroC86-SUMMA
astroC86 Jun 18, 2025
a85e75a
Fixed Doc error
astroC86 Jun 26, 2025
b7e6702
Renamed layer and group as to row and col respectively
astroC86 Jun 27, 2025
ae5661b
minor: small improvements to text
mrava87 Jun 29, 2025
053e52d
minor: fix flake8
mrava87 Jun 29, 2025
9aedd7c
MatrixMul works with non-square prcs by creating square subcommunicator
astroC86 Jun 30, 2025
4c662d6
minor: stylistic fixes
mrava87 Jul 1, 2025
0c34b78
minor: fix flake8
mrava87 Jul 1, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 104 additions & 0 deletions examples/matrixmul.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import sys
import math
import numpy as np
from mpi4py import MPI

from pylops_mpi import DistributedArray, Partition
from pylops_mpi.basicoperators.MatrixMultiply import SUMMAMatrixMult

np.random.seed(42)

comm = MPI.COMM_WORLD
rank = comm.Get_rank()
nProcs = comm.Get_size()


P_prime = int(math.ceil(math.sqrt(nProcs)))
C = int(math.ceil(nProcs / P_prime))
assert P_prime * C >= nProcs

# matrix dims
M = 32 # any M
K = 32 # any K
N = 35 # any N

blk_rows = int(math.ceil(M / P_prime))
blk_cols = int(math.ceil(N / P_prime))

my_group = rank % P_prime
my_layer = rank // P_prime

# sub‐communicators
layer_comm = comm.Split(color=my_layer, key=my_group) # all procs in same layer
group_comm = comm.Split(color=my_group, key=my_layer) # all procs in same group

# Each rank will end up with:
# A_p: shape (my_own_rows, K)
# B_p: shape (K, my_own_cols)
# where
row_start = my_group * blk_rows
row_end = min(M, row_start + blk_rows)
my_own_rows = row_end - row_start

col_start = my_group * blk_cols # note: same my_group index on cols
col_end = min(N, col_start + blk_cols)
my_own_cols = col_end - col_start

# ======================= BROADCASTING THE SLICES =======================
if rank == 0:
A = np.arange(M*K, dtype=np.float32).reshape(M, K)
B = np.arange(K*N, dtype=np.float32).reshape(K, N)
for dest in range(nProcs):
pg = dest % P_prime
rs = pg*blk_rows; re = min(M, rs+blk_rows)
cs = pg*blk_cols; ce = min(N, cs+blk_cols)
a_block , b_block = A[rs:re, :].copy(), B[:, cs:ce].copy()
if dest == 0:
A_p, B_p = a_block, b_block
else:
comm.Send(a_block, dest=dest, tag=100+dest)
comm.Send(b_block, dest=dest, tag=200+dest)
else:
A_p = np.empty((my_own_rows, K), dtype=np.float32)
B_p = np.empty((K, my_own_cols), dtype=np.float32)
comm.Recv(A_p, source=0, tag=100+rank)
comm.Recv(B_p, source=0, tag=200+rank)

comm.Barrier()

Aop = SUMMAMatrixMult(A_p, N)
col_lens = comm.allgather(my_own_cols)
total_cols = np.add.reduce(col_lens, 0)
x = DistributedArray(global_shape=K * total_cols,
local_shapes=[K * col_len for col_len in col_lens],
partition=Partition.SCATTER,
mask=[i % P_prime for i in range(comm.Get_size())],
dtype=np.float32)
x[:] = B_p.flatten()
y = Aop @ x

# ======================= VERIFICATION =================-=============
A = np.arange(M*K).reshape(M, K).astype(np.float32)
B = np.arange(K*N).reshape(K, N).astype(np.float32)
C_true = A @ B
Z_true = (A.T.dot(C_true.conj())).conj()


col_start = my_layer * blk_cols # note: same my_group index on cols
col_end = min(N, col_start + blk_cols)
my_own_cols = col_end - col_start
expected_y = C_true[:,col_start:col_end].flatten()

if not np.allclose(y.local_array, expected_y, atol=1e-6, rtol=1e-14):
print(f"RANK {rank}: FORWARD VERIFICATION FAILED")
print(f'{rank} local: {y.local_array}, expected: {C_true[:,col_start:col_end]}')
else:
print(f"RANK {rank}: FORWARD VERIFICATION PASSED")

z = Aop.H @ y
expected_z = Z_true[:,col_start:col_end].flatten()
if not np.allclose(z.local_array, expected_z, atol=1e-6, rtol=1e-14):
print(f"RANK {rank}: ADJOINT VERIFICATION FAILED")
print(f'{rank} local: {z.local_array}, expected: {Z_true[:,col_start:col_end]}')
else:
print(f"RANK {rank}: ADJOINT VERIFICATION PASSED")
142 changes: 142 additions & 0 deletions pylops_mpi/basicoperators/MatrixMultiply.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
import numpy as np
import math
from mpi4py import MPI
from pylops.utils.backend import get_module
from pylops.utils.typing import DTypeLike, NDArray

from pylops_mpi import (
DistributedArray,
MPILinearOperator,
Partition
)


class SUMMAMatrixMult(MPILinearOperator):
def __init__(
self,
A: NDArray,
N: int,
base_comm: MPI.Comm = MPI.COMM_WORLD,
dtype: DTypeLike = "float64",
) -> None:
rank = base_comm.Get_rank()
size = base_comm.Get_size()

# Determine grid dimensions (P_prime × C) such that P_prime * C ≥ size
self._P_prime = int(math.ceil(math.sqrt(size)))
self._C = int(math.ceil(size / self._P_prime))
assert self._P_prime * self._C >= size

# Compute this process's group and layer indices
self._group_id = rank % self._P_prime
self._layer_id = rank // self._P_prime

# Split communicators by layer (rows) and by group (columns)
self.base_comm = base_comm
self._layer_comm = base_comm.Split(color=self._layer_id, key=self._group_id)
self._group_comm = base_comm.Split(color=self._group_id, key=self._layer_id)

self.dtype = np.dtype(dtype)
self.A = np.array(A, dtype=self.dtype, copy=False)

self.M = self._layer_comm.allreduce(self.A.shape[0], op=MPI.SUM)
self.K = A.shape[1]
self.N = N

# Determine how many columns each group holds
block_cols = int(math.ceil(self.N / self._P_prime))
local_col_start = self._group_id * block_cols
local_col_end = min(self.N, local_col_start + block_cols)
local_ncols = local_col_end - local_col_start

# Sum up the total number of input columns across all processes
total_ncols = base_comm.allreduce(local_ncols, op=MPI.SUM)
self.dims = (self.K, total_ncols)

# Recompute how many output columns each layer holds
layer_col_start = self._layer_id * block_cols
layer_col_end = min(self.N, layer_col_start + block_cols)
layer_ncols = layer_col_end - layer_col_start
total_layer_cols = self.base_comm.allreduce(layer_ncols, op=MPI.SUM)

self.dimsd = (self.M, total_layer_cols)
shape = (int(np.prod(self.dimsd)), int(np.prod(self.dims)))

super().__init__(shape=shape, dtype=np.dtype(dtype), base_comm=base_comm)

def _matvec(self, x: DistributedArray) -> DistributedArray:
ncp = get_module(x.engine)
if x.partition != Partition.SCATTER:
raise ValueError(f"x should have partition={Partition.SCATTER} Got {x.partition} instead...")
blk_cols = int(math.ceil(self.N / self._P_prime))
col_start = self._group_id * blk_cols
col_end = min(self.N, col_start + blk_cols)
my_own_cols = max(0, col_end - col_start)
x = x.local_array.reshape((self.dims[0], my_own_cols))
x = x.astype(self.dtype, copy=False)
B_block = self._layer_comm.bcast(x if self._group_id == self._layer_id else None, root=self._layer_id)
C_local = ncp.vstack(
self._layer_comm.allgather(
ncp.matmul(self.A, B_block)
)
)

layer_col_start = self._layer_id * blk_cols
layer_col_end = min(self.N, layer_col_start + blk_cols)
layer_ncols = max(0, layer_col_end - layer_col_start)
layer_col_lens = self.base_comm.allgather(layer_ncols)
mask = [i // self._P_prime for i in range(self.size)]

y = DistributedArray(global_shape= (self.M * self.dimsd[1]),
local_shapes=[(self.M * c) for c in layer_col_lens],
mask=mask,
#axis=1,
partition=Partition.SCATTER,
dtype=self.dtype)
y[:] = C_local.flatten()
return y

def _rmatvec(self, x: DistributedArray) -> DistributedArray:
ncp = get_module(x.engine)
if x.partition != Partition.SCATTER:
raise ValueError(f"x should have partition={Partition.SCATTER}. Got {x.partition} instead.")

# Determine local column block for this layer
blk_cols = int(math.ceil(self.N / self._P_prime))
layer_col_start = self._layer_id * blk_cols
layer_col_end = min(self.N, layer_col_start + blk_cols)
layer_ncols = layer_col_end - layer_col_start
layer_col_lens = self.base_comm.allgather(layer_ncols)
x = x.local_array.reshape((self.M, layer_ncols))

# Determine local row block for this process group
blk_rows = int(math.ceil(self.M / self._P_prime))
row_start = self._group_id * blk_rows
row_end = min(self.M, row_start + blk_rows)

B_tile = x[row_start:row_end, :].astype(self.dtype, copy=False)
A_local = self.A.T.conj()

m, b = A_local.shape
pad = (-m) % self._P_prime
r = (m + pad) // self._P_prime
A_pad = np.pad(A_local, ((0, pad), (0, 0)), mode='constant', constant_values=0)
A_batch = A_pad.reshape(self._P_prime, r, b)

# Perform local matmul and unpad
Y_batch = ncp.matmul(A_batch, B_tile)
Y_pad = Y_batch.reshape(r * self._P_prime, -1)
y_local = Y_pad[:m, :]
y_layer = self._layer_comm.allreduce(y_local, op=MPI.SUM)

mask = [i // self._P_prime for i in range(self.size)]
y = DistributedArray(
global_shape=(self.K * self.dimsd[1]),
local_shapes=[self.K * c for c in layer_col_lens],
mask=mask,
#axis=1
partition=Partition.SCATTER,
dtype=self.dtype,
)
y[:] = y_layer.flatten()
return y
Loading