Skip to content

Commit 38df3f7

Browse files
committed
test: added test to compare PG and GPG
1 parent eb10bd8 commit 38df3f7

File tree

3 files changed

+50
-347
lines changed

3 files changed

+50
-347
lines changed

examples/pyprox_deconv_pgm_compare.ipynb

Lines changed: 0 additions & 346 deletions
This file was deleted.

pyproximal/optimization/primal.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,7 @@ def ProximalGradient(proxf, proxg, x0, tau=None, beta=0.5,
264264
if show:
265265
if iiter < 10 or niter - iiter < 10 or iiter % (niter // 10) == 0:
266266
pf, pg = proxf(x), proxg(x)
267-
msg = '%6g %12.5e %10.3e %10.3e %10.3e' % \
267+
msg = '%6g %12.5e %10.3e %10.3e %10.3e %10.3e' % \
268268
(iiter + 1, np.real(to_numpy(x[0])) if x.ndim == 1 else np.real(to_numpy(x[0, 0])),
269269
pf, pg[0] if epsg_print == 'Multi' else pg,
270270
pf + np.sum(epsg * pg),

pytests/test_solver.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import pytest
2+
3+
import numpy as np
4+
from numpy.testing import assert_array_almost_equal
5+
from pylops.basicoperators import MatrixMult
6+
from pyproximal.proximal import L1, L2
7+
from pyproximal.optimization.primal import ProximalGradient, GeneralizedProximalGradient
8+
9+
par1 = {'n': 8, 'm': 10, 'dtype': 'float32'} # float64
10+
par2 = {'n': 8, 'm': 10, 'dtype': 'float64'} # float32
11+
12+
13+
@pytest.mark.parametrize("par", [(par1), (par2)])
14+
def test_PG_GPG(par):
15+
"""Check equivalency of ProximalGradient and GeneralizedProximalGradient when using
16+
a single regularization term
17+
"""
18+
np.random.seed(0)
19+
n, m = par['n'], par['m']
20+
21+
# Define sparse model
22+
x = np.zeros(m)
23+
x[2], x[4] = 1, 0.5
24+
25+
# Random mixing matrix
26+
R = np.random.normal(0., 1., (n, m))
27+
Rop = MatrixMult(R)
28+
29+
y = Rop @ x
30+
31+
# Step size
32+
L = (Rop.H * Rop).eigs(1).real
33+
tau = 0.99 / L
34+
35+
# PG
36+
l2 = L2(Op=Rop, b=y, niter=10, warm=True)
37+
l1 = L1(sigma=5e-1)
38+
xpg = ProximalGradient(l2, l1, x0=np.zeros(m),
39+
tau=tau, niter=100,
40+
acceleration='fista')
41+
42+
# GPG
43+
l2 = L2(Op=Rop, b=y, niter=10, warm=True)
44+
l1 = L1(sigma=5e-1)
45+
xgpg = GeneralizedProximalGradient([l2, ], [l1, ], x0=np.zeros(m),
46+
tau=tau, niter=100,
47+
acceleration='fista')
48+
49+
assert_array_almost_equal(xpg, xgpg, decimal=2)

0 commit comments

Comments
 (0)