Skip to content

Commit f72fce6

Browse files
committed
Addressed some comments
1 parent 6a9d382 commit f72fce6

File tree

3 files changed

+22
-15
lines changed

3 files changed

+22
-15
lines changed

examples/matrixmul.py renamed to examples/plot_matrixmult.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,17 @@
1-
import sys
1+
"""
2+
Distributed Matrix Multiplication
3+
=========================
4+
This example shows how to use the :py:class:`pylops_mpi.basicoperators.MatrixMultiply.SUMMAMatrixMult`.
5+
This class provides a way to distribute arrays across multiple processes in
6+
a parallel computing environment.
7+
"""
8+
29
import math
310
import numpy as np
411
from mpi4py import MPI
512

613
from pylops_mpi import DistributedArray, Partition
7-
from pylops_mpi.basicoperators.MatrixMultiply import SUMMAMatrixMult
14+
from pylops_mpi.basicoperators.MatrixMultiply import MPISUMMAMatrixMult
815

916
np.random.seed(42)
1017

@@ -15,12 +22,15 @@
1522

1623
P_prime = int(math.ceil(math.sqrt(nProcs)))
1724
C = int(math.ceil(nProcs / P_prime))
18-
assert P_prime * C >= nProcs
25+
26+
if P_prime * C < nProcs:
27+
print("No. of procs has to be a square number")
28+
exit(-1)
1929

2030
# matrix dims
21-
M = 32 # any M
22-
K = 32 # any K
23-
N = 35 # any N
31+
M = 32
32+
K = 32
33+
N = 35
2434

2535
blk_rows = int(math.ceil(M / P_prime))
2636
blk_cols = int(math.ceil(N / P_prime))
@@ -66,7 +76,7 @@
6676

6777
comm.Barrier()
6878

69-
Aop = SUMMAMatrixMult(A_p, N)
79+
Aop = MPISUMMAMatrixMult(A_p, N)
7080
col_lens = comm.allgather(my_own_cols)
7181
total_cols = np.add.reduce(col_lens, 0)
7282
x = DistributedArray(global_shape=K * total_cols,

pylops_mpi/basicoperators/MatrixMultiply.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
)
1212

1313

14-
class SUMMAMatrixMult(MPILinearOperator):
14+
class MPISUMMAMatrixMult(MPILinearOperator):
1515
def __init__(
1616
self,
1717
A: NDArray,
@@ -90,7 +90,6 @@ def _matvec(self, x: DistributedArray) -> DistributedArray:
9090
y = DistributedArray(global_shape= (self.M * self.dimsd[1]),
9191
local_shapes=[(self.M * c) for c in layer_col_lens],
9292
mask=mask,
93-
#axis=1,
9493
partition=Partition.SCATTER,
9594
dtype=self.dtype)
9695
y[:] = C_local.flatten()
@@ -134,7 +133,6 @@ def _rmatvec(self, x: DistributedArray) -> DistributedArray:
134133
global_shape=(self.K * self.dimsd[1]),
135134
local_shapes=[self.K * c for c in layer_col_lens],
136135
mask=mask,
137-
#axis=1
138136
partition=Partition.SCATTER,
139137
dtype=self.dtype,
140138
)

tests/test_matrixmult.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,9 @@
33
from numpy.testing import assert_allclose
44
from mpi4py import MPI
55
import math
6-
import sys
76

87
from pylops_mpi import DistributedArray, Partition
9-
from pylops_mpi.basicoperators.MatrixMultiply import SUMMAMatrixMult
8+
from pylops_mpi.basicoperators.MatrixMultiply import MPISUMMAMatrixMult
109

1110
np.random.seed(42)
1211

@@ -109,7 +108,7 @@ def test_SUMMAMatrixMult(M, K, N, dtype_str):
109108
comm.Barrier()
110109

111110
# Create SUMMAMatrixMult operator
112-
Aop = SUMMAMatrixMult(A_p, N, base_comm=comm, dtype=dtype_str)
111+
Aop = MPISUMMAMatrixMult(A_p, N, base_comm=comm, dtype=dtype_str)
113112

114113
# Create DistributedArray for input x (representing B flattened)
115114
all_my_own_cols_B = comm.allgather(my_own_cols_B)
@@ -166,7 +165,7 @@ def test_SUMMAMatrixMult(M, K, N, dtype_str):
166165
y_dist.local_array,
167166
expected_y_slice.ravel(),
168167
rtol=1e-14,
169-
atol=1e-7,
168+
atol=1e-14,
170169
err_msg=f"Rank {rank}: Forward verification failed."
171170
)
172171

@@ -183,7 +182,7 @@ def test_SUMMAMatrixMult(M, K, N, dtype_str):
183182
z_dist.local_array,
184183
expected_z_slice.ravel(),
185184
rtol=1e-14,
186-
atol=1e-7,
185+
atol=1e-14,
187186
err_msg=f"Rank {rank}: Adjoint verification failed."
188187
)
189188

0 commit comments

Comments
 (0)