Skip to content

Commit 4e39068

Browse files
committed
Fixed tests and moved checks to root
1 parent ef3c283 commit 4e39068

File tree

2 files changed

+51
-20
lines changed

2 files changed

+51
-20
lines changed

examples/plot_matrixmult.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,8 @@
4141
rank = comm.Get_rank() # rank of current process
4242
size = comm.Get_size() # number of processes
4343

44-
p_prime = int(math.ceil(math.sqrt(size)))
45-
repl_factor = int(math.ceil(size / p_prime))
44+
p_prime = math.isqrt(size)
45+
repl_factor = p_prime
4646

4747
if (p_prime * repl_factor) != size:
4848
print(f"Number of processes must be a square number, provided {size} instead...")
@@ -183,13 +183,27 @@
183183

184184
# Local benchmarks
185185
y = y.asarray(masked=True)
186-
if N > 1:
187-
y = y.reshape(p_prime, M, blk_cols)
188-
y = np.hstack([yblock for yblock in y])
186+
col_counts = [min(blk_cols, N - j * blk_cols) for j in range(p_prime)]
187+
y_blocks = []
188+
offset = 0
189+
for cnt in col_counts:
190+
block_size = M * cnt
191+
y_blocks.append(
192+
y[offset: offset + block_size].reshape(M, cnt)
193+
)
194+
offset += block_size
195+
y = np.hstack(y_blocks)
196+
189197
xadj = xadj.asarray(masked=True)
190-
if N > 1:
191-
xadj = xadj.reshape(p_prime, K, blk_cols)
192-
xadj = np.hstack([xadjblock for xadjblock in xadj])
198+
xadj_blocks = []
199+
offset = 0
200+
for cnt in col_counts:
201+
block_size = K * cnt
202+
xadj_blocks.append(
203+
xadj[offset: offset + block_size].reshape(K, cnt)
204+
)
205+
offset += block_size
206+
xadj = np.hstack(xadj_blocks)
193207

194208
if rank == 0:
195209
y_loc = (A @ X).squeeze()

tests/test_matrixmult.py

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,9 @@ def test_SUMMAMatrixMult(M, K, N, dtype_str):
3535
cmplx = 1j if np.issubdtype(dtype, np.complexfloating) else 0
3636
base_float_dtype = np.float32 if dtype == np.complex64 else np.float64
3737

38-
p_prime = int(math.ceil(math.sqrt(size)))
39-
C = int(math.ceil(size / p_prime))
40-
assert p_prime * C == size
38+
p_prime = math.isqrt(size)
39+
C = p_prime
40+
assert p_prime * C == size, f"Number of processes must be a square number, provided {size} instead..."
4141

4242
my_group = rank % p_prime
4343
my_layer = rank // p_prime
@@ -90,25 +90,42 @@ def test_SUMMAMatrixMult(M, K, N, dtype_str):
9090
# Adjoint operation: xadj = A.H @ y (distributed)
9191
xadj_dist = Aop.H @ y_dist
9292

93-
y = y_dist.asarray(masked=True)
94-
y = y.reshape(p_prime, M, blk_cols_X)
93+
y = y_dist.asarray(masked=True)
94+
col_counts = [min(blk_cols_X, N - j * blk_cols_X) for j in range(p_prime)]
95+
y_blocks = []
96+
offset = 0
97+
for cnt in col_counts:
98+
block_size = M * cnt
99+
y_blocks.append(
100+
y[offset: offset + block_size].reshape(M, cnt)
101+
)
102+
offset += block_size
103+
y = np.hstack(y_blocks)
95104

96105
xadj = xadj_dist.asarray(masked=True)
97-
xadj = xadj.reshape(p_prime, K, blk_cols_X)
106+
xadj_blocks = []
107+
offset = 0
108+
for cnt in col_counts:
109+
block_size = K * cnt
110+
xadj_blocks.append(
111+
xadj[offset: offset + block_size].reshape(K, cnt)
112+
)
113+
offset += block_size
114+
xadj = np.hstack(xadj_blocks)
98115

99116
if rank == 0:
100-
y_loc = (A_glob @ X_glob).squeeze()
117+
y_loc = A_glob @ X_glob
101118
assert_allclose(
102-
y,
103-
y_loc,
119+
y.squeeze(),
120+
y_loc.squeeze(),
104121
rtol=np.finfo(np.dtype(dtype)).resolution,
105122
err_msg=f"Rank {rank}: Forward verification failed."
106123
)
107124

108-
xadj_loc = (A_glob.conj().T @ y_loc.conj()).conj().squeeze()
125+
xadj_loc = A_glob.conj().T @ y_loc
109126
assert_allclose(
110-
xadj,
111-
xadj_loc,
127+
xadj.squeeze(),
128+
xadj_loc.squeeze(),
112129
rtol=np.finfo(np.dtype(dtype)).resolution,
113130
err_msg=f"Rank {rank}: Ajoint verification failed."
114131
)

0 commit comments

Comments
 (0)