Skip to content

Commit 5a457c4

Browse files
author
Olivier Leblanc
committed
add tests
1 parent 7810d5b commit 5a457c4

File tree

2 files changed

+33
-3
lines changed

2 files changed

+33
-3
lines changed

examples/plot_norms.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,4 +123,22 @@
123123
plt.xlabel('x')
124124
plt.title(r'$||x||_1$')
125125
plt.legend()
126+
plt.tight_layout()
127+
128+
###############################################################################
129+
# We consider now the TV norm.
130+
TV = pyproximal.TV(dim=1, sigma=1.)
131+
132+
x = np.arange(-1, 1, 0.1)
133+
print('||x||_{TV}: ', l1(x))
134+
135+
tau = 0.5
136+
xp = TV.prox(x, tau)
137+
138+
plt.figure(figsize=(7, 2))
139+
plt.plot(x, x, 'k', lw=2, label='x')
140+
plt.plot(x, xp, 'r', lw=2, label='prox(x)')
141+
plt.xlabel('x')
142+
plt.title(r'$||x||_{TV}$')
143+
plt.legend()
126144
plt.tight_layout()

pytests/test_norms.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
import numpy as np
44
from numpy.testing import assert_array_almost_equal
55

6-
from pylops.basicoperators import Identity, Diagonal, MatrixMult
6+
from pylops.basicoperators import Identity, Diagonal, MatrixMult, FirstDerivative
77
from pyproximal.utils import moreau
8-
from pyproximal.proximal import Box, Euclidean, L2, L1, L21, L21_plus_L1, Huber, Nuclear
8+
from pyproximal.proximal import Box, Euclidean, L2, L1, L21, L21_plus_L1, Huber, Nuclear, TV
99

1010
par1 = {'nx': 10, 'sigma': 1., 'dtype': 'float32'} # even float32
1111
par2 = {'nx': 11, 'sigma': 2., 'dtype': 'float64'} # odd float64
@@ -189,6 +189,18 @@ def test_Huber(par):
189189
assert moreau(hub, x, tau)
190190

191191

192+
@pytest.mark.parametrize("par", [(par1), (par2)])
193+
def test_TV(par):
194+
"""TV norm of x and proximal
195+
"""
196+
tv = TV(dim=1, sigma=par['sigma'])
197+
# norm
198+
x = np.random.normal(0., 1., par['nx']).astype(par['dtype'])
199+
derivOp = FirstDerivative(par['nx'], dtype=par['dtype'], kind='forward')
200+
dx = derivOp @ x
201+
assert_array_almost_equal(tv(x), par['sigma'] * np.sum(np.abs(dx), axis=0))
202+
203+
192204
def test_Nuclear_FOM():
193205
"""Nuclear norm benchmark with FOM solver
194206
"""
@@ -228,7 +240,7 @@ def test_Weighted_Nuclear(par):
228240
# the exact same singular values)
229241
X = np.random.uniform(0., 0.1, (par['nx'], 2 * par['nx'])).astype(par['dtype'])
230242
S = np.linalg.svd(X, compute_uv=False)
231-
assert (nucl(X.ravel()) - np.sum(weights[:S.size] * S)) < 1e-3
243+
assert (nucl(X.ravel()) - np.sum(weights[:S.size] * S)) < 1e-2
232244

233245
# prox / dualprox
234246
tau = 2.

0 commit comments

Comments
 (0)