Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
7fcc2cf
Added impl, test and example
astroC86 Jun 2, 2025
6a9d382
Merge branch 'PyLops:main' into astroC86-SUMMA
astroC86 Jun 2, 2025
f72fce6
Addressed some comments
astroC86 Jun 10, 2025
c607283
Example formating
astroC86 Jun 10, 2025
de1a173
Rename MatrixMultiply file to MatrixMult
astroC86 Jun 10, 2025
82b7e34
Addressed more issues
astroC86 Jun 11, 2025
9e1a49f
Addressed comments
astroC86 Jun 13, 2025
22cde7b
Addressing changes
astroC86 Jun 13, 2025
740030d
Minor cosmetic changes
astroC86 Jun 13, 2025
a88dec3
More minor changes
astroC86 Jun 13, 2025
66f1770
Example shape dims general
astroC86 Jun 13, 2025
7ac593d
Added comments to example
astroC86 Jun 13, 2025
8a56096
I donot know why I thought I needed to batch
astroC86 Jun 13, 2025
42452a1
Inital docstring for matrix mult
astroC86 Jun 13, 2025
a110ff8
minor: cleanup of docstrings and updated example
mrava87 Jun 16, 2025
bd9ad37
minor: fix mistake in plot_matrixmult
mrava87 Jun 16, 2025
18db078
removed now useless bcast and fixed mask in test
astroC86 Jun 17, 2025
ef3c283
changed tests
astroC86 Jun 17, 2025
4e39068
Fixed tests and moved checks to root
astroC86 Jun 17, 2025
ed3b585
Fix internal check for MPIMatrixMult
astroC86 Jun 17, 2025
7b76f96
Fixed Notation
astroC86 Jun 17, 2025
3e9659e
Skipping test if number of procs is not square for now
astroC86 Jun 17, 2025
dd9b43c
Merge branch 'main' into astroC86-SUMMA
astroC86 Jun 18, 2025
a85e75a
Fixed Doc error
astroC86 Jun 26, 2025
b7e6702
Renamed layer and group as to row and col respectively
astroC86 Jun 27, 2025
ae5661b
minor: small improvements to text
mrava87 Jun 29, 2025
053e52d
minor: fix flake8
mrava87 Jun 29, 2025
9aedd7c
MatrixMul works with non-square prcs by creating square subcommunicator
astroC86 Jun 30, 2025
4c662d6
minor: stylistic fixes
mrava87 Jul 1, 2025
0c34b78
minor: fix flake8
mrava87 Jul 1, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 23 additions & 30 deletions examples/plot_matrixmult.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,20 +35,6 @@
# filled with data that is appropriate that is appropriate the use-case.
np.random.seed(42)

###############################################################################
# Next we obtain the MPI parameters for each rank and check that the number
# of processes (``size``) is a square number
comm = MPI.COMM_WORLD
rank = comm.Get_rank() # rank of current process
size = comm.Get_size() # number of processes

p_prime = math.isqrt(size)
repl_factor = p_prime

if (p_prime * repl_factor) != size:
print(f"Number of processes must be a square number, provided {size} instead...")
exit(-1)

###############################################################################
# We are now ready to create the input matrices :math:`\mathbf{A}` of size
# :math:`M \times k` :math:`\mathbf{A}` of size and :math:`\mathbf{A}` of size
Expand Down Expand Up @@ -93,12 +79,15 @@
# └────────────┴────────────┘
# </div>

my_col = rank % p_prime
my_row = rank // p_prime
base_comm = MPI.COMM_WORLD
comm, rank, row_id, col_id, is_active = MPIMatrixMult.active_grid_comm(base_comm, N, M)
print(f"Process {base_comm.Get_rank()} is {"active" if is_active else "inactive"}")
if not is_active: exit(0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think in real-life problem this would not fly as we usually don't want to use operators in isolation but combine them... so this way I suspect some ranks will be killed and if used by other operators the overall run will crash.

I am not so concerned as in practice we will hardly end up in these edge cases, so now if someone does for just MatrixMult things will still work, which is nice... if they do it for more complex operators that include MatrixMult, they will have to handle it...

Copy link
Contributor Author

@astroC86 astroC86 Jul 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can get rid of the exit if you want, it shouldn't affect anything since all operations from then on would use the new communicator that includes only the active procs in the newly defined square grid and so other procs are free to do what they want, or just remain idle. In theory we can come up with an example where:

base_comm = MPI.COMM_WORLD
comm, ......= MPIMatrixMult.active_grid_comm(base_comm, N, M)
....
## matrix mul here would use the new comm `comm`
.....
# all procs have to reach this barrier before proceeding to perform the second op
base_comm.barrier() 
## second op here
....

Copy link
Contributor

@mrava87 mrava87 Jul 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure you get my point here.. in PyLops-MPI we do not really care to apply two operators one after the other, this is not the purpose of the library; what we want/need is to be able to chain/stack atomic operators (like MatrixMult) to create more scientifically interesting ones and solve inverse problems. So what I really want to be able to do is something like:

Aop = MPIMatrixMult(A_p, M, base_comm=comm, dtype=dtype_str)
...

Dop = FirstDerivative(dims=(N, col_end_X - col_start_X), axis=0, dtype=np.float32)
DBop = MPIBlockDiag(ops=[Dop, ], base_comm=comm, mask=cols_id)
Op = DBop @ Aop
y1_dist = Op @ x_dist

I added this to the test and check consistency with the serial version and everything seems to work (at least for the case with all rank active)

This was an oversight from my side, I should have asked to check / checked that we could do it before merging... not too bad, we can tests and examples now - will try tonight to open a PR with what I have so far

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh ook, I see your concern! That is weird I will investigate

p_prime = math.isqrt(comm.Get_size())

# Create sub‐communicators
row_comm = comm.Split(color=my_row, key=my_col) # all procs in same row
col_comm = comm.Split(color=my_col, key=my_row) # all procs in same col
row_comm = comm.Split(color=row_id, key=col_id) # all procs in same row
col_comm = comm.Split(color=col_id, key=row_id) # all procs in same col

################################################################################
# At this point we divide the rows and columns of :math:`\mathbf{A}` and
Expand Down Expand Up @@ -136,20 +125,20 @@
blk_rows = int(math.ceil(N / p_prime))
blk_cols = int(math.ceil(M / p_prime))

rs = my_col * blk_rows
rs = col_id * blk_rows
re = min(N, rs + blk_rows)
my_own_rows = re - rs
my_own_rows = max(0,re - rs)

cs = my_row * blk_cols
cs = row_id * blk_cols
ce = min(M, cs + blk_cols)
my_own_cols = ce - cs
my_own_cols = max(0,ce - cs)

A_p, X_p = A[rs:re, :].copy(), X[:, cs:ce].copy()

################################################################################
# We are now ready to create the :py:class:`pylops_mpi.basicoperators.MPIMatrixMult`
# operator and the input matrix :math:`\mathbf{X}`
Aop = MPIMatrixMult(A_p, M, dtype="float32")
Aop = MPIMatrixMult(A_p, M, base_comm=comm, dtype="float32")

col_lens = comm.allgather(my_own_cols)
total_cols = np.sum(col_lens)
Expand Down Expand Up @@ -188,9 +177,11 @@
offset = 0
for cnt in col_counts:
block_size = N * cnt
y_blocks.append(
y[offset: offset + block_size].reshape(N, cnt)
)
y_block = y[offset: offset + block_size]
if len(y_block) != 0:
y_blocks.append(
y_block.reshape(N, cnt)
)
offset += block_size
y = np.hstack(y_blocks)

Expand All @@ -199,13 +190,15 @@
offset = 0
for cnt in col_counts:
block_size = K * cnt
xadj_blocks.append(
xadj[offset: offset + block_size].reshape(K, cnt)
)
xadj_blk = xadj[offset: offset + block_size]
if len(xadj_blk)!= 0:
xadj_blocks.append(
xadj_blk.reshape(K, cnt)
)
offset += block_size
xadj = np.hstack(xadj_blocks)

if rank == 0:
if comm.Get_rank() == 0:
y_loc = (A @ X).squeeze()
xadj_loc = (A.T.dot(y_loc.conj())).conj().squeeze()

Expand Down
63 changes: 57 additions & 6 deletions pylops_mpi/basicoperators/MatrixMult.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def __init__(
self._col_start = self._row_id * block_cols
self._col_end = min(self.M, self._col_start + block_cols)

self._local_ncols = self._col_end - self._col_start
self._local_ncols = max(0, self._col_end - self._col_start)
self._rank_col_lens = self.base_comm.allgather(self._local_ncols)
total_ncols = np.sum(self._rank_col_lens)

Expand All @@ -168,11 +168,14 @@ def _matvec(self, x: DistributedArray) -> DistributedArray:
if x.partition != Partition.SCATTER:
raise ValueError(f"x should have partition={Partition.SCATTER} Got {x.partition} instead...")

y = DistributedArray(global_shape=(self.N * self.dimsd[1]),
local_shapes=[(self.N * c) for c in self._rank_col_lens],
mask=x.mask,
partition=Partition.SCATTER,
dtype=self.dtype)
y = DistributedArray(
global_shape=(self.N * self.dimsd[1]),
local_shapes=[(self.N * c) for c in self._rank_col_lens],
mask=x.mask,
partition=Partition.SCATTER,
dtype=self.dtype,
base_comm=self.base_comm
)

my_own_cols = self._rank_col_lens[self.rank]
x_arr = x.local_array.reshape((self.dims[0], my_own_cols))
Expand All @@ -185,6 +188,53 @@ def _matvec(self, x: DistributedArray) -> DistributedArray:
y[:] = Y_local.flatten()
return y

@staticmethod
def active_grid_comm(base_comm:MPI.Comm, N:int, M:int):
"""
Configure a square process grid from a parent MPI communicator and select the subset of "active" processes.

Each process in base_comm is assigned to a logical 2D grid of size p_prime x p_prime,
where p_prime = floor(sqrt(total_ranks)). Only the first `active_dim x active_dim` processes
(by row-major order) are considered "active". Inactive ranks return immediately with no new communicator.

Parameters:
-----------
base_comm : MPI.Comm
The parent communicator (e.g., MPI.COMM_WORLD).
N : int
Number of rows of your global data domain.
M : int
Number of columns of your global data domain.

Returns:
--------
tuple:
comm (MPI.Comm or None) : Sub-communicator including only active ranks.
rank (int) : Rank within the new sub-communicator (or original rank if inactive).
row (int) : Grid row index of this process in the active grid (or original rank if inactive).
col (int) : Grid column index of this process in the active grid (or original rank if inactive).
is_active (bool) : Flag indicating whether this rank is in the active sub-grid.
"""
rank = base_comm.Get_rank()
size = base_comm.Get_size()
p_prime = math.isqrt(size)
row, col = divmod(rank, p_prime)
active_dim = min(N, M, p_prime)
is_active = (row < active_dim and col < active_dim)

if not is_active:
return None, rank, row, col, False

active_ranks = [r for r in range(size)
if (r // p_prime) < active_dim and (r % p_prime) < active_dim]
new_group = base_comm.Get_group().Incl(active_ranks)
new_comm = base_comm.Create_group(new_group)

p_prime_new = math.isqrt(len(active_ranks))
new_rank = new_comm.Get_rank()
new_row, new_col = divmod(new_rank, p_prime_new)
return new_comm, new_rank, new_row, new_col, True

def _rmatvec(self, x: DistributedArray) -> DistributedArray:
ncp = get_module(x.engine)
if x.partition != Partition.SCATTER:
Expand All @@ -196,6 +246,7 @@ def _rmatvec(self, x: DistributedArray) -> DistributedArray:
mask=x.mask,
partition=Partition.SCATTER,
dtype=self.dtype,
base_comm=self.base_comm
)

x_arr = x.local_array.reshape((self.N, self._local_ncols)).astype(self.dtype)
Expand Down
45 changes: 19 additions & 26 deletions tests/test_matrixmult.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,7 @@
from pylops_mpi.basicoperators.MatrixMult import MPIMatrixMult

np.random.seed(42)
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
size = comm.Get_size()
base_comm = MPI.COMM_WORLD

# Define test cases: (N, K, M, dtype_str)
# M, K, N are matrix dimensions A(N,K), B(K,M)
Expand All @@ -32,31 +30,25 @@
@pytest.mark.mpi(min_size=1)
@pytest.mark.parametrize("M, K, N, dtype_str", test_params)
def test_SUMMAMatrixMult(N, K, M, dtype_str):
p_prime = math.isqrt(size)
C = p_prime
if p_prime * C != size:
pytest.skip("Number of processes must be a square number, "
"provided {size} instead...")

dtype = np.dtype(dtype_str)

cmplx = 1j if np.issubdtype(dtype, np.complexfloating) else 0
base_float_dtype = np.float32 if dtype == np.complex64 else np.float64

my_col = rank % p_prime
my_row = rank // p_prime
comm, rank, row_id, col_id, is_active = MPIMatrixMult.active_grid_comm(base_comm, N, M)
print(f"Process {base_comm.Get_rank()} is {"active" if is_active else "inactive"}")
if not is_active: return

# Create sub-communicators
row_comm = comm.Split(color=my_row, key=my_col)
col_comm = comm.Split(color=my_col, key=my_row)
size = comm.Get_size()
p_prime = math.isqrt(size)

# Calculate local matrix dimensions
blk_rows_A = int(math.ceil(N / p_prime))
row_start_A = my_col * blk_rows_A
row_start_A = col_id * blk_rows_A
row_end_A = min(N, row_start_A + blk_rows_A)

blk_cols_X = int(math.ceil(M / p_prime))
col_start_X = my_row * blk_cols_X
col_start_X = row_id * blk_cols_X
col_end_X = min(M, col_start_X + blk_cols_X)
local_col_X_len = max(0, col_end_X - col_start_X)

Expand Down Expand Up @@ -102,9 +94,11 @@ def test_SUMMAMatrixMult(N, K, M, dtype_str):
offset = 0
for cnt in col_counts:
block_size = N * cnt
y_blocks.append(
y[offset: offset + block_size].reshape(N, cnt)
)
y_block = y[offset: offset + block_size]
if len(y_block) != 0:
y_blocks.append(
y_block.reshape(N, cnt)
)
offset += block_size
y = np.hstack(y_blocks)

Expand All @@ -113,9 +107,11 @@ def test_SUMMAMatrixMult(N, K, M, dtype_str):
offset = 0
for cnt in col_counts:
block_size = K * cnt
xadj_blocks.append(
xadj[offset: offset + block_size].reshape(K, cnt)
)
xadj_blk = xadj[offset: offset + block_size]
if len(xadj_blk) != 0:
xadj_blocks.append(
xadj_blk.reshape(K, cnt)
)
offset += block_size
xadj = np.hstack(xadj_blocks)

Expand All @@ -134,7 +130,4 @@ def test_SUMMAMatrixMult(N, K, M, dtype_str):
xadj_loc.squeeze(),
rtol=np.finfo(np.dtype(dtype)).resolution,
err_msg=f"Rank {rank}: Ajoint verification failed."
)

col_comm.Free()
row_comm.Free()
)