Skip to content

Commit 7810d5b

Browse files
author
Olivier Leblanc
committed
merge ProxGrad and its accel, add Generalization
1 parent 7e26fe4 commit 7810d5b

File tree

2 files changed

+103
-67
lines changed

2 files changed

+103
-67
lines changed

pyproximal/optimization/primal.py

Lines changed: 100 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -100,12 +100,13 @@ def ProximalPoint(prox, x0, tau, niter=10, callback=None, show=False):
100100

101101

102102
def ProximalGradient(proxf, proxg, x0, tau=None, beta=0.5,
103-
epsg=1., niter=10, niterback=100,
104-
callback=None, show=False):
105-
r"""Proximal gradient
103+
epsg=1., niter=10, niterback=100,
104+
acceleration='vandenberghe',
105+
callback=None, show=False):
106+
r"""Proximal gradient (optionnally accelerated)
106107
107-
Solves the following minimization problem using Proximal gradient
108-
algorithm:
108+
Solves the following minimization problem using (Accelerated) Proximal
109+
gradient algorithm:
109110
110111
.. math::
111112
@@ -138,6 +139,8 @@ def ProximalGradient(proxf, proxg, x0, tau=None, beta=0.5,
138139
Number of iterations of iterative scheme
139140
niterback : :obj:`int`, optional
140141
Max number of iterations of backtracking
142+
acceleration: :obj:`str`, optional
143+
Acceleration (``None``, ``vandenberghe`` or ``fista``)
141144
callback : :obj:`callable`, optional
142145
Function with signature (``callback(x)``) to call after each iteration
143146
where ``x`` is the current model vector
@@ -151,12 +154,15 @@ def ProximalGradient(proxf, proxg, x0, tau=None, beta=0.5,
151154
152155
Notes
153156
-----
154-
The Proximal point algorithm can be expressed by the following recursion:
157+
The (Accelerated) Proximal point algorithm can be expressed by the
158+
following recursion:
155159
156160
.. math::
157161
158-
\mathbf{x}^{k+1} = \prox_{\tau^k \epsilon g}(\mathbf{x}^k -
159-
\tau^k \nabla f(\mathbf{x}^k))
162+
\mathbf{y}^{k+1} = \mathbf{x}^k + \omega^k
163+
(\mathbf{x}^k - \mathbf{x}^{k-1})
164+
\mathbf{x}^{k+1} = \prox_{\tau^k \epsilon g}(\mathbf{y}^{k+1} -
165+
\tau^k \nabla f(\mathbf{y}^{k+1})) \\
160166
161167
where at each iteration :math:`\tau^k` can be estimated by back-tracking
162168
as follows:
@@ -173,7 +179,17 @@ def ProximalGradient(proxf, proxg, x0, tau=None, beta=0.5,
173179
174180
where :math:`\tilde{f}_\tau(\mathbf{x}, \mathbf{y}) = f(\mathbf{y}) +
175181
\nabla f(\mathbf{y})^T (\mathbf{x} - \mathbf{y}) +
176-
1/(2\tau)||\mathbf{x} - \mathbf{y}||_2^2`.
182+
1/(2\tau)||\mathbf{x} - \mathbf{y}||_2^2`,
183+
and
184+
:math:`\omega^k = 0` for ``acceleration=None``,
185+
:math:`\omega^k = k / (k + 3)` for ``acceleration=vandenberghe`` [1]_
186+
or :math:`\omega^k = (t_{k-1}-1)/t_k` for ``acceleration=fista`` where
187+
:math:`t_k = (1 + \sqrt{1+4t_{k-1}^{2}}) / 2` [2]_
188+
189+
.. [1] Vandenberghe, L., "Fast proximal gradient methods", 2010.
190+
.. [2] Beck, A., and Teboulle, M. "A Fast Iterative Shrinkage-Thresholding
191+
Algorithm for Linear Inverse Problems", SIAM Journal on
192+
Imaging Sciences, vol. 2, pp. 183-202. 2009.
177193
178194
"""
179195
# check if epgs is a ve
@@ -182,9 +198,12 @@ def ProximalGradient(proxf, proxg, x0, tau=None, beta=0.5,
182198
else:
183199
epsg_print = 'Multi'
184200

201+
if acceleration not in [None, 'None', 'vandenberghe', 'fista']:
202+
raise NotImplementedError('Acceleration should be None, vandenberghe '
203+
'or fista')
185204
if show:
186205
tstart = time.time()
187-
print('Proximal Gradient\n'
206+
print('Accelerated Proximal Gradient\n'
188207
'---------------------------------------------------------\n'
189208
'Proximal operator (f): %s\n'
190209
'Proximal operator (g): %s\n'
@@ -201,13 +220,32 @@ def ProximalGradient(proxf, proxg, x0, tau=None, beta=0.5,
201220
backtracking = True
202221
tau = 1.
203222

223+
# initialize model
224+
t = 1.
204225
x = x0.copy()
226+
y = x.copy()
227+
228+
# iterate
205229
for iiter in range(niter):
230+
xold = x.copy()
231+
232+
# proximal step
206233
if not backtracking:
207-
x = proxg.prox(x - tau * proxf.grad(x), epsg * tau)
234+
x = proxg.prox(y - tau * proxf.grad(y), epsg * tau)
208235
else:
209-
x, tau = _backtracking(x, tau, proxf, proxg, epsg,
236+
x, tau = _backtracking(y, tau, proxf, proxg, epsg,
210237
beta=beta, niterback=niterback)
238+
239+
# update y
240+
if acceleration == 'vandenberghe':
241+
omega = iiter / (iiter + 3)
242+
elif acceleration== 'fista':
243+
told = t
244+
t = (1. + np.sqrt(1. + 4. * t ** 2)) / 2.
245+
omega = ((told - 1.) / t)
246+
else:
247+
omega = 0
248+
y = x + omega * (x - xold)
211249

212250
# run callback
213251
if callback is not None:
@@ -226,47 +264,44 @@ def ProximalGradient(proxf, proxg, x0, tau=None, beta=0.5,
226264
print('---------------------------------------------------------\n')
227265
return x
228266

229-
230-
def AcceleratedProximalGradient(proxf, proxg, x0, tau=None, beta=0.5,
231-
epsg=1., niter=10, niterback=100,
232-
acceleration='vandenberghe',
267+
def GeneralizedProximalGradient(proxfs, proxgs, x0, tau=None, beta=0.5,
268+
epsg=1., niter=10,
269+
acceleration='None',
233270
callback=None, show=False):
234-
r"""Accelerated Proximal gradient
271+
r"""Generalized Proximal gradient
235272
236-
Solves the following minimization problem using Accelerated Proximal
273+
Solves the following minimization problem using Generalized Proximal
237274
gradient algorithm:
238275
239276
.. math::
240277
241-
\mathbf{x} = \argmin_\mathbf{x} f(\mathbf{x}) + \epsilon g(\mathbf{x})
278+
\mathbf{x} = \argmin_\mathbf{x} \sum_{i=1}^n f_i(\mathbf{x})
279+
+ \sum_{j=1}^m \tau_j g_j(\mathbf{x}),~~n,m \in \mathbb{N}^+
242280
243-
where :math:`f(\mathbf{x})` is a smooth convex function with a uniquely
244-
defined gradient and :math:`g(\mathbf{x})` is any convex function that
245-
has a known proximal operator.
281+
where the :math:`f_i(\mathbf{x})` are smooth convex functions with a uniquely
282+
defined gradient and the :math:`g_j(\mathbf{x})` are any convex function that
283+
have a known proximal operator.
246284
247285
Parameters
248286
----------
249-
proxf : :obj:`pyproximal.ProxOperator`
250-
Proximal operator of f function (must have ``grad`` implemented)
251-
proxg : :obj:`pyproximal.ProxOperator`
252-
Proximal operator of g function
287+
proxfs : :obj:`List of pyproximal.ProxOperator`
288+
Proximal operators of the f_i functions (must have ``grad`` implemented)
289+
proxgs : :obj:`List of pyproximal.ProxOperator`
290+
Proximal operators of the g_j functions
253291
x0 : :obj:`numpy.ndarray`
254292
Initial vector
255293
tau : :obj:`float` or :obj:`numpy.ndarray`, optional
256294
Positive scalar weight, which should satisfy the following condition
257295
to guarantees convergence: :math:`\tau \in (0, 1/L]` where ``L`` is
258-
the Lipschitz constant of :math:`\nabla f`. When ``tau=None``,
296+
the Lipschitz constant of :math:`\sum_{i=1}^n \nabla f_i`. When ``tau=None``,
259297
backtracking is used to adaptively estimate the best tau at each
260-
iteration. Finally note that :math:`\tau` can be chosen to be a vector
261-
when dealing with problems with multiple right-hand-sides
298+
iteration.
262299
beta : obj:`float`, optional
263300
Backtracking parameter (must be between 0 and 1)
264301
epsg : :obj:`float` or :obj:`np.ndarray`, optional
265302
Scaling factor of g function
266303
niter : :obj:`int`, optional
267304
Number of iterations of iterative scheme
268-
niterback : :obj:`int`, optional
269-
Max number of iterations of backtracking
270305
acceleration: :obj:`str`, optional
271306
Acceleration (``vandenberghe`` or ``fista``)
272307
callback : :obj:`callable`, optional
@@ -282,76 +317,75 @@ def AcceleratedProximalGradient(proxf, proxg, x0, tau=None, beta=0.5,
282317
283318
Notes
284319
-----
285-
The Accelerated Proximal point algorithm can be expressed by the
320+
The Generalized Proximal point algorithm can be expressed by the
286321
following recursion:
287322
288323
.. math::
289-
290-
\mathbf{x}^{k+1} = \prox_{\tau^k \epsilon g}(\mathbf{y}^{k+1} -
291-
\tau^k \nabla f(\mathbf{y}^{k+1})) \\
292-
\mathbf{y}^{k+1} = \mathbf{x}^k + \omega^k
293-
(\mathbf{x}^k - \mathbf{x}^{k-1})
294-
295-
where :math:`\omega^k = k / (k + 3)` for ``acceleration=vandenberghe`` [1]_
296-
or :math:`\omega^k = (t_{k-1}-1)/t_k` for ``acceleration=fista`` where
297-
:math:`t_k = (1 + \sqrt{1+4t_{k-1}^{2}}) / 2` [2]_
298-
299-
.. [1] Vandenberghe, L., "Fast proximal gradient methods", 2010.
300-
.. [2] Beck, A., and Teboulle, M. "A Fast Iterative Shrinkage-Thresholding
301-
Algorithm for Linear Inverse Problems", SIAM Journal on
302-
Imaging Sciences, vol. 2, pp. 183-202. 2009.
303-
324+
\text{for } j=1,\cdots,n, \\
325+
~~~~\mathbf z_j^{k+1} = \mathbf z_j^{k} + \eta_k (prox_{\frac{\tau^k}{\omega_j} g_j}(2 \mathbf{x}^{k} - z_j^{k})
326+
- \tau^k \sum_{i=1}^n \nabla f_i(\mathbf{x}^{k})) - \mathbf{x}^{k} - \\
327+
\mathbf{x}^{k+1} = \sum_{j=1}^n \omega_j f_j \\
328+
329+
where :math:`\sum_{j=1}^n \omega_j=1`.
304330
"""
305-
# check if epgs is a ve
331+
# check if epgs is a vector
306332
if np.asarray(epsg).size == 1.:
307333
epsg_print = str(epsg)
308334
else:
309335
epsg_print = 'Multi'
310336

311-
if acceleration not in ['vandenberghe', 'fista']:
312-
raise NotImplementedError('Acceleration should be vandenberghe '
337+
if acceleration not in [None, 'None', 'vandenberghe', 'fista']:
338+
raise NotImplementedError('Acceleration should be None, vandenberghe '
313339
'or fista')
314340
if show:
315341
tstart = time.time()
316-
print('Accelerated Proximal Gradient\n'
342+
print('Generalized Proximal Gradient\n'
317343
'---------------------------------------------------------\n'
318-
'Proximal operator (f): %s\n'
319-
'Proximal operator (g): %s\n'
320-
'tau = %10e\tepsg = %s\tniter = %d\n' % (type(proxf),
321-
type(proxg),
322-
0 if tau is None else tau,
323-
epsg_print, niter))
344+
'Proximal operators (f): %s\n'
345+
'Proximal operators (g): %s\n'
346+
'tau = %10e\tbeta=%10e\nepsg = %s\tniter = %d\n' % ([type(proxf) for proxf in proxfs],
347+
[type(proxg) for proxg in proxgs],
348+
0 if tau is None else tau,
349+
beta, epsg_print, niter))
324350
head = ' Itn x[0] f g J=f+eps*g'
325351
print(head)
326352

327-
backtracking = False
328353
if tau is None:
329-
backtracking = True
330354
tau = 1.
331355

332356
# initialize model
333357
t = 1.
334358
x = x0.copy()
335359
y = x.copy()
360+
zs = [x.copy() for _ in range(len(proxgs))]
336361

337362
# iterate
338363
for iiter in range(niter):
339364
xold = x.copy()
340365

341366
# proximal step
342-
if not backtracking:
343-
x = proxg.prox(y - tau * proxf.grad(y), epsg * tau)
344-
else:
345-
x, tau = _backtracking(y, tau, proxf, proxg, epsg,
346-
beta=beta, niterback=niterback)
367+
grad = np.zeros_like(x)
368+
for i, proxf in enumerate(proxfs):
369+
grad += proxf.grad(x)
370+
371+
sol = np.zeros_like(x)
372+
for i, proxg in enumerate(proxgs):
373+
tmp = 2 * y - zs[i] - tau * grad
374+
tmp[:] = proxg.prox(tmp, tau *len(proxgs) )
375+
zs[i] += epsg * (tmp - y)
376+
sol += zs[i]/len(proxgs)
377+
x[:] = sol.copy()
347378

348379
# update y
349380
if acceleration == 'vandenberghe':
350381
omega = iiter / (iiter + 3)
351-
else:
382+
elif acceleration== 'fista':
352383
told = t
353384
t = (1. + np.sqrt(1. + 4. * t ** 2)) / 2.
354385
omega = ((told - 1.) / t)
386+
else:
387+
omega = 0
388+
355389
y = x + omega * (x - xold)
356390

357391
# run callback
@@ -360,7 +394,7 @@ def AcceleratedProximalGradient(proxf, proxg, x0, tau=None, beta=0.5,
360394

361395
if show:
362396
if iiter < 10 or niter - iiter < 10 or iiter % (niter // 10) == 0:
363-
pf, pg = proxf(x), proxg(x)
397+
pf, pg = np.sum([proxf(x) for proxf in proxfs]), np.sum([proxg(x) for proxg in proxgs])
364398
msg = '%6g %12.5e %10.3e %10.3e %10.3e' % \
365399
(iiter + 1, x[0] if x.ndim == 1 else x[0, 0],
366400
pf, pg[0] if epsg_print == 'Multi' else pg,

pyproximal/proximal/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
L21 L2,1 Norm
2121
L21_plus_L1 L2,1 + L1 mixed-norm
2222
Huber Huber Norm
23+
TV Total Variation Norm
2324
Nuclear Nuclear Norm
2425
NuclearBall Nuclear Ball
2526
Orthogonal Product between orthogonal operator and vector
@@ -46,6 +47,7 @@
4647
from .L21 import *
4748
from .L21_plus_L1 import *
4849
from .Huber import *
50+
from .TV import *
4951
from .Nuclear import *
5052
from .Orthogonal import *
5153
from .VStack import *
@@ -58,6 +60,6 @@
5860

5961
__all__ = ['Box', 'Simplex', 'Intersection', 'AffineSet', 'Quadratic',
6062
'Euclidean', 'EuclideanBall', 'L0Ball', 'L1', 'L1Ball', 'L2',
61-
'L2Convolve', 'L21', 'L21_plus_L1', 'Huber', 'Nuclear',
63+
'L2Convolve', 'L21', 'L21_plus_L1', 'Huber', 'TV', 'Nuclear',
6264
'NuclearBall', 'Orthogonal', 'VStack', 'Nonlinear', 'SCAD',
6365
'Log', 'ETP', 'Geman', 'QuadraticEnvelopeCard', 'SingularValuePenalty']

0 commit comments

Comments
 (0)