Skip to content

Commit ef3c283

Browse files
committed
changed tests
1 parent 18db078 commit ef3c283

File tree

2 files changed

+23
-42
lines changed

2 files changed

+23
-42
lines changed

examples/plot_matrixmult.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -182,9 +182,6 @@
182182
# the returned 1D-arrays into 2D-arrays.
183183

184184
# Local benchmarks
185-
y_loc = A @ X
186-
xadj_loc = (A.T.dot(y_loc.conj())).conj()
187-
188185
y = y.asarray(masked=True)
189186
if N > 1:
190187
y = y.reshape(p_prime, M, blk_cols)

tests/test_matrixmult.py

Lines changed: 23 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -51,28 +51,27 @@ def test_SUMMAMatrixMult(M, K, N, dtype_str):
5151
row_start_A = my_group * blk_rows_A
5252
row_end_A = min(M, row_start_A + blk_rows_A)
5353

54-
blk_cols_BC = int(math.ceil(N / p_prime))
55-
col_start_B = my_layer * blk_cols_BC
56-
col_end_B = min(N, col_start_B + blk_cols_BC)
57-
local_col_B_len = max(0, col_end_B - col_start_B)
58-
54+
blk_cols_X = int(math.ceil(N / p_prime))
55+
col_start_X = my_layer * blk_cols_X
56+
col_end_X = min(N, col_start_X + blk_cols_X)
57+
local_col_X_len = max(0, col_end_X - col_start_X)
5958

6059
A_glob_real = np.arange(M * K, dtype=base_float_dtype).reshape(M, K)
6160
A_glob_imag = np.arange(M * K, dtype=base_float_dtype).reshape(M, K) * 0.5
6261
A_glob = (A_glob_real + cmplx * A_glob_imag).astype(dtype)
6362

64-
B_glob_real = np.arange(K * N, dtype=base_float_dtype).reshape(K, N)
65-
B_glob_imag = np.arange(K * N, dtype=base_float_dtype).reshape(K, N) * 0.7
66-
B_glob = (B_glob_real + cmplx * B_glob_imag).astype(dtype)
63+
X_glob_real = np.arange(K * N, dtype=base_float_dtype).reshape(K, N)
64+
X_glob_imag = np.arange(K * N, dtype=base_float_dtype).reshape(K, N) * 0.7
65+
X_glob = (X_glob_real + cmplx * X_glob_imag).astype(dtype)
6766

6867
A_p = A_glob[row_start_A:row_end_A,:]
69-
B_p = B_glob[:,col_start_B:col_end_B]
68+
X_p = X_glob[:,col_start_X:col_end_X]
7069

71-
# Create SUMMAMatrixMult operator
70+
# Create MPIMatrixMult operator
7271
Aop = MPIMatrixMult(A_p, N, base_comm=comm, dtype=dtype_str)
7372

7473
# Create DistributedArray for input x (representing B flattened)
75-
all_local_col_len = comm.allgather(local_col_B_len)
74+
all_local_col_len = comm.allgather(local_col_X_len)
7675
total_cols = np.sum(all_local_col_len)
7776

7877
x_dist = DistributedArray(
@@ -84,49 +83,34 @@ def test_SUMMAMatrixMult(M, K, N, dtype_str):
8483
dtype=dtype
8584
)
8685

87-
x_dist.local_array[:] = B_p.ravel()
86+
x_dist.local_array[:] = X_p.ravel()
8887

8988
# Forward operation: y = A @ B (distributed)
9089
y_dist = Aop @ x_dist
91-
9290
# Adjoint operation: xadj = A.H @ y (distributed)
9391
xadj_dist = Aop.H @ y_dist
9492

95-
y_loc = A_glob @ B_glob
96-
xadj_loc = A_glob.conj().T @ y_loc
97-
98-
col_start_C_dist = my_layer * blk_cols_BC
99-
col_end_C_dist = min(N, col_start_C_dist + blk_cols_BC)
100-
my_own_cols_C_dist = max(0, col_end_C_dist - col_start_C_dist)
101-
expected_y_shape = (M * my_own_cols_C_dist,)
93+
y = y_dist.asarray(masked=True)
94+
y = y.reshape(p_prime, M, blk_cols_X)
10295

103-
assert y_dist.local_array.shape == expected_y_shape, (
104-
f"Rank {rank}: y_dist shape {y_dist.local_array.shape} != expected {expected_y_shape}"
105-
)
96+
xadj = xadj_dist.asarray(masked=True)
97+
xadj = xadj.reshape(p_prime, K, blk_cols_X)
10698

107-
if y_dist.local_array.size > 0 and y_loc is not None and y_loc.size > 0:
108-
expected_y_slice = y_loc[:, col_start_C_dist:col_end_C_dist]
99+
if rank == 0:
100+
y_loc = (A_glob @ X_glob).squeeze()
109101
assert_allclose(
110-
y_dist.local_array,
111-
expected_y_slice.ravel(),
102+
y,
103+
y_loc,
112104
rtol=np.finfo(np.dtype(dtype)).resolution,
113105
err_msg=f"Rank {rank}: Forward verification failed."
114106
)
115107

116-
# Verify adjoint operation (xadj = A.H @ y)
117-
expected_xadj_shape = (K * my_own_cols_C_dist,)
118-
assert xadj_dist.local_array.shape == expected_xadj_shape, (
119-
f"Rank {rank}: z_dist shape {xadj_dist.local_array.shape} != expected {expected_xadj_shape}"
120-
)
121-
122-
# Verify adjoint result values
123-
if xadj_dist.local_array.size > 0 and xadj_loc is not None and xadj_loc .size > 0:
124-
expected_xadj_slice = xadj_loc [:, col_start_C_dist:col_end_C_dist]
108+
xadj_loc = (A_glob.conj().T @ y_loc.conj()).conj().squeeze()
125109
assert_allclose(
126-
xadj_dist.local_array,
127-
expected_xadj_slice.ravel(),
110+
xadj,
111+
xadj_loc,
128112
rtol=np.finfo(np.dtype(dtype)).resolution,
129-
err_msg=f"Rank {rank}: Adjoint verification failed."
113+
err_msg=f"Rank {rank}: Ajoint verification failed."
130114
)
131115

132116
group_comm.Free()

0 commit comments

Comments
 (0)