Skip to content

Commit 8d458b0

Browse files
committed
test: temporary fix to test_power
1 parent 51ab3e7 commit 8d458b0

File tree

1 file changed

+34
-33
lines changed

1 file changed

+34
-33
lines changed

tests/test_linearop.py

Lines changed: 34 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -135,39 +135,40 @@ def test_scaled(par):
135135
assert_allclose(Sop_y_np, Sop.H @ y_global, rtol=1e-9)
136136

137137

138-
# @pytest.mark.mpi(min_size=2)
139-
# @pytest.mark.parametrize("par", [(par1), (par1j)])
140-
# def test_power(par):
141-
# """Test the PowerLinearOperator"""
142-
# Op = pylops.MatrixMult(A=((rank + 1) * np.ones(shape=(par['ny'], par['nx']))).astype(par['dtype']))
143-
# BDiag_MPI = MPIBlockDiag(ops=[Op, ])
144-
# # Power Operator
145-
# Pop_MPI = BDiag_MPI ** 3
146-
147-
# # Forward Mode
148-
# x = DistributedArray(global_shape=size * par['nx'], dtype=par['dtype'], engine=backend)
149-
# x[:] = np.ones(par['nx'])
150-
# x_global = x.asarray()
151-
# Pop_x = Pop_MPI @ x
152-
# assert isinstance(Pop_x, DistributedArray)
153-
# Pop_x_np = Pop_x.asarray()
154-
155-
# # Adjoint Mode
156-
# y = DistributedArray(global_shape=size * par['ny'], dtype=par['dtype'], engine=backend)
157-
# y[:] = np.ones(par['ny'])
158-
# y_global = y.asarray()
159-
# Pop_y = Pop_MPI.H @ y
160-
# assert isinstance(Pop_y, DistributedArray)
161-
# Pop_y_np = Pop_y.asarray()
162-
163-
# if rank == 0:
164-
# ops = [pylops.MatrixMult((i + 1) * np.ones(shape=(par['ny'], par['nx'])).astype(par['dtype'])) for i in
165-
# range(size)]
166-
# BDiag = pylops.BlockDiag(ops=ops)
167-
# # TODO (tharitt): Fail PyLops Op ** 3 does not preserve CuPy (it turns to NumPy)
168-
# Pop = BDiag ** 3
169-
# assert_allclose(Pop_x_np, Pop @ x_global, rtol=1e-9)
170-
# assert_allclose(Pop_y_np, Pop.H @ y_global, rtol=1e-9)
138+
@pytest.mark.mpi(min_size=2)
139+
@pytest.mark.parametrize("par", [(par1), (par1j)])
140+
def test_power(par):
141+
"""Test the PowerLinearOperator"""
142+
Op = pylops.MatrixMult(A=((rank + 1) * np.ones(shape=(par['ny'], par['nx']))).astype(par['dtype']),
143+
dtype=par['dtype'])
144+
BDiag_MPI = MPIBlockDiag(ops=[Op, ])
145+
146+
# Power Operator
147+
Pop_MPI = BDiag_MPI ** 3
148+
149+
# Forward Mode
150+
x = DistributedArray(global_shape=size * par['nx'], dtype=par['dtype'], engine=backend)
151+
x[:] = np.ones(par['nx'])
152+
x_global = x.asarray()
153+
Pop_x = Pop_MPI @ x
154+
assert isinstance(Pop_x, DistributedArray)
155+
Pop_x_np = Pop_x.asarray()
156+
157+
# Adjoint Mode
158+
y = DistributedArray(global_shape=size * par['ny'], dtype=par['dtype'], engine=backend)
159+
y[:] = np.ones(par['ny'])
160+
y_global = y.asarray()
161+
Pop_y = Pop_MPI.H @ y
162+
assert isinstance(Pop_y, DistributedArray)
163+
Pop_y_np = Pop_y.asarray()
164+
165+
if rank == 0:
166+
ops = [pylops.MatrixMult((i + 1) * np.ones(shape=(par['ny'], par['nx'])).astype(par['dtype'])) for i in
167+
range(size)]
168+
BDiag = pylops.BlockDiag(ops=ops)
169+
Pop = BDiag * BDiag * BDiag ## temporarely replaced BDiag ** 3 until bug in PyLops is fixed
170+
assert_allclose(Pop_x_np, Pop @ x_global, rtol=1e-9)
171+
assert_allclose(Pop_y_np, Pop.H @ y_global, rtol=1e-9)
171172

172173

173174
@pytest.mark.mpi(min_size=2)

0 commit comments

Comments
 (0)