Skip to content

Commit 9e1a49f

Browse files
committed
Addressed comments
1 parent 82b7e34 commit 9e1a49f

File tree

3 files changed

+84
-85
lines changed

3 files changed

+84
-85
lines changed

examples/plot_matrixmult.py

Lines changed: 33 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from mpi4py import MPI
1212

1313
from pylops_mpi import DistributedArray, Partition
14-
from pylops_mpi.basicoperators.MatrixMult import MPISUMMAMatrixMult
14+
from pylops_mpi.basicoperators.MatrixMult import MPIMatrixMult
1515

1616
np.random.seed(42)
1717

@@ -22,17 +22,17 @@
2222
P_prime = int(math.ceil(math.sqrt(n_procs)))
2323
C = int(math.ceil(n_procs / P_prime))
2424

25-
if P_prime * C < n_procs:
25+
if (P_prime * C) != n_procs:
2626
print("No. of procs has to be a square number")
2727
exit(-1)
2828

2929
# matrix dims
3030
M = 32
31-
K = 32
32-
N = 35
31+
K = 35
32+
N = 37
3333

34-
blk_rows = int(math.ceil(M / P_prime))
35-
blk_cols = int(math.ceil(N / P_prime))
34+
A = np.random.rand(M * K).astype(dtype=np.float32).reshape(M, K)
35+
B = np.random.rand(K * N).astype(dtype=np.float32).reshape(K, N)
3636

3737
my_group = rank % P_prime
3838
my_layer = rank // P_prime
@@ -41,75 +41,59 @@
4141
layer_comm = comm.Split(color=my_layer, key=my_group) # all procs in same layer
4242
group_comm = comm.Split(color=my_group, key=my_layer) # all procs in same group
4343

44-
# Each rank will end up with:
45-
# A_p: shape (my_own_rows, K)
46-
# B_p: shape (K, my_own_cols)
47-
# where
44+
45+
#Each rank will end up with:
46+
# - :math:`A_{p} \in \mathbb{R}^{\text{my\_own\_rows}\times K}`
47+
# - :math:`B_{p} \in \mathbb{R}^{K\times \text{my\_own\_cols}}`
48+
# where
49+
blk_rows = int(math.ceil(M / P_prime))
4850
row_start = my_group * blk_rows
4951
row_end = min(M, row_start + blk_rows)
5052
my_own_rows = row_end - row_start
5153

52-
col_start = my_group * blk_cols # note: same my_group index on cols
54+
blk_cols = int(math.ceil(N / P_prime))
55+
col_start = my_layer * blk_cols
5356
col_end = min(N, col_start + blk_cols)
5457
my_own_cols = col_end - col_start
5558

56-
# ======================= BROADCASTING THE SLICES =======================
57-
if rank == 0:
58-
A = np.arange(M * K, dtype=np.float32).reshape(M, K)
59-
B = np.arange(K * N, dtype=np.float32).reshape(K, N)
60-
for dest in range(n_procs):
61-
pg = dest % P_prime
62-
rs = pg * blk_rows;
63-
re = min(M, rs + blk_rows)
64-
cs = pg * blk_cols;
65-
ce = min(N, cs + blk_cols)
66-
a_block, b_block = A[rs:re, :], B[:, cs:ce]
67-
if dest == 0:
68-
A_p, B_p = a_block, b_block
69-
else:
70-
comm.Send(a_block, dest=dest, tag=100 + dest)
71-
comm.Send(b_block, dest=dest, tag=200 + dest)
72-
else:
73-
A_p = np.empty((my_own_rows, K), dtype=np.float32)
74-
B_p = np.empty((K, my_own_cols), dtype=np.float32)
75-
comm.Recv(A_p, source=0, tag=100 + rank)
76-
comm.Recv(B_p, source=0, tag=200 + rank)
7759

78-
comm.Barrier()
60+
rs = (rank % P_prime) * blk_rows
61+
re = min(M, rs + blk_rows)
7962

80-
Aop = MPISUMMAMatrixMult(A_p, N)
63+
cs = (rank // P_prime) * blk_cols
64+
ce = min(N, cs + blk_cols)
65+
A_p, B_p = A[rs:re, :].copy(), B[:, cs:ce].copy()
66+
67+
Aop = MPIMatrixMult(A_p, N, dtype="float32")
8168
col_lens = comm.allgather(my_own_cols)
8269
total_cols = np.sum(col_lens)
8370
x = DistributedArray(global_shape=K * total_cols,
8471
local_shapes=[K * col_len for col_len in col_lens],
8572
partition=Partition.SCATTER,
8673
mask=[i % P_prime for i in range(comm.Get_size())],
87-
dtype=np.float32)
74+
base_comm=comm,
75+
dtype="float32")
8876
x[:] = B_p.flatten()
8977
y = Aop @ x
9078

9179
# ======================= VERIFICATION =================-=============
92-
A = np.arange(M * K).reshape(M, K).astype(np.float32)
93-
B = np.arange(K * N).reshape(K, N).astype(np.float32)
94-
C_true = A @ B
95-
Z_true = (A.T.dot(C_true.conj())).conj()
80+
y_loc = A @ B
81+
xadj_loc = (A.T.dot(y_loc.conj())).conj()
9682

97-
col_start = my_layer * blk_cols # note: same my_group index on cols
98-
col_end = min(N, col_start + blk_cols)
99-
my_own_cols = col_end - col_start
100-
expected_y = C_true[:, col_start:col_end].flatten()
10183

102-
xadj = Aop.H @ y
84+
expected_y_loc = y_loc[:, col_start:col_end].flatten().astype(np.float32)
85+
expected_xadj_loc = xadj_loc[:, col_start:col_end].flatten().astype(np.float32)
10386

104-
if not np.allclose(y.local_array, expected_y, atol=1e-6, rtol=1e-14):
87+
xadj = Aop.H @ y
88+
if not np.allclose(y.local_array, expected_y_loc, rtol=1e-6):
10589
print(f"RANK {rank}: FORWARD VERIFICATION FAILED")
106-
print(f'{rank} local: {y.local_array}, expected: {C_true[:, col_start:col_end]}')
90+
print(f'{rank} local: {y.local_array}, expected: {y_loc[:, col_start:col_end]}')
10791
else:
10892
print(f"RANK {rank}: FORWARD VERIFICATION PASSED")
10993

110-
expected_z = Z_true[:, col_start:col_end].flatten()
111-
if not np.allclose(xadj.local_array, expected_z, atol=1e-6, rtol=1e-14):
94+
if not np.allclose(xadj.local_array, expected_xadj_loc, rtol=1e-6):
11295
print(f"RANK {rank}: ADJOINT VERIFICATION FAILED")
113-
print(f'{rank} local: {xadj.local_array}, expected: {Z_true[:, col_start:col_end]}')
96+
print(f'{rank} local: {xadj.local_array}, expected: {xadj_loc[:, col_start:col_end]}')
11497
else:
11598
print(f"RANK {rank}: ADJOINT VERIFICATION PASSED")
99+

pylops_mpi/basicoperators/MatrixMult.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
)
1212

1313

14-
class MPISUMMAMatrixMult(MPILinearOperator):
14+
class MPIMatrixMult(MPILinearOperator):
1515
def __init__(
1616
self,
1717
A: NDArray,
@@ -36,8 +36,7 @@ def __init__(
3636
self.base_comm = base_comm
3737
self._layer_comm = base_comm.Split(color=self._layer_id, key=self._group_id)
3838
self._group_comm = base_comm.Split(color=self._group_id, key=self._layer_id)
39-
40-
self.A = np.array(A, dtype=self.dtype, copy=False)
39+
self.A = A.astype(np.dtype(dtype))
4140

4241
self.M = self._layer_comm.allreduce(self.A.shape[0], op=MPI.SUM)
4342
self.K = A.shape[1]
@@ -61,19 +60,19 @@ def __init__(
6160

6261
self.dimsd = (self.M, total_layer_cols)
6362
shape = (int(np.prod(self.dimsd)), int(np.prod(self.dims)))
64-
6563
super().__init__(shape=shape, dtype=np.dtype(dtype), base_comm=base_comm)
6664

6765
def _matvec(self, x: DistributedArray) -> DistributedArray:
6866
ncp = get_module(x.engine)
6967
if x.partition != Partition.SCATTER:
7068
raise ValueError(f"x should have partition={Partition.SCATTER} Got {x.partition} instead...")
7169
blk_cols = int(math.ceil(self.N / self._P_prime))
72-
col_start = self._group_id * blk_cols
70+
col_start = self._layer_id * blk_cols
7371
col_end = min(self.N, col_start + blk_cols)
7472
my_own_cols = max(0, col_end - col_start)
7573
x = x.local_array.reshape((self.dims[0], my_own_cols))
76-
x = x.astype(self.dtype, copy=False)
74+
x = x.astype(self.dtype)
75+
7776
B_block = self._layer_comm.bcast(x if self._group_id == self._layer_id else None, root=self._layer_id)
7877
C_local = ncp.vstack(
7978
self._layer_comm.allgather(
@@ -106,24 +105,24 @@ def _rmatvec(self, x: DistributedArray) -> DistributedArray:
106105
layer_col_end = min(self.N, layer_col_start + blk_cols)
107106
layer_ncols = layer_col_end - layer_col_start
108107
layer_col_lens = self.base_comm.allgather(layer_ncols)
109-
x = x.local_array.reshape((self.M, layer_ncols))
108+
x = x.local_array.reshape((self.M, layer_ncols)).astype(self.dtype)
110109

111110
# Determine local row block for this process group
112111
blk_rows = int(math.ceil(self.M / self._P_prime))
113112
row_start = self._group_id * blk_rows
114113
row_end = min(self.M, row_start + blk_rows)
115114

116-
B_tile = x[row_start:row_end, :].astype(self.dtype, copy=False)
117-
A_local = self.A.T.conj()
115+
B_tile = x[row_start:row_end, :].astype(self.dtype)
116+
A_local = self.A.T.conj().astype(self.dtype)
118117

119118
m, b = A_local.shape
120119
pad = (-m) % self._P_prime
121120
r = (m + pad) // self._P_prime
122-
A_pad = np.pad(A_local, ((0, pad), (0, 0)), mode='constant', constant_values=0)
121+
A_pad = np.pad(A_local, ((0, pad), (0, 0)), mode='constant', constant_values=self.dtype.type(0.0))
123122
A_batch = A_pad.reshape(self._P_prime, r, b)
124123

125124
# Perform local matmul and unpad
126-
Y_batch = ncp.matmul(A_batch, B_tile)
125+
Y_batch = ncp.matmul(A_batch, B_tile).astype(self.dtype)
127126
Y_pad = Y_batch.reshape(r * self._P_prime, -1)
128127
y_local = Y_pad[:m, :]
129128
y_layer = self._layer_comm.allreduce(y_local, op=MPI.SUM)

tests/test_matrixmult.py

Lines changed: 41 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@
33
from numpy.testing import assert_allclose
44
from mpi4py import MPI
55
import math
6+
import sys
67

78
from pylops_mpi import DistributedArray, Partition
8-
from pylops_mpi.basicoperators.MatrixMult import MPISUMMAMatrixMult
9+
from pylops_mpi.basicoperators.MatrixMult import MPIMatrixMult
910

1011
np.random.seed(42)
1112

@@ -28,7 +29,7 @@
2829

2930
@pytest.mark.mpi(min_size=1) # SUMMA should also work for 1 process.
3031
@pytest.mark.parametrize("M, K, N, dtype_str", test_params)
31-
def test_MPIMatrixMult(M, K, N, dtype_str):
32+
def test_SUMMAMatrixMult(M, K, N, dtype_str):
3233
dtype = np.dtype(dtype_str)
3334

3435
cmplx = 1j if np.issubdtype(dtype, np.complexfloating) else 0
@@ -60,18 +61,15 @@ def test_MPIMatrixMult(M, K, N, dtype_str):
6061
A_p = np.empty((my_own_rows_A, K), dtype=dtype)
6162
B_p = np.empty((K, my_own_cols_B), dtype=dtype)
6263

63-
# Generate and distribute test matrices
64-
A_glob, B_glob = None, None
65-
if rank == 0:
66-
# Create global matrices with complex components if needed
67-
A_glob_real = np.arange(M * K, dtype=base_float_dtype).reshape(M, K)
68-
A_glob_imag = np.arange(M * K, dtype=base_float_dtype).reshape(M, K) * 0.5
69-
A_glob = (A_glob_real + cmplx * A_glob_imag).astype(dtype)
64+
A_glob_real = np.arange(M * K, dtype=base_float_dtype).reshape(M, K)
65+
A_glob_imag = np.arange(M * K, dtype=base_float_dtype).reshape(M, K) * 0.5
66+
A_glob = (A_glob_real + cmplx * A_glob_imag).astype(dtype)
7067

71-
B_glob_real = np.arange(K * N, dtype=base_float_dtype).reshape(K, N)
72-
B_glob_imag = np.arange(K * N, dtype=base_float_dtype).reshape(K, N) * 0.7
73-
B_glob = (B_glob_real + cmplx * B_glob_imag).astype(dtype)
68+
B_glob_real = np.arange(K * N, dtype=base_float_dtype).reshape(K, N)
69+
B_glob_imag = np.arange(K * N, dtype=base_float_dtype).reshape(K, N) * 0.7
70+
B_glob = (B_glob_real + cmplx * B_glob_imag).astype(dtype)
7471

72+
if rank == 0:
7573
# Distribute matrix blocks to all ranks
7674
for dest_rank in range(size):
7775
dest_my_group = dest_rank % P_prime
@@ -108,7 +106,7 @@ def test_MPIMatrixMult(M, K, N, dtype_str):
108106
comm.Barrier()
109107

110108
# Create SUMMAMatrixMult operator
111-
Aop = MPISUMMAMatrixMult(A_p, N, base_comm=comm, dtype=dtype_str)
109+
Aop = MPIMatrixMult(A_p, N, base_comm=comm, dtype=dtype_str)
112110

113111
# Create DistributedArray for input x (representing B flattened)
114112
all_my_own_cols_B = comm.allgather(my_own_cols_B)
@@ -133,26 +131,44 @@ def test_MPIMatrixMult(M, K, N, dtype_str):
133131

134132
# Forward operation: y = A @ B (distributed)
135133
y_dist = Aop @ x_dist
136-
y = y_dist.asarray(),
137134

138135
# Adjoint operation: z = A.H @ y (distributed y representing C)
139-
y_adj_dist = Aop.H @ y_dist
140-
y_adj = y_adj_dist.asarray()
136+
z_dist = Aop.H @ y_dist
141137

142-
if rank == 0:
143-
y_np = A_glob @ B_glob
144-
y_adj_np = A_glob.conj().T @ y_np
138+
C_true = A_glob @ B_glob
139+
Z_true = A_glob.conj().T @ C_true
140+
141+
col_start_C_dist = my_layer * blk_cols_BC
142+
col_end_C_dist = min(N, col_start_C_dist + blk_cols_BC)
143+
my_own_cols_C_dist = max(0, col_end_C_dist - col_start_C_dist)
144+
expected_y_shape = (M * my_own_cols_C_dist,)
145+
146+
assert y_dist.local_array.shape == expected_y_shape, (
147+
f"Rank {rank}: y_dist shape {y_dist.local_array.shape} != expected {expected_y_shape}"
148+
)
149+
150+
if y_dist.local_array.size > 0 and C_true is not None and C_true.size > 0:
151+
expected_y_slice = C_true[:, col_start_C_dist:col_end_C_dist]
145152
assert_allclose(
146-
y,
147-
y_np.ravel(),
148-
rtol=1e-14,
153+
y_dist.local_array,
154+
expected_y_slice.ravel(),
155+
rtol=np.finfo(np.dtype(dtype)).resolution,
149156
err_msg=f"Rank {rank}: Forward verification failed."
150157
)
151158

159+
# Verify adjoint operation (z = A.H @ y)
160+
expected_z_shape = (K * my_own_cols_C_dist,)
161+
assert z_dist.local_array.shape == expected_z_shape, (
162+
f"Rank {rank}: z_dist shape {z_dist.local_array.shape} != expected {expected_z_shape}"
163+
)
164+
165+
# Verify adjoint result values
166+
if z_dist.local_array.size > 0 and Z_true is not None and Z_true.size > 0:
167+
expected_z_slice = Z_true[:, col_start_C_dist:col_end_C_dist]
152168
assert_allclose(
153-
y_adj,
154-
y_adj_np.ravel(),
155-
rtol=1e-14,
169+
z_dist.local_array,
170+
expected_z_slice.ravel(),
171+
rtol=np.finfo(np.dtype(dtype)).resolution,
156172
err_msg=f"Rank {rank}: Adjoint verification failed."
157173
)
158174

0 commit comments

Comments
 (0)