Skip to content

Commit 0090e34

Browse files
authored
Merge pull request #87 from mrava87/main
HQS solver and L0 norm
2 parents 4a13fff + 6941d00 commit 0090e34

File tree

8 files changed

+254
-25
lines changed

8 files changed

+254
-25
lines changed

docs/source/api/index.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ Convex
6464
EuclideanBall
6565
Huber
6666
Intersection
67+
L0
6768
L0Ball
6869
L1
6970
L1Ball
@@ -136,6 +137,7 @@ Primal
136137
AcceleratedProximalGradient
137138
ADMM
138139
ADMML2
140+
HQS
139141
LinearizedADMM
140142
ProximalGradient
141143
ProximalPoint

pyproximal/optimization/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,9 @@
1111
ProximalPoint Proximal point algorithm (or proximal min.)
1212
ProximalGradient Proximal gradient algorithm
1313
AcceleratedProximalGradient Accelerated Proximal gradient algorithm
14+
HQS Half Quadrating Splitting
1415
ADMM Alternating Direction Method of Multipliers
15-
ADMML2 ADMM with L2 misfit term
16+
ADMML2 ADMM with L2 misfit term
1617
LinearizedADMM Linearized ADMM
1718
TwIST Two-step Iterative Shrinkage/Threshold
1819
PlugAndPlay Plug-and-Play Prior with ADMM

pyproximal/optimization/primal.py

Lines changed: 125 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -372,7 +372,115 @@ def AcceleratedProximalGradient(proxf, proxg, x0, tau=None, beta=0.5,
372372
return x
373373

374374

375-
def ADMM(proxf, proxg, x0, tau, niter=10, callback=None, show=False):
375+
def HQS(proxf, proxg, x0, tau, niter=10, gfirst=True,
376+
callback=None, callbackz=False, show=False):
377+
r"""Half Quadratic splitting
378+
379+
Solves the following minimization problem using Half Quadratic splitting
380+
algorithm:
381+
382+
.. math::
383+
384+
\mathbf{x},\mathbf{z} = \argmin_{\mathbf{x},\mathbf{z}}
385+
f(\mathbf{x}) + g(\mathbf{z}) \\
386+
s.t. \; \mathbf{x}=\mathbf{z}
387+
388+
where :math:`f(\mathbf{x})` and :math:`g(\mathbf{z})` are any convex
389+
function that has a known proximal operator.
390+
391+
Parameters
392+
----------
393+
proxf : :obj:`pyproximal.ProxOperator`
394+
Proximal operator of f function
395+
proxg : :obj:`pyproximal.ProxOperator`
396+
Proximal operator of g function
397+
x0 : :obj:`numpy.ndarray`
398+
Initial vector
399+
tau : :obj:`float`, optional
400+
Positive scalar weight, which should satisfy the following condition
401+
to guarantees convergence: :math:`\tau \in (0, 1/L]` where ``L`` is
402+
the Lipschitz constant of :math:`\nabla f`.
403+
niter : :obj:`int`, optional
404+
Number of iterations of iterative scheme
405+
gfirst : :obj:`bool`, optional
406+
Apply Proximal of operator ``g`` first (``True``) or Proximal of
407+
operator ``f`` first (``False``)
408+
callback : :obj:`callable`, optional
409+
Function with signature (``callback(x)``) to call after each iteration
410+
where ``x`` is the current model vector
411+
callbackz : :obj:`bool`, optional
412+
Modify callback signature to (``callback(x, z)``) when ``callbackz=True``
413+
show : :obj:`bool`, optional
414+
Display iterations log
415+
416+
Returns
417+
-------
418+
x : :obj:`numpy.ndarray`
419+
Inverted model
420+
z : :obj:`numpy.ndarray`
421+
Inverted second model
422+
423+
Notes
424+
-----
425+
The HQS algorithm can be expressed by the following recursion [1]_:
426+
427+
.. math::
428+
429+
\mathbf{z}^{k+1} = \prox_{\tau g}(\mathbf{x}^{k})
430+
\mathbf{x}^{k+1} = \prox_{\tau f}(\mathbf{z}^{k+1})\\
431+
432+
Note that ``x`` and ``z`` converge to each other, however if iterations are
433+
stopped too early ``x`` is guaranteed to belong to the domain of ``f``
434+
while ``z`` is guaranteed to belong to the domain of ``g``. Depending on
435+
the problem either of the two may be the best solution.
436+
437+
.. [1] D., Geman, and C., Yang, "Nonlinear image recovery with halfquadratic
438+
regularization", IEEE Transactions on Image Processing,
439+
4, 7, pp. 932-946, 1995.
440+
441+
"""
442+
if show:
443+
tstart = time.time()
444+
print('HQS\n'
445+
'---------------------------------------------------------\n'
446+
'Proximal operator (f): %s\n'
447+
'Proximal operator (g): %s\n'
448+
'tau = %10e\tniter = %d\n' % (type(proxf), type(proxg),
449+
tau, niter))
450+
head = ' Itn x[0] f g J = f + g'
451+
print(head)
452+
453+
x = x0.copy()
454+
z = np.zeros_like(x)
455+
for iiter in range(niter):
456+
if gfirst:
457+
z = proxg.prox(x, tau)
458+
x = proxf.prox(z, tau)
459+
else:
460+
x = proxf.prox(z, tau)
461+
z = proxg.prox(x, tau)
462+
463+
# run callback
464+
if callback is not None:
465+
if callbackz:
466+
callback(x, z)
467+
else:
468+
callback(x)
469+
470+
if show:
471+
if iiter < 10 or niter - iiter < 10 or iiter % (niter // 10) == 0:
472+
pf, pg = proxf(x), proxg(x)
473+
msg = '%6g %12.5e %10.3e %10.3e %10.3e' % \
474+
(iiter + 1, x[0], pf, pg, pf + pg)
475+
print(msg)
476+
if show:
477+
print('\nTotal time (s) = %.2f' % (time.time() - tstart))
478+
print('---------------------------------------------------------\n')
479+
return x, z
480+
481+
482+
def ADMM(proxf, proxg, x0, tau, niter=10, gfirst=False,
483+
callback=None, callbackz=False, show=False):
376484
r"""Alternating Direction Method of Multipliers
377485
378486
Solves the following minimization problem using Alternating Direction
@@ -417,9 +525,14 @@ def ADMM(proxf, proxg, x0, tau, niter=10, callback=None, show=False):
417525
the Lipschitz constant of :math:`\nabla f`.
418526
niter : :obj:`int`, optional
419527
Number of iterations of iterative scheme
528+
gfirst : :obj:`bool`, optional
529+
Apply Proximal of operator ``g`` first (``True``) or Proximal of
530+
operator ``f`` first (``False``)
420531
callback : :obj:`callable`, optional
421532
Function with signature (``callback(x)``) to call after each iteration
422533
where ``x`` is the current model vector
534+
callbackz : :obj:`bool`, optional
535+
Modify callback signature to (``callback(x, z)``) when ``callbackz=True``
423536
show : :obj:`bool`, optional
424537
Display iterations log
425538
@@ -445,7 +558,7 @@ def ADMM(proxf, proxg, x0, tau, niter=10, callback=None, show=False):
445558
\mathbf{z}^{k+1} = \prox_{\tau g}(\mathbf{x}^{k+1} + \mathbf{u}^{k})\\
446559
\mathbf{u}^{k+1} = \mathbf{u}^{k} + \mathbf{x}^{k+1} - \mathbf{z}^{k+1}
447560
448-
Note that ``x`` and ``z`` converge to each other, but if iterations are
561+
Note that ``x`` and ``z`` converge to each other, however if iterations are
449562
stopped too early ``x`` is guaranteed to belong to the domain of ``f``
450563
while ``z`` is guaranteed to belong to the domain of ``g``. Depending on
451564
the problem either of the two may be the best solution.
@@ -465,14 +578,20 @@ def ADMM(proxf, proxg, x0, tau, niter=10, callback=None, show=False):
465578
x = x0.copy()
466579
u = z = np.zeros_like(x)
467580
for iiter in range(niter):
468-
x = proxf.prox(z - u, tau)
469-
z = proxg.prox(x + u, tau)
581+
if gfirst:
582+
z = proxg.prox(x + u, tau)
583+
x = proxf.prox(z - u, tau)
584+
else:
585+
x = proxf.prox(z - u, tau)
586+
z = proxg.prox(x + u, tau)
470587
u = u + x - z
471588

472589
# run callback
473590
if callback is not None:
474-
callback(x)
475-
591+
if callbackz:
592+
callback(x, z)
593+
else:
594+
callback(x)
476595
if show:
477596
if iiter < 10 or niter - iiter < 10 or iiter % (niter // 10) == 0:
478597
pf, pg = proxf(x), proxg(x)

pyproximal/optimization/primaldual.py

Lines changed: 31 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44

55
def PrimalDual(proxf, proxg, A, x0, tau, mu, z=None, theta=1., niter=10,
6-
gfirst=True, callback=None, show=False):
6+
gfirst=True, callback=None, callbacky=False, show=False):
77
r"""Primal-dual algorithm
88
99
Solves the following (possibly) nonlinear minimization problem using
@@ -39,10 +39,12 @@ def PrimalDual(proxf, proxg, A, x0, tau, mu, z=None, theta=1., niter=10,
3939
Linear operator of g
4040
x0 : :obj:`numpy.ndarray`
4141
Initial vector
42-
tau : :obj:`float`
43-
Stepsize of subgradient of :math:`f`
44-
mu : :obj:`float`
45-
Stepsize of subgradient of :math:`g^*`
42+
tau : :obj:`float` or :obj:`np.ndarray`
43+
Stepsize of subgradient of :math:`f`. This can be constant
44+
or function of iterations (in the latter cases provided as np.ndarray)
45+
mu : :obj:`float` or :obj:`np.ndarray`
46+
Stepsize of subgradient of :math:`g^*`. This can be constant
47+
or function of iterations (in the latter cases provided as np.ndarray)
4648
z : :obj:`numpy.ndarray`, optional
4749
Additional vector
4850
theta : :obj:`float`
@@ -58,6 +60,8 @@ def PrimalDual(proxf, proxg, A, x0, tau, mu, z=None, theta=1., niter=10,
5860
callback : :obj:`callable`, optional
5961
Function with signature (``callback(x)``) to call after each iteration
6062
where ``x`` is the current model vector
63+
callbacky : :obj:`bool`, optional
64+
Modify callback signature to (``callback(x, y)``) when ``callbacky=True``
6165
show : :obj:`bool`, optional
6266
Display iterations log
6367
@@ -93,11 +97,20 @@ def PrimalDual(proxf, proxg, A, x0, tau, mu, z=None, theta=1., niter=10,
9397
\mathbf{y}^{k+1} = \prox_{\mu g^*}(\mathbf{y}^{k} +
9498
\mu \mathbf{A}\bar{\mathbf{x}}^{k+1})
9599
96-
.. [1] A., Chambolle, and T., Pock, "A first-order primal-dual algorithm for
100+
.. [1] A., Chambolle, and T., Pock, "A first-order primal-dual algorithm for
97101
convex problems with applications to imaging", Journal of Mathematical
98-
Imaging and Vision, 40, 8pp. 120145. 2011.
102+
Imaging and Vision, 40, 8pp. 120-145. 2011.
99103
100104
"""
105+
# check if tau and mu are scalars or arrays
106+
fixedtau = fixedmu = False
107+
if isinstance(tau, (int, float)):
108+
tau = tau * np.ones(niter)
109+
fixedtau = True
110+
if isinstance(mu, (int, float)):
111+
mu = mu * np.ones(niter)
112+
fixedmu = True
113+
101114
if show:
102115
tstart = time.time()
103116
print('Primal-dual: min_x f(Ax) + x^T z + g(x)\n'
@@ -106,9 +119,10 @@ def PrimalDual(proxf, proxg, A, x0, tau, mu, z=None, theta=1., niter=10,
106119
'Proximal operator (g): %s\n'
107120
'Linear operator (A): %s\n'
108121
'Additional vector (z): %s\n'
109-
'tau = %10e\tmu = %10e\ntheta = %.2f\t\tniter = %d\n' %
122+
'tau = %s\t\tmu = %s\ntheta = %.2f\t\tniter = %d\n' %
110123
(type(proxf), type(proxg), type(A),
111-
None if z is None else 'vector', tau, mu, theta, niter))
124+
None if z is None else 'vector', str(tau[0]) if fixedtau else 'Variable',
125+
str(mu[0]) if fixedmu else 'Variable', theta, niter))
112126
head = ' Itn x[0] f g z^x J = f + g + z^x'
113127
print(head)
114128

@@ -119,24 +133,26 @@ def PrimalDual(proxf, proxg, A, x0, tau, mu, z=None, theta=1., niter=10,
119133
for iiter in range(niter):
120134
xold = x.copy()
121135
if gfirst:
122-
y = proxg.proxdual(y + mu * A.matvec(xhat), mu)
136+
y = proxg.proxdual(y + mu[iiter] * A.matvec(xhat), mu[iiter])
123137
ATy = A.rmatvec(y)
124138
if z is not None:
125139
ATy += z
126-
x = proxf.prox(x - tau * ATy, tau)
140+
x = proxf.prox(x - tau[iiter] * ATy, tau[iiter])
127141
xhat = x + theta * (x - xold)
128142
else:
129143
ATy = A.rmatvec(y)
130144
if z is not None:
131145
ATy += z
132-
x = proxf.prox(x - tau * ATy, tau)
146+
x = proxf.prox(x - tau[iiter] * ATy, tau[iiter])
133147
xhat = x + theta * (x - xold)
134-
y = proxg.proxdual(y + mu * A.matvec(xhat), mu)
148+
y = proxg.proxdual(y + mu[iiter] * A.matvec(xhat), mu[iiter])
135149

136150
# run callback
137151
if callback is not None:
138-
callback(x)
139-
152+
if callbacky:
153+
callback(x, y)
154+
else:
155+
callback(x)
140156
if show:
141157
if iiter < 10 or niter - iiter < 10 or iiter % (niter // 10) == 0:
142158
pf, pg = proxf(x), proxg(A.matvec(x))

pyproximal/projection/L1.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,4 +40,8 @@ def __init__(self, n, radius, maxiter=100, xtol=1e-5):
4040
self.simplex = SimplexProj(n, radius, maxiter, xtol)
4141

4242
def __call__(self, x):
43-
return np.sign(x) * self.simplex(np.abs(x))
43+
if np.iscomplexobj(x):
44+
return np.exp(1j * np.angle(x)) * self.simplex(np.abs(x))
45+
else:
46+
return np.sign(x) * self.simplex(np.abs(x))
47+

0 commit comments

Comments
 (0)