Skip to content

Commit 19e873a

Browse files
authored
Merge pull request #155 from mrava87/test-matrixmultchain
test/doc: added tests and example of chaining MPIMatrixMult
2 parents 2c8bf53 + fc16f5d commit 19e873a

File tree

2 files changed

+91
-95
lines changed

2 files changed

+91
-95
lines changed

examples/plot_matrixmult.py

Lines changed: 30 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,10 @@
2424
import numpy as np
2525
from mpi4py import MPI
2626

27-
from pylops_mpi import DistributedArray, Partition
28-
from pylops_mpi.basicoperators.MatrixMult import MPIMatrixMult
27+
import pylops
28+
29+
import pylops_mpi
30+
from pylops_mpi import Partition
2931

3032
plt.close("all")
3133

@@ -86,7 +88,8 @@
8688
# than the row or columm ranks.
8789

8890
base_comm = MPI.COMM_WORLD
89-
comm, rank, row_id, col_id, is_active = MPIMatrixMult.active_grid_comm(base_comm, N, M)
91+
comm, rank, row_id, col_id, is_active = \
92+
pylops_mpi.MPIMatrixMult.active_grid_comm(base_comm, N, M)
9093
print(f"Process {base_comm.Get_rank()} is {'active' if is_active else 'inactive'}")
9194
if not is_active: exit(0)
9295

@@ -144,23 +147,24 @@
144147
################################################################################
145148
# We are now ready to create the :py:class:`pylops_mpi.basicoperators.MPIMatrixMult`
146149
# operator and the input matrix :math:`\mathbf{X}`
147-
Aop = MPIMatrixMult(A_p, M, base_comm=comm, dtype="float32")
150+
Aop = pylops_mpi.MPIMatrixMult(A_p, M, base_comm=comm, dtype="float32")
148151

149152
col_lens = comm.allgather(my_own_cols)
150153
total_cols = np.sum(col_lens)
151-
x = DistributedArray(global_shape=K * total_cols,
152-
local_shapes=[K * col_len for col_len in col_lens],
153-
partition=Partition.SCATTER,
154-
mask=[i % p_prime for i in range(comm.Get_size())],
155-
base_comm=comm,
156-
dtype="float32")
154+
x = pylops_mpi.DistributedArray(
155+
global_shape=K * total_cols,
156+
local_shapes=[K * col_len for col_len in col_lens],
157+
partition=Partition.SCATTER,
158+
mask=[i % p_prime for i in range(comm.Get_size())],
159+
base_comm=comm,
160+
dtype="float32")
157161
x[:] = X_p.flatten()
158162

159163
################################################################################
160-
# We can now apply the forward pass :math:`\mathbf{y} = \mathbf{Ax}` (which effectively
161-
# implements a distributed matrix-matrix multiplication :math:`Y = \mathbf{AX}`)
162-
# Note :math:`\mathbf{Y}` is distributed in the same way as the input
163-
# :math:`\mathbf{X}`.
164+
# We can now apply the forward pass :math:`\mathbf{y} = \mathbf{Ax}` (which
165+
# effectively implements a distributed matrix-matrix multiplication
166+
# :math:`Y = \mathbf{AX}`). Note :math:`\mathbf{Y}` is distributed in the same
167+
# way as the input :math:`\mathbf{X}`.
164168
y = Aop @ x
165169

166170
###############################################################################
@@ -172,52 +176,15 @@
172176
xadj = Aop.H @ y
173177

174178
###############################################################################
175-
# To conclude we verify our result against the equivalent serial version of
176-
# the operation by gathering the resulting matrices in rank0 and reorganizing
177-
# the returned 1D-arrays into 2D-arrays.
178-
179-
# Local benchmarks
180-
y = y.asarray(masked=True)
181-
col_counts = [min(blk_cols, M - j * blk_cols) for j in range(p_prime)]
182-
y_blocks = []
183-
offset = 0
184-
for cnt in col_counts:
185-
block_size = N * cnt
186-
y_block = y[offset: offset + block_size]
187-
if len(y_block) != 0:
188-
y_blocks.append(
189-
y_block.reshape(N, cnt)
190-
)
191-
offset += block_size
192-
y = np.hstack(y_blocks)
193-
194-
xadj = xadj.asarray(masked=True)
195-
xadj_blocks = []
196-
offset = 0
197-
for cnt in col_counts:
198-
block_size = K * cnt
199-
xadj_blk = xadj[offset: offset + block_size]
200-
if len(xadj_blk) != 0:
201-
xadj_blocks.append(
202-
xadj_blk.reshape(K, cnt)
203-
)
204-
offset += block_size
205-
xadj = np.hstack(xadj_blocks)
206-
207-
if rank == 0:
208-
y_loc = (A @ X).squeeze()
209-
xadj_loc = (A.T.dot(y_loc.conj())).conj().squeeze()
210-
211-
if not np.allclose(y, y_loc, rtol=1e-6):
212-
print("FORWARD VERIFICATION FAILED")
213-
print(f'distributed: {y}')
214-
print(f'expected: {y_loc}')
215-
else:
216-
print("FORWARD VERIFICATION PASSED")
217-
218-
if not np.allclose(xadj, xadj_loc, rtol=1e-6):
219-
print("ADJOINT VERIFICATION FAILED")
220-
print(f'distributed: {xadj}')
221-
print(f'expected: {xadj_loc}')
222-
else:
223-
print("ADJOINT VERIFICATION PASSED")
179+
# Finally, we show the :py:class:`pylops_mpi.basicoperators.MPIMatrixMult`
180+
# operator can be combined with any other PyLops-MPI operator. We are going to
181+
# apply here a first derivative along the first axis to the output of the matrix
182+
# multiplication. The only gotcha here is that one needs to be aware of the
183+
# ad-hoc distribution of the arrays that are fed to this operator and make
184+
# sure it is matched in the other operators involved in the chain.
185+
Dop = pylops.FirstDerivative(dims=(N, my_own_cols), axis=0,
186+
dtype=np.float32)
187+
DBop = pylops_mpi.MPIBlockDiag(ops=[Dop, ])
188+
Op = DBop @ Aop
189+
190+
y1 = Op @ x

tests/test_matrixmult.py

Lines changed: 61 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,9 @@
88
from mpi4py import MPI
99
import pytest
1010

11+
from pylops.basicoperators import FirstDerivative, Identity
1112
from pylops_mpi import DistributedArray, Partition
12-
from pylops_mpi.basicoperators.MatrixMult import MPIMatrixMult
13+
from pylops_mpi.basicoperators import MPIMatrixMult, MPIBlockDiag
1314

1415
np.random.seed(42)
1516
base_comm = MPI.COMM_WORLD
@@ -19,28 +20,48 @@
1920
# M, K, N are matrix dimensions A(N,K), B(K,M)
2021
# P_prime will be ceil(sqrt(size)).
2122
test_params = [
22-
pytest.param(37, 37, 37, "float32", id="f32_37_37_37"),
23+
pytest.param(37, 37, 37, "float64", id="f32_37_37_37"),
2324
pytest.param(50, 30, 40, "float64", id="f64_50_30_40"),
2425
pytest.param(22, 20, 16, "complex64", id="c64_22_20_16"),
2526
pytest.param(3, 4, 5, "float32", id="f32_3_4_5"),
2627
pytest.param(1, 2, 1, "float64", id="f64_1_2_1",),
2728
pytest.param(2, 1, 3, "float32", id="f32_2_1_3",),
2829
]
2930

31+
def _reorganize_local_matrix(x_dist, N, M, blk_cols, p_prime):
32+
"""Re-organize distributed array in local matrix
33+
"""
34+
x = x_dist.asarray(masked=True)
35+
col_counts = [min(blk_cols, M - j * blk_cols) for j in range(p_prime)]
36+
x_blocks = []
37+
offset = 0
38+
for cnt in col_counts:
39+
block_size = N * cnt
40+
x_block = x[offset: offset + block_size]
41+
if len(x_block) != 0:
42+
x_blocks.append(
43+
x_block.reshape(N, cnt)
44+
)
45+
offset += block_size
46+
x = np.hstack(x_blocks)
47+
return x
48+
3049

3150
@pytest.mark.mpi(min_size=1)
32-
@pytest.mark.parametrize("M, K, N, dtype_str", test_params)
51+
@pytest.mark.parametrize("N, K, M, dtype_str", test_params)
3352
def test_MPIMatrixMult(N, K, M, dtype_str):
3453
dtype = np.dtype(dtype_str)
3554

3655
cmplx = 1j if np.issubdtype(dtype, np.complexfloating) else 0
3756
base_float_dtype = np.float32 if dtype == np.complex64 else np.float64
3857

39-
comm, rank, row_id, col_id, is_active = MPIMatrixMult.active_grid_comm(base_comm, N, M)
58+
comm, rank, row_id, col_id, is_active = \
59+
MPIMatrixMult.active_grid_comm(base_comm, N, M)
4060
if not is_active: return
4161

4262
size = comm.Get_size()
4363
p_prime = math.isqrt(size)
64+
cols_id = comm.allgather(col_id)
4465

4566
# Calculate local matrix dimensions
4667
blk_rows_A = int(math.ceil(N / p_prime))
@@ -52,6 +73,7 @@ def test_MPIMatrixMult(N, K, M, dtype_str):
5273
col_end_X = min(M, col_start_X + blk_cols_X)
5374
local_col_X_len = max(0, col_end_X - col_start_X)
5475

76+
# Fill local matrices
5577
A_glob_real = np.arange(N * K, dtype=base_float_dtype).reshape(N, K)
5678
A_glob_imag = np.arange(N * K, dtype=base_float_dtype).reshape(N, K) * 0.5
5779
A_glob = (A_glob_real + cmplx * A_glob_imag).astype(dtype)
@@ -88,32 +110,8 @@ def test_MPIMatrixMult(N, K, M, dtype_str):
88110
xadj_dist = Aop.H @ y_dist
89111

90112
# Re-organize in local matrix
91-
y = y_dist.asarray(masked=True)
92-
col_counts = [min(blk_cols_X, M - j * blk_cols_X) for j in range(p_prime)]
93-
y_blocks = []
94-
offset = 0
95-
for cnt in col_counts:
96-
block_size = N * cnt
97-
y_block = y[offset: offset + block_size]
98-
if len(y_block) != 0:
99-
y_blocks.append(
100-
y_block.reshape(N, cnt)
101-
)
102-
offset += block_size
103-
y = np.hstack(y_blocks)
104-
105-
xadj = xadj_dist.asarray(masked=True)
106-
xadj_blocks = []
107-
offset = 0
108-
for cnt in col_counts:
109-
block_size = K * cnt
110-
xadj_blk = xadj[offset: offset + block_size]
111-
if len(xadj_blk) != 0:
112-
xadj_blocks.append(
113-
xadj_blk.reshape(K, cnt)
114-
)
115-
offset += block_size
116-
xadj = np.hstack(xadj_blocks)
113+
y = _reorganize_local_matrix(y_dist, N, M, blk_cols_X, p_prime)
114+
xadj = _reorganize_local_matrix(xadj_dist, K, M, blk_cols_X, p_prime)
117115

118116
if rank == 0:
119117
y_loc = A_glob @ X_glob
@@ -129,5 +127,36 @@ def test_MPIMatrixMult(N, K, M, dtype_str):
129127
xadj.squeeze(),
130128
xadj_loc.squeeze(),
131129
rtol=np.finfo(np.dtype(dtype)).resolution,
132-
err_msg=f"Rank {rank}: Ajoint verification failed."
133-
)
130+
err_msg=f"Rank {rank}: Adjoint verification failed."
131+
)
132+
133+
# Chain with another operator
134+
Dop = FirstDerivative(dims=(N, col_end_X - col_start_X),
135+
axis=0, dtype=dtype)
136+
DBop = MPIBlockDiag(ops=[Dop, ], base_comm=comm, mask=cols_id)
137+
Op = DBop @ Aop
138+
139+
y1_dist = Op @ x_dist
140+
xadj1_dist = Op.H @ y1_dist
141+
142+
# Re-organize in local matrix
143+
y1 = _reorganize_local_matrix(y1_dist, N, M, blk_cols_X, p_prime)
144+
xadj1 = _reorganize_local_matrix(xadj1_dist, K, M, blk_cols_X, p_prime)
145+
146+
if rank == 0:
147+
Dop_glob = FirstDerivative(dims=(N, M), axis=0, dtype=dtype)
148+
y1_loc = (Dop_glob @ (A_glob @ X_glob).ravel()).reshape(N, M)
149+
assert_allclose(
150+
y1.squeeze(),
151+
y1_loc.squeeze(),
152+
rtol=np.finfo(np.dtype(dtype)).resolution,
153+
err_msg=f"Rank {rank}: Forward verification failed."
154+
)
155+
156+
xadj1_loc = A_glob.conj().T @ (Dop_glob.H @ y1_loc.ravel()).reshape(N, M)
157+
assert_allclose(
158+
xadj1.squeeze(),
159+
xadj1_loc.squeeze(),
160+
rtol=np.finfo(np.dtype(dtype)).resolution,
161+
err_msg=f"Rank {rank}: Adjoint verification failed."
162+
)

0 commit comments

Comments
 (0)