Skip to content

Commit 5656aba

Browse files
committed
Check for smaller values
1 parent 206d7be commit 5656aba

File tree

2 files changed

+51
-51
lines changed

2 files changed

+51
-51
lines changed

.github/workflows/build.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,4 +43,4 @@ jobs:
4343
- name: Install pylops-mpi
4444
run: pip install .
4545
- name: Testing using pytest-mpi
46-
run: mpiexec -n ${{ matrix.rank }} pytest tests/ --with-mpi -v -x
46+
run: mpiexec -n ${{ matrix.rank }} pytest tests/test_blockdiag.py --with-mpi -v -x

tests/test_blockdiag.py

Lines changed: 50 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,10 @@
1111
import pylops_mpi
1212
from pylops_mpi.utils.dottest import dottest
1313

14-
par1 = {'ny': 101, 'nx': 101, 'dtype': np.float64}
15-
par1j = {'ny': 101, 'nx': 101, 'dtype': np.complex128}
16-
par2 = {'ny': 301, 'nx': 101, 'dtype': np.float64}
17-
par2j = {'ny': 301, 'nx': 101, 'dtype': np.complex128}
14+
par1 = {'ny': 31, 'nx': 31, 'dtype': np.float64}
15+
par1j = {'ny': 41, 'nx': 41, 'dtype': np.complex128}
16+
par2 = {'ny': 31, 'nx': 41, 'dtype': np.float64}
17+
par2j = {'ny': 31, 'nx': 41, 'dtype': np.complex128}
1818

1919
np.random.seed(42)
2020

@@ -57,49 +57,49 @@ def test_blockdiag(par):
5757
assert_allclose(y_rmat_mpi, y_rmat_np, rtol=1e-14)
5858

5959

60-
# @pytest.mark.mpi(min_size=2)
61-
# @pytest.mark.parametrize("par", [(par1), (par1j), (par2), (par2j)])
62-
# def test_stacked_blockdiag(par):
63-
# """Tests for MPIStackedBlogDiag"""
64-
# size = MPI.COMM_WORLD.Get_size()
65-
# rank = MPI.COMM_WORLD.Get_rank()
66-
# Op = pylops.MatrixMult(A=((rank + 1) * np.ones(shape=(par['ny'], par['nx']))).astype(par['dtype']))
67-
# BDiag_MPI = pylops_mpi.MPIBlockDiag(ops=[Op, ])
68-
# FirstDeriv_MPI = pylops_mpi.MPIFirstDerivative(dims=(par['ny'], par['nx']), dtype=par['dtype'])
69-
# StackedBDiag_MPI = pylops_mpi.MPIStackedBlockDiag(ops=[BDiag_MPI, FirstDeriv_MPI])
70-
#
71-
# dist1 = pylops_mpi.DistributedArray(global_shape=size * par['nx'], dtype=par['dtype'])
72-
# dist1[:] = np.ones(dist1.local_shape, dtype=par['dtype'])
73-
# dist2 = pylops_mpi.DistributedArray(global_shape=par['nx'] * par['ny'], dtype=par['dtype'])
74-
# dist2[:] = np.ones(dist2.local_shape, dtype=par['dtype'])
75-
# x = pylops_mpi.StackedDistributedArray(distarrays=[dist1, dist2])
76-
# x_global = x.asarray()
77-
#
78-
# dist1 = pylops_mpi.DistributedArray(global_shape=size * par['ny'], dtype=par['dtype'])
79-
# dist1[:] = np.ones(dist1.local_shape, dtype=par['dtype'])
80-
# dist2 = pylops_mpi.DistributedArray(global_shape=par['nx'] * par['ny'], dtype=par['dtype'])
81-
# dist2[:] = np.ones(dist2.local_shape, dtype=par['dtype'])
82-
# y = pylops_mpi.StackedDistributedArray(distarrays=[dist1, dist2])
83-
# y_global = y.asarray()
84-
#
85-
# # Forward
86-
# x_mat = StackedBDiag_MPI @ x
87-
# # Adjoint
88-
# y_rmat = StackedBDiag_MPI.H @ y
89-
# assert isinstance(x_mat, pylops_mpi.StackedDistributedArray)
90-
# assert isinstance(y_rmat, pylops_mpi.StackedDistributedArray)
91-
# # Dot test
92-
# dottest(StackedBDiag_MPI, x, y, size * par['ny'] + par['nx'] * par['ny'], size * par['nx'] + par['nx'] * par['ny'])
93-
#
94-
# x_mat_mpi = x_mat.asarray()
95-
# y_rmat_mpi = y_rmat.asarray()
96-
#
97-
# if rank == 0:
98-
# ops = [pylops.MatrixMult((i + 1) * np.ones(shape=(par['ny'], par['nx'])).astype(par['dtype'])) for i in range(size)]
99-
# BDiag = pylops.BlockDiag(ops=ops)
100-
# FirstDeriv = pylops.FirstDerivative(dims=(par['ny'], par['nx']), axis=0, dtype=par['dtype'])
101-
# BDiag_final = pylops.BlockDiag([BDiag, FirstDeriv])
102-
# x_mat_np = BDiag_final @ x_global
103-
# y_rmat_np = BDiag_final.H @ y_global
104-
# assert_allclose(x_mat_mpi, x_mat_np, rtol=1e-14)
105-
# assert_allclose(y_rmat_mpi, y_rmat_np, rtol=1e-14)
60+
@pytest.mark.mpi(min_size=2)
61+
@pytest.mark.parametrize("par", [(par1), (par1j), (par2), (par2j)])
62+
def test_stacked_blockdiag(par):
63+
"""Tests for MPIStackedBlogDiag"""
64+
size = MPI.COMM_WORLD.Get_size()
65+
rank = MPI.COMM_WORLD.Get_rank()
66+
Op = pylops.MatrixMult(A=((rank + 1) * np.ones(shape=(par['ny'], par['nx']))).astype(par['dtype']))
67+
BDiag_MPI = pylops_mpi.MPIBlockDiag(ops=[Op, ])
68+
FirstDeriv_MPI = pylops_mpi.MPIFirstDerivative(dims=(par['ny'], par['nx']), dtype=par['dtype'])
69+
StackedBDiag_MPI = pylops_mpi.MPIStackedBlockDiag(ops=[BDiag_MPI, FirstDeriv_MPI])
70+
71+
dist1 = pylops_mpi.DistributedArray(global_shape=size * par['nx'], dtype=par['dtype'])
72+
dist1[:] = np.ones(dist1.local_shape, dtype=par['dtype'])
73+
dist2 = pylops_mpi.DistributedArray(global_shape=par['nx'] * par['ny'], dtype=par['dtype'])
74+
dist2[:] = np.ones(dist2.local_shape, dtype=par['dtype'])
75+
x = pylops_mpi.StackedDistributedArray(distarrays=[dist1, dist2])
76+
x_global = x.asarray()
77+
78+
dist1 = pylops_mpi.DistributedArray(global_shape=size * par['ny'], dtype=par['dtype'])
79+
dist1[:] = np.ones(dist1.local_shape, dtype=par['dtype'])
80+
dist2 = pylops_mpi.DistributedArray(global_shape=par['nx'] * par['ny'], dtype=par['dtype'])
81+
dist2[:] = np.ones(dist2.local_shape, dtype=par['dtype'])
82+
y = pylops_mpi.StackedDistributedArray(distarrays=[dist1, dist2])
83+
y_global = y.asarray()
84+
85+
# Forward
86+
x_mat = StackedBDiag_MPI @ x
87+
# Adjoint
88+
y_rmat = StackedBDiag_MPI.H @ y
89+
assert isinstance(x_mat, pylops_mpi.StackedDistributedArray)
90+
assert isinstance(y_rmat, pylops_mpi.StackedDistributedArray)
91+
# Dot test
92+
dottest(StackedBDiag_MPI, x, y, size * par['ny'] + par['nx'] * par['ny'], size * par['nx'] + par['nx'] * par['ny'])
93+
94+
x_mat_mpi = x_mat.asarray()
95+
y_rmat_mpi = y_rmat.asarray()
96+
97+
if rank == 0:
98+
ops = [pylops.MatrixMult((i + 1) * np.ones(shape=(par['ny'], par['nx'])).astype(par['dtype'])) for i in range(size)]
99+
BDiag = pylops.BlockDiag(ops=ops)
100+
FirstDeriv = pylops.FirstDerivative(dims=(par['ny'], par['nx']), axis=0, dtype=par['dtype'])
101+
BDiag_final = pylops.BlockDiag([BDiag, FirstDeriv])
102+
x_mat_np = BDiag_final @ x_global
103+
y_rmat_np = BDiag_final.H @ y_global
104+
assert_allclose(x_mat_mpi, x_mat_np, rtol=1e-14)
105+
assert_allclose(y_rmat_mpi, y_rmat_np, rtol=1e-14)

0 commit comments

Comments
 (0)