Skip to content

Commit 82b7e34

Browse files
committed
Addressed more issues
1 parent de1a173 commit 82b7e34

File tree

3 files changed

+18
-47
lines changed

3 files changed

+18
-47
lines changed

examples/plot_matrixmult.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@
6363
re = min(M, rs + blk_rows)
6464
cs = pg * blk_cols;
6565
ce = min(N, cs + blk_cols)
66-
a_block, b_block = A[rs:re, :].copy(), B[:, cs:ce].copy()
66+
a_block, b_block = A[rs:re, :], B[:, cs:ce]
6767
if dest == 0:
6868
A_p, B_p = a_block, b_block
6969
else:
@@ -79,7 +79,7 @@
7979

8080
Aop = MPISUMMAMatrixMult(A_p, N)
8181
col_lens = comm.allgather(my_own_cols)
82-
total_cols = np.add.reduce(col_lens, 0)
82+
total_cols = np.sum(col_lens)
8383
x = DistributedArray(global_shape=K * total_cols,
8484
local_shapes=[K * col_len for col_len in col_lens],
8585
partition=Partition.SCATTER,
@@ -99,16 +99,17 @@
9999
my_own_cols = col_end - col_start
100100
expected_y = C_true[:, col_start:col_end].flatten()
101101

102+
xadj = Aop.H @ y
103+
102104
if not np.allclose(y.local_array, expected_y, atol=1e-6, rtol=1e-14):
103105
print(f"RANK {rank}: FORWARD VERIFICATION FAILED")
104106
print(f'{rank} local: {y.local_array}, expected: {C_true[:, col_start:col_end]}')
105107
else:
106108
print(f"RANK {rank}: FORWARD VERIFICATION PASSED")
107109

108-
z = Aop.H @ y
109110
expected_z = Z_true[:, col_start:col_end].flatten()
110-
if not np.allclose(z.local_array, expected_z, atol=1e-6, rtol=1e-14):
111+
if not np.allclose(xadj.local_array, expected_z, atol=1e-6, rtol=1e-14):
111112
print(f"RANK {rank}: ADJOINT VERIFICATION FAILED")
112-
print(f'{rank} local: {z.local_array}, expected: {Z_true[:, col_start:col_end]}')
113+
print(f'{rank} local: {xadj.local_array}, expected: {Z_true[:, col_start:col_end]}')
113114
else:
114115
print(f"RANK {rank}: ADJOINT VERIFICATION PASSED")

pylops_mpi/basicoperators/MatrixMult.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@ def __init__(
2525
# Determine grid dimensions (P_prime × C) such that P_prime * C ≥ size
2626
self._P_prime = int(math.ceil(math.sqrt(size)))
2727
self._C = int(math.ceil(size / self._P_prime))
28-
assert self._P_prime * self._C >= size
28+
if self._P_prime * self._C < size:
29+
raise Exception("Number of Procs must be a square number")
2930

3031
# Compute this process's group and layer indices
3132
self._group_id = rank % self._P_prime
@@ -36,7 +37,6 @@ def __init__(
3637
self._layer_comm = base_comm.Split(color=self._layer_id, key=self._group_id)
3738
self._group_comm = base_comm.Split(color=self._group_id, key=self._layer_id)
3839

39-
self.dtype = np.dtype(dtype)
4040
self.A = np.array(A, dtype=self.dtype, copy=False)
4141

4242
self.M = self._layer_comm.allreduce(self.A.shape[0], op=MPI.SUM)

tests/test_matrixmult.py

Lines changed: 10 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828

2929
@pytest.mark.mpi(min_size=1) # SUMMA should also work for 1 process.
3030
@pytest.mark.parametrize("M, K, N, dtype_str", test_params)
31-
def test_SUMMAMatrixMult(M, K, N, dtype_str):
31+
def test_MPIMatrixMult(M, K, N, dtype_str):
3232
dtype = np.dtype(dtype_str)
3333

3434
cmplx = 1j if np.issubdtype(dtype, np.complexfloating) else 0
@@ -133,56 +133,26 @@ def test_SUMMAMatrixMult(M, K, N, dtype_str):
133133

134134
# Forward operation: y = A @ B (distributed)
135135
y_dist = Aop @ x_dist
136+
y = y_dist.asarray(),
136137

137138
# Adjoint operation: z = A.H @ y (distributed y representing C)
138-
z_dist = Aop.H @ y_dist
139+
y_adj_dist = Aop.H @ y_dist
140+
y_adj = y_adj_dist.asarray()
139141

140142
if rank == 0:
141-
if all(dim > 0 for dim in [M, K, N]):
142-
C_true = A_glob @ B_glob
143-
Z_true = A_glob.conj().T @ C_true
144-
else: # Handle cases with zero dimensions
145-
C_true = np.zeros((M, N), dtype=dtype)
146-
Z_true = np.zeros((K if K > 0 else 0, N), dtype=dtype) if K > 0 else np.zeros((0, N), dtype=dtype)
147-
else:
148-
C_true = Z_true = None
149-
150-
C_true = comm.bcast(C_true, root=0)
151-
Z_true = comm.bcast(Z_true, root=0)
152-
153-
col_start_C_dist = my_layer * blk_cols_BC
154-
col_end_C_dist = min(N, col_start_C_dist + blk_cols_BC)
155-
my_own_cols_C_dist = max(0, col_end_C_dist - col_start_C_dist)
156-
expected_y_shape = (M * my_own_cols_C_dist,)
157-
158-
assert y_dist.local_array.shape == expected_y_shape, (
159-
f"Rank {rank}: y_dist shape {y_dist.local_array.shape} != expected {expected_y_shape}"
160-
)
161-
162-
if y_dist.local_array.size > 0 and C_true is not None and C_true.size > 0:
163-
expected_y_slice = C_true[:, col_start_C_dist:col_end_C_dist]
143+
y_np = A_glob @ B_glob
144+
y_adj_np = A_glob.conj().T @ y_np
164145
assert_allclose(
165-
y_dist.local_array,
166-
expected_y_slice.ravel(),
146+
y,
147+
y_np.ravel(),
167148
rtol=1e-14,
168-
atol=1e-14,
169149
err_msg=f"Rank {rank}: Forward verification failed."
170150
)
171151

172-
# Verify adjoint operation (z = A.H @ y)
173-
expected_z_shape = (K * my_own_cols_C_dist,)
174-
assert z_dist.local_array.shape == expected_z_shape, (
175-
f"Rank {rank}: z_dist shape {z_dist.local_array.shape} != expected {expected_z_shape}"
176-
)
177-
178-
# Verify adjoint result values
179-
if z_dist.local_array.size > 0 and Z_true is not None and Z_true.size > 0:
180-
expected_z_slice = Z_true[:, col_start_C_dist:col_end_C_dist]
181152
assert_allclose(
182-
z_dist.local_array,
183-
expected_z_slice.ravel(),
153+
y_adj,
154+
y_adj_np.ravel(),
184155
rtol=1e-14,
185-
atol=1e-14,
186156
err_msg=f"Rank {rank}: Adjoint verification failed."
187157
)
188158

0 commit comments

Comments
 (0)