Skip to content

Commit a88dec3

Browse files
committed
More minor changes
1 parent 740030d commit a88dec3

File tree

2 files changed

+29
-33
lines changed

2 files changed

+29
-33
lines changed

pylops_mpi/basicoperators/MatrixMult.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,13 +91,15 @@ def _rmatvec(self, x: DistributedArray) -> DistributedArray:
9191
ncp = get_module(x.engine)
9292
if x.partition != Partition.SCATTER:
9393
raise ValueError(f"x should have partition={Partition.SCATTER}. Got {x.partition} instead.")
94+
9495
y = DistributedArray(
9596
global_shape=(self.K * self.dimsd[1]),
9697
local_shapes=[self.K * c for c in self._rank_col_lens],
9798
mask=x.mask,
9899
partition=Partition.SCATTER,
99100
dtype=self.dtype,
100101
)
102+
101103
x_arr = x.local_array.reshape((self.M, self._local_ncols)).astype(self.dtype)
102104
X_tile = x_arr[self._row_start:self._row_end, :]
103105

tests/test_matrixmult.py

Lines changed: 27 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -35,27 +35,26 @@ def test_SUMMAMatrixMult(M, K, N, dtype_str):
3535
cmplx = 1j if np.issubdtype(dtype, np.complexfloating) else 0
3636
base_float_dtype = np.float32 if dtype == np.complex64 else np.float64
3737

38-
P_prime = int(math.ceil(math.sqrt(size)))
39-
C = int(math.ceil(size / P_prime))
40-
assert P_prime * C >= size # Ensure process grid covers all processes
38+
p_prime = int(math.ceil(math.sqrt(size)))
39+
C = int(math.ceil(size / p_prime))
40+
assert p_prime * C == size
4141

42-
my_group = rank % P_prime
43-
my_layer = rank // P_prime
42+
my_group = rank % p_prime
43+
my_layer = rank // p_prime
4444

4545
# Create sub-communicators
4646
layer_comm = comm.Split(color=my_layer, key=my_group)
4747
group_comm = comm.Split(color=my_group, key=my_layer)
4848

4949
# Calculate local matrix dimensions
50-
blk_rows_A = int(math.ceil(M / P_prime))
50+
blk_rows_A = int(math.ceil(M / p_prime))
5151
row_start_A = my_group * blk_rows_A
5252
row_end_A = min(M, row_start_A + blk_rows_A)
53-
my_own_rows_A = max(0, row_end_A - row_start_A)
5453

55-
blk_cols_BC = int(math.ceil(N / P_prime))
54+
blk_cols_BC = int(math.ceil(N / p_prime))
5655
col_start_B = my_layer * blk_cols_BC
5756
col_end_B = min(N, col_start_B + blk_cols_BC)
58-
my_own_cols_B = max(0, col_end_B - col_start_B)
57+
local_col_B_len = max(0, col_end_B - col_start_B)
5958

6059

6160
A_glob_real = np.arange(M * K, dtype=base_float_dtype).reshape(M, K)
@@ -73,33 +72,28 @@ def test_SUMMAMatrixMult(M, K, N, dtype_str):
7372
Aop = MPIMatrixMult(A_p, N, base_comm=comm, dtype=dtype_str)
7473

7574
# Create DistributedArray for input x (representing B flattened)
76-
all_my_own_cols_B = comm.allgather(my_own_cols_B)
77-
total_cols = np.sum(all_my_own_cols_B)
75+
all_local_col_len = comm.allgather(local_col_B_len)
76+
total_cols = np.sum(all_local_col_len)
7877

7978
x_dist = DistributedArray(
8079
global_shape=(K * total_cols),
81-
local_shapes=[K * cl_b for cl_b in all_my_own_cols_B],
80+
local_shapes=[K * cl_b for cl_b in all_local_col_len],
8281
partition=Partition.SCATTER,
8382
base_comm=comm,
83+
mask=[i // p_prime for i in range(size)],
8484
dtype=dtype
8585
)
8686

87-
if B_p.size > 0:
88-
x_dist.local_array[:] = B_p.ravel()
89-
else:
90-
assert x_dist.local_array.size == 0, (
91-
f"Rank {rank}: B_p empty but x_dist.local_array not empty "
92-
f"(size {x_dist.local_array.size})"
93-
)
87+
x_dist.local_array[:] = B_p.ravel()
9488

9589
# Forward operation: y = A @ B (distributed)
9690
y_dist = Aop @ x_dist
9791

98-
# Adjoint operation: z = A.H @ y (distributed y representing C)
99-
z_dist = Aop.H @ y_dist
92+
# Adjoint operation: xadj = A.H @ y (distributed)
93+
xadj_dist = Aop.H @ y_dist
10094

101-
C_true = A_glob @ B_glob
102-
Z_true = A_glob.conj().T @ C_true
95+
y_loc = A_glob @ B_glob
96+
xadj_loc = A_glob.conj().T @ y_loc
10397

10498
col_start_C_dist = my_layer * blk_cols_BC
10599
col_end_C_dist = min(N, col_start_C_dist + blk_cols_BC)
@@ -110,27 +104,27 @@ def test_SUMMAMatrixMult(M, K, N, dtype_str):
110104
f"Rank {rank}: y_dist shape {y_dist.local_array.shape} != expected {expected_y_shape}"
111105
)
112106

113-
if y_dist.local_array.size > 0 and C_true is not None and C_true.size > 0:
114-
expected_y_slice = C_true[:, col_start_C_dist:col_end_C_dist]
107+
if y_dist.local_array.size > 0 and y_loc is not None and y_loc.size > 0:
108+
expected_y_slice = y_loc[:, col_start_C_dist:col_end_C_dist]
115109
assert_allclose(
116110
y_dist.local_array,
117111
expected_y_slice.ravel(),
118112
rtol=np.finfo(np.dtype(dtype)).resolution,
119113
err_msg=f"Rank {rank}: Forward verification failed."
120114
)
121115

122-
# Verify adjoint operation (z = A.H @ y)
123-
expected_z_shape = (K * my_own_cols_C_dist,)
124-
assert z_dist.local_array.shape == expected_z_shape, (
125-
f"Rank {rank}: z_dist shape {z_dist.local_array.shape} != expected {expected_z_shape}"
116+
# Verify adjoint operation (xadj = A.H @ y)
117+
expected_xadj_shape = (K * my_own_cols_C_dist,)
118+
assert xadj_dist.local_array.shape == expected_xadj_shape, (
119+
f"Rank {rank}: z_dist shape {xadj_dist.local_array.shape} != expected {expected_xadj_shape}"
126120
)
127121

128122
# Verify adjoint result values
129-
if z_dist.local_array.size > 0 and Z_true is not None and Z_true.size > 0:
130-
expected_z_slice = Z_true[:, col_start_C_dist:col_end_C_dist]
123+
if xadj_dist.local_array.size > 0 and xadj_loc is not None and xadj_loc .size > 0:
124+
expected_xadj_slice = xadj_loc [:, col_start_C_dist:col_end_C_dist]
131125
assert_allclose(
132-
z_dist.local_array,
133-
expected_z_slice.ravel(),
126+
xadj_dist.local_array,
127+
expected_xadj_slice.ravel(),
134128
rtol=np.finfo(np.dtype(dtype)).resolution,
135129
err_msg=f"Rank {rank}: Adjoint verification failed."
136130
)

0 commit comments

Comments
 (0)