Skip to content

Commit c894e85

Browse files
authored
Merge pull request #137 from mrava87/doc-genproxgrad
doc: improved consistency in code and doc in GeneralizedProximalGradient
2 parents 488dcbd + 8e3bcbb commit c894e85

File tree

3 files changed

+67
-15
lines changed

3 files changed

+67
-15
lines changed

pyproximal/optimization/primal.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,7 @@ def GeneralizedProximalGradient(proxfs, proxgs, x0, tau=None,
308308
.. math::
309309
310310
\mathbf{x} = \argmin_\mathbf{x} \sum_{i=1}^n f_i(\mathbf{x})
311-
+ \sum_{j=1}^m \tau_j g_j(\mathbf{x}),~~n,m \in \mathbb{N}^+
311+
+ \sum_{j=1}^m \epsilon_j g_j(\mathbf{x}),~~n,m \in \mathbb{N}^+
312312
313313
where the :math:`f_i(\mathbf{x})` are smooth convex functions with a uniquely
314314
defined gradient and the :math:`g_j(\mathbf{x})` are any convex function that
@@ -329,7 +329,7 @@ def GeneralizedProximalGradient(proxfs, proxgs, x0, tau=None,
329329
backtracking is used to adaptively estimate the best tau at each
330330
iteration.
331331
epsg : :obj:`float` or :obj:`np.ndarray`, optional
332-
Scaling factor of g function
332+
Scaling factor(s) of ``g`` function(s)
333333
niter : :obj:`int`, optional
334334
Number of iterations of iterative scheme
335335
acceleration: :obj:`str`, optional
@@ -352,11 +352,12 @@ def GeneralizedProximalGradient(proxfs, proxgs, x0, tau=None,
352352
353353
.. math::
354354
\text{for } j=1,\cdots,n, \\
355-
~~~~\mathbf z_j^{k+1} = \mathbf z_j^{k} + \eta_k (prox_{\frac{\tau^k}{\omega_j} g_j}(2 \mathbf{x}^{k} - z_j^{k})
356-
- \tau^k \sum_{i=1}^n \nabla f_i(\mathbf{x}^{k})) - \mathbf{x}^{k} \\
355+
~~~~\mathbf z_j^{k+1} = \mathbf z_j^{k} + \epsilon_j
356+
\left[prox_{\frac{\tau^k}{\omega_j} g_j}\left(2 \mathbf{x}^{k} - \mathbf{z}_j^{k}
357+
- \tau^k \sum_{i=1}^n \nabla f_i(\mathbf{x}^{k})\right) - \mathbf{x}^{k} \right] \\
357358
\mathbf{x}^{k+1} = \sum_{j=1}^n \omega_j f_j \\
358359
359-
where :math:`\sum_{j=1}^n \omega_j=1`.
360+
where :math:`\sum_{j=1}^n \omega_j=1`. In the current implementation :math:`\omega_j=1/n`.
360361
"""
361362
# check if epgs is a vector
362363
if np.asarray(epsg).size == 1.:
@@ -393,23 +394,23 @@ def GeneralizedProximalGradient(proxfs, proxgs, x0, tau=None,
393394
for iiter in range(niter):
394395
xold = x.copy()
395396

396-
# proximal step
397+
# gradient
397398
grad = np.zeros_like(x)
398399
for i, proxf in enumerate(proxfs):
399400
grad += proxf.grad(x)
400401

401-
sol = np.zeros_like(x)
402+
# proximal step
403+
x = np.zeros_like(x)
402404
for i, proxg in enumerate(proxgs):
403-
tmp = 2 * y - zs[i] - tau * grad
404-
tmp[:] = proxg.prox(tmp, tau *len(proxgs) )
405-
zs[i] += epsg * (tmp - y)
406-
sol += zs[i] / len(proxgs)
407-
x[:] = sol.copy()
405+
ztmp = 2 * y - zs[i] - tau * grad
406+
ztmp = proxg.prox(ztmp, tau * len(proxgs))
407+
zs[i] += epsg * (ztmp - y)
408+
x += zs[i] / len(proxgs)
408409

409410
# update y
410411
if acceleration == 'vandenberghe':
411412
omega = iiter / (iiter + 3)
412-
elif acceleration== 'fista':
413+
elif acceleration == 'fista':
413414
told = t
414415
t = (1. + np.sqrt(1. + 4. * t ** 2)) / 2.
415416
omega = ((told - 1.) / t)
@@ -781,7 +782,7 @@ def ADMML2(proxg, Op, b, A, x0, tau, niter=10, callback=None, show=False, **kwar
781782

782783
if show:
783784
if iiter < 10 or niter - iiter < 10 or iiter % (niter // 10) == 0:
784-
pf, pg = np.linalg.norm(Op @ x - b), proxg(Ax)
785+
pf, pg = 0.5 * np.linalg.norm(Op @ x - b) ** 2, proxg(Ax)
785786
msg = '%6g %12.5e %10.3e %10.3e %10.3e' % \
786787
(iiter + 1, x[0], pf, pg, pf + pg)
787788
print(msg)

pyproximal/proximal/VStack.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,16 @@ def __init__(self, ops, nn=None, restr=None):
5151
self.xin = cum_nn[:-1]
5252
self.xin = np.insert(self.xin, 0, 0)
5353
self.xend = cum_nn
54+
# store required size of input
55+
self.nx = cum_nn[-1]
5456
else:
5557
self.restr = restr
58+
# store required size of input
59+
self.nx = np.sum([restr.iava.size for restr in self.restr])
5660

5761
def __call__(self, x):
62+
if x.size != self.nx:
63+
raise ValueError(f'x must have size {self.nx}, instead the provided x has size {x.size}')
5864
f = 0.
5965
if hasattr(self, 'nn'):
6066
for iop, op in enumerate(self.ops):
@@ -66,6 +72,8 @@ def __call__(self, x):
6672

6773
@_check_tau
6874
def prox(self, x, tau):
75+
if x.size != self.nx:
76+
raise ValueError(f'x must have size {self.nx}, instead the provided x has size {x.size}')
6977
if hasattr(self, 'nn'):
7078
f = np.hstack([op.prox(x[self.xin[iop]:self.xend[iop]], tau)
7179
for iop, op in enumerate(self.ops)])
@@ -76,6 +84,8 @@ def prox(self, x, tau):
7684
return f
7785

7886
def grad(self, x):
87+
if x.size != self.nx:
88+
raise ValueError(f'x must have size {self.nx}, instead the provided x has size {x.size}')
7989
if hasattr(self, 'nn'):
8090
f = np.hstack([op.grad(x[self.xin[iop]:self.xend[iop]])
8191
for iop, op in enumerate(self.ops)])

pytests/test_proximal.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import numpy as np
44
from numpy.testing import assert_array_equal, assert_array_almost_equal
5-
from pylops import MatrixMult, Identity
5+
from pylops import Identity, MatrixMult, Restriction
66

77
import pyproximal
88
from pyproximal.utils import moreau
@@ -84,6 +84,20 @@ def test_Orthogonal(par):
8484
assert moreau(orth, x, tau)
8585

8686

87+
@pytest.mark.parametrize("par", [(par1), (par2)])
88+
def test_VStack_error(par):
89+
"""VStack operator error when input has wrong dimensions
90+
"""
91+
np.random.seed(10)
92+
nxs = [par['nx'] // 4] * 4
93+
nxs[-1] = par['nx'] - np.sum(nxs[:-1])
94+
l2 = L2()
95+
vstack = VStack([l2] * 4, nxs)
96+
97+
with pytest.raises(ValueError):
98+
vstack.prox(np.ones(nxs[0]), 2)
99+
100+
87101
@pytest.mark.parametrize("par", [(par1), (par2)])
88102
def test_VStack(par):
89103
"""L2 functional with VStack operator of multiple L1s
@@ -109,6 +123,33 @@ def test_VStack(par):
109123
assert moreau(vstack, x, tau)
110124

111125

126+
@pytest.mark.parametrize("par", [(par1), ])
127+
def test_VStack_restriction(par):
128+
"""L2 functional with VStack operator of multiple L1s using restriction
129+
"""
130+
np.random.seed(10)
131+
nxs = [par['nx'] // 2] * 2
132+
nxs[-1] = par['nx'] - np.sum(nxs[:-1])
133+
l2 = L2()
134+
vstack = VStack([l2] * 2,
135+
restr=[Restriction(par['nx'], np.arange(par['nx'] // 2)),
136+
Restriction(par['nx'], par['nx'] // 2 + np.arange(par['nx'] // 2))])
137+
138+
# functional
139+
x = np.random.normal(0., 1., par['nx']).astype(par['dtype'])
140+
assert_array_almost_equal(l2(x), vstack(x), decimal=4)
141+
142+
# gradient
143+
assert_array_almost_equal(l2.grad(x), vstack.grad(x), decimal=4)
144+
145+
# prox / dualprox
146+
tau = 2.
147+
assert_array_equal(l2.prox(x, tau), vstack.prox(x, tau))
148+
149+
# moreau
150+
assert moreau(vstack, x, tau)
151+
152+
112153
def test_Nonlinear():
113154
"""Nonlinear proximal operator. Since this is a template class simply check
114155
that errors are raised when not used properly

0 commit comments

Comments
 (0)