|
28 | 28 |
|
29 | 29 | @pytest.mark.mpi(min_size=1) # SUMMA should also work for 1 process. |
30 | 30 | @pytest.mark.parametrize("M, K, N, dtype_str", test_params) |
31 | | -def test_SUMMAMatrixMult(M, K, N, dtype_str): |
| 31 | +def test_MPIMatrixMult(M, K, N, dtype_str): |
32 | 32 | dtype = np.dtype(dtype_str) |
33 | 33 |
|
34 | 34 | cmplx = 1j if np.issubdtype(dtype, np.complexfloating) else 0 |
@@ -133,56 +133,26 @@ def test_SUMMAMatrixMult(M, K, N, dtype_str): |
133 | 133 |
|
134 | 134 | # Forward operation: y = A @ B (distributed) |
135 | 135 | y_dist = Aop @ x_dist |
| 136 | + y = y_dist.asarray(), |
136 | 137 |
|
137 | 138 | # Adjoint operation: z = A.H @ y (distributed y representing C) |
138 | | - z_dist = Aop.H @ y_dist |
| 139 | + y_adj_dist = Aop.H @ y_dist |
| 140 | + y_adj = y_adj_dist.asarray() |
139 | 141 |
|
140 | 142 | if rank == 0: |
141 | | - if all(dim > 0 for dim in [M, K, N]): |
142 | | - C_true = A_glob @ B_glob |
143 | | - Z_true = A_glob.conj().T @ C_true |
144 | | - else: # Handle cases with zero dimensions |
145 | | - C_true = np.zeros((M, N), dtype=dtype) |
146 | | - Z_true = np.zeros((K if K > 0 else 0, N), dtype=dtype) if K > 0 else np.zeros((0, N), dtype=dtype) |
147 | | - else: |
148 | | - C_true = Z_true = None |
149 | | - |
150 | | - C_true = comm.bcast(C_true, root=0) |
151 | | - Z_true = comm.bcast(Z_true, root=0) |
152 | | - |
153 | | - col_start_C_dist = my_layer * blk_cols_BC |
154 | | - col_end_C_dist = min(N, col_start_C_dist + blk_cols_BC) |
155 | | - my_own_cols_C_dist = max(0, col_end_C_dist - col_start_C_dist) |
156 | | - expected_y_shape = (M * my_own_cols_C_dist,) |
157 | | - |
158 | | - assert y_dist.local_array.shape == expected_y_shape, ( |
159 | | - f"Rank {rank}: y_dist shape {y_dist.local_array.shape} != expected {expected_y_shape}" |
160 | | - ) |
161 | | - |
162 | | - if y_dist.local_array.size > 0 and C_true is not None and C_true.size > 0: |
163 | | - expected_y_slice = C_true[:, col_start_C_dist:col_end_C_dist] |
| 143 | + y_np = A_glob @ B_glob |
| 144 | + y_adj_np = A_glob.conj().T @ y_np |
164 | 145 | assert_allclose( |
165 | | - y_dist.local_array, |
166 | | - expected_y_slice.ravel(), |
| 146 | + y, |
| 147 | + y_np.ravel(), |
167 | 148 | rtol=1e-14, |
168 | | - atol=1e-14, |
169 | 149 | err_msg=f"Rank {rank}: Forward verification failed." |
170 | 150 | ) |
171 | 151 |
|
172 | | - # Verify adjoint operation (z = A.H @ y) |
173 | | - expected_z_shape = (K * my_own_cols_C_dist,) |
174 | | - assert z_dist.local_array.shape == expected_z_shape, ( |
175 | | - f"Rank {rank}: z_dist shape {z_dist.local_array.shape} != expected {expected_z_shape}" |
176 | | - ) |
177 | | - |
178 | | - # Verify adjoint result values |
179 | | - if z_dist.local_array.size > 0 and Z_true is not None and Z_true.size > 0: |
180 | | - expected_z_slice = Z_true[:, col_start_C_dist:col_end_C_dist] |
181 | 152 | assert_allclose( |
182 | | - z_dist.local_array, |
183 | | - expected_z_slice.ravel(), |
| 153 | + y_adj, |
| 154 | + y_adj_np.ravel(), |
184 | 155 | rtol=1e-14, |
185 | | - atol=1e-14, |
186 | 156 | err_msg=f"Rank {rank}: Adjoint verification failed." |
187 | 157 | ) |
188 | 158 |
|
|
0 commit comments