Skip to content

Commit a7e89f1

Browse files
Merge branch 'main' of https://github.com/reivilo3/pyproximal into main
2 parents 80e40f2 + 9c709f9 commit a7e89f1

File tree

8 files changed

+260
-31
lines changed

8 files changed

+260
-31
lines changed

docs/source/api/index.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ Convex
6464
EuclideanBall
6565
Huber
6666
Intersection
67+
L0
6768
L0Ball
6869
L1
6970
L1Ball
@@ -137,6 +138,7 @@ Primal
137138
AcceleratedProximalGradient
138139
ADMM
139140
ADMML2
141+
HQS
140142
LinearizedADMM
141143
ProximalGradient
142144
GeneralizedProximalGradient

pyproximal/optimization/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,9 @@
1111
ProximalPoint Proximal point algorithm (or proximal min.)
1212
ProximalGradient Proximal gradient algorithm
1313
AcceleratedProximalGradient Accelerated Proximal gradient algorithm
14+
HQS Half Quadrating Splitting
1415
ADMM Alternating Direction Method of Multipliers
15-
ADMML2 ADMM with L2 misfit term
16+
ADMML2 ADMM with L2 misfit term
1617
LinearizedADMM Linearized ADMM
1718
TwIST Two-step Iterative Shrinkage/Threshold
1819
PlugAndPlay Plug-and-Play Prior with ADMM

pyproximal/optimization/primal.py

Lines changed: 125 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -406,7 +406,115 @@ def GeneralizedProximalGradient(proxfs, proxgs, x0, tau=None, beta=0.5,
406406
return x
407407

408408

409-
def ADMM(proxf, proxg, x0, tau, niter=10, callback=None, show=False):
409+
def HQS(proxf, proxg, x0, tau, niter=10, gfirst=True,
410+
callback=None, callbackz=False, show=False):
411+
r"""Half Quadratic splitting
412+
413+
Solves the following minimization problem using Half Quadratic splitting
414+
algorithm:
415+
416+
.. math::
417+
418+
\mathbf{x},\mathbf{z} = \argmin_{\mathbf{x},\mathbf{z}}
419+
f(\mathbf{x}) + g(\mathbf{z}) \\
420+
s.t. \; \mathbf{x}=\mathbf{z}
421+
422+
where :math:`f(\mathbf{x})` and :math:`g(\mathbf{z})` are any convex
423+
function that has a known proximal operator.
424+
425+
Parameters
426+
----------
427+
proxf : :obj:`pyproximal.ProxOperator`
428+
Proximal operator of f function
429+
proxg : :obj:`pyproximal.ProxOperator`
430+
Proximal operator of g function
431+
x0 : :obj:`numpy.ndarray`
432+
Initial vector
433+
tau : :obj:`float`, optional
434+
Positive scalar weight, which should satisfy the following condition
435+
to guarantees convergence: :math:`\tau \in (0, 1/L]` where ``L`` is
436+
the Lipschitz constant of :math:`\nabla f`.
437+
niter : :obj:`int`, optional
438+
Number of iterations of iterative scheme
439+
gfirst : :obj:`bool`, optional
440+
Apply Proximal of operator ``g`` first (``True``) or Proximal of
441+
operator ``f`` first (``False``)
442+
callback : :obj:`callable`, optional
443+
Function with signature (``callback(x)``) to call after each iteration
444+
where ``x`` is the current model vector
445+
callbackz : :obj:`bool`, optional
446+
Modify callback signature to (``callback(x, z)``) when ``callbackz=True``
447+
show : :obj:`bool`, optional
448+
Display iterations log
449+
450+
Returns
451+
-------
452+
x : :obj:`numpy.ndarray`
453+
Inverted model
454+
z : :obj:`numpy.ndarray`
455+
Inverted second model
456+
457+
Notes
458+
-----
459+
The HQS algorithm can be expressed by the following recursion [1]_:
460+
461+
.. math::
462+
463+
\mathbf{z}^{k+1} = \prox_{\tau g}(\mathbf{x}^{k})
464+
\mathbf{x}^{k+1} = \prox_{\tau f}(\mathbf{z}^{k+1})\\
465+
466+
Note that ``x`` and ``z`` converge to each other, however if iterations are
467+
stopped too early ``x`` is guaranteed to belong to the domain of ``f``
468+
while ``z`` is guaranteed to belong to the domain of ``g``. Depending on
469+
the problem either of the two may be the best solution.
470+
471+
.. [1] D., Geman, and C., Yang, "Nonlinear image recovery with halfquadratic
472+
regularization", IEEE Transactions on Image Processing,
473+
4, 7, pp. 932-946, 1995.
474+
475+
"""
476+
if show:
477+
tstart = time.time()
478+
print('HQS\n'
479+
'---------------------------------------------------------\n'
480+
'Proximal operator (f): %s\n'
481+
'Proximal operator (g): %s\n'
482+
'tau = %10e\tniter = %d\n' % (type(proxf), type(proxg),
483+
tau, niter))
484+
head = ' Itn x[0] f g J = f + g'
485+
print(head)
486+
487+
x = x0.copy()
488+
z = np.zeros_like(x)
489+
for iiter in range(niter):
490+
if gfirst:
491+
z = proxg.prox(x, tau)
492+
x = proxf.prox(z, tau)
493+
else:
494+
x = proxf.prox(z, tau)
495+
z = proxg.prox(x, tau)
496+
497+
# run callback
498+
if callback is not None:
499+
if callbackz:
500+
callback(x, z)
501+
else:
502+
callback(x)
503+
504+
if show:
505+
if iiter < 10 or niter - iiter < 10 or iiter % (niter // 10) == 0:
506+
pf, pg = proxf(x), proxg(x)
507+
msg = '%6g %12.5e %10.3e %10.3e %10.3e' % \
508+
(iiter + 1, x[0], pf, pg, pf + pg)
509+
print(msg)
510+
if show:
511+
print('\nTotal time (s) = %.2f' % (time.time() - tstart))
512+
print('---------------------------------------------------------\n')
513+
return x, z
514+
515+
516+
def ADMM(proxf, proxg, x0, tau, niter=10, gfirst=False,
517+
callback=None, callbackz=False, show=False):
410518
r"""Alternating Direction Method of Multipliers
411519
412520
Solves the following minimization problem using Alternating Direction
@@ -451,9 +559,14 @@ def ADMM(proxf, proxg, x0, tau, niter=10, callback=None, show=False):
451559
the Lipschitz constant of :math:`\nabla f`.
452560
niter : :obj:`int`, optional
453561
Number of iterations of iterative scheme
562+
gfirst : :obj:`bool`, optional
563+
Apply Proximal of operator ``g`` first (``True``) or Proximal of
564+
operator ``f`` first (``False``)
454565
callback : :obj:`callable`, optional
455566
Function with signature (``callback(x)``) to call after each iteration
456567
where ``x`` is the current model vector
568+
callbackz : :obj:`bool`, optional
569+
Modify callback signature to (``callback(x, z)``) when ``callbackz=True``
457570
show : :obj:`bool`, optional
458571
Display iterations log
459572
@@ -479,7 +592,7 @@ def ADMM(proxf, proxg, x0, tau, niter=10, callback=None, show=False):
479592
\mathbf{z}^{k+1} = \prox_{\tau g}(\mathbf{x}^{k+1} + \mathbf{u}^{k})\\
480593
\mathbf{u}^{k+1} = \mathbf{u}^{k} + \mathbf{x}^{k+1} - \mathbf{z}^{k+1}
481594
482-
Note that ``x`` and ``z`` converge to each other, but if iterations are
595+
Note that ``x`` and ``z`` converge to each other, however if iterations are
483596
stopped too early ``x`` is guaranteed to belong to the domain of ``f``
484597
while ``z`` is guaranteed to belong to the domain of ``g``. Depending on
485598
the problem either of the two may be the best solution.
@@ -499,14 +612,20 @@ def ADMM(proxf, proxg, x0, tau, niter=10, callback=None, show=False):
499612
x = x0.copy()
500613
u = z = np.zeros_like(x)
501614
for iiter in range(niter):
502-
x = proxf.prox(z - u, tau)
503-
z = proxg.prox(x + u, tau)
615+
if gfirst:
616+
z = proxg.prox(x + u, tau)
617+
x = proxf.prox(z - u, tau)
618+
else:
619+
x = proxf.prox(z - u, tau)
620+
z = proxg.prox(x + u, tau)
504621
u = u + x - z
505622

506623
# run callback
507624
if callback is not None:
508-
callback(x)
509-
625+
if callbackz:
626+
callback(x, z)
627+
else:
628+
callback(x)
510629
if show:
511630
if iiter < 10 or niter - iiter < 10 or iiter % (niter // 10) == 0:
512631
pf, pg = proxf(x), proxg(x)

pyproximal/optimization/primaldual.py

Lines changed: 31 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44

55
def PrimalDual(proxf, proxg, A, x0, tau, mu, z=None, theta=1., niter=10,
6-
gfirst=True, callback=None, show=False):
6+
gfirst=True, callback=None, callbacky=False, show=False):
77
r"""Primal-dual algorithm
88
99
Solves the following (possibly) nonlinear minimization problem using
@@ -39,10 +39,12 @@ def PrimalDual(proxf, proxg, A, x0, tau, mu, z=None, theta=1., niter=10,
3939
Linear operator of g
4040
x0 : :obj:`numpy.ndarray`
4141
Initial vector
42-
tau : :obj:`float`
43-
Stepsize of subgradient of :math:`f`
44-
mu : :obj:`float`
45-
Stepsize of subgradient of :math:`g^*`
42+
tau : :obj:`float` or :obj:`np.ndarray`
43+
Stepsize of subgradient of :math:`f`. This can be constant
44+
or function of iterations (in the latter cases provided as np.ndarray)
45+
mu : :obj:`float` or :obj:`np.ndarray`
46+
Stepsize of subgradient of :math:`g^*`. This can be constant
47+
or function of iterations (in the latter cases provided as np.ndarray)
4648
z : :obj:`numpy.ndarray`, optional
4749
Additional vector
4850
theta : :obj:`float`
@@ -58,6 +60,8 @@ def PrimalDual(proxf, proxg, A, x0, tau, mu, z=None, theta=1., niter=10,
5860
callback : :obj:`callable`, optional
5961
Function with signature (``callback(x)``) to call after each iteration
6062
where ``x`` is the current model vector
63+
callbacky : :obj:`bool`, optional
64+
Modify callback signature to (``callback(x, y)``) when ``callbacky=True``
6165
show : :obj:`bool`, optional
6266
Display iterations log
6367
@@ -93,11 +97,20 @@ def PrimalDual(proxf, proxg, A, x0, tau, mu, z=None, theta=1., niter=10,
9397
\mathbf{y}^{k+1} = \prox_{\mu g^*}(\mathbf{y}^{k} +
9498
\mu \mathbf{A}\bar{\mathbf{x}}^{k+1})
9599
96-
.. [1] A., Chambolle, and T., Pock, "A first-order primal-dual algorithm for
100+
.. [1] A., Chambolle, and T., Pock, "A first-order primal-dual algorithm for
97101
convex problems with applications to imaging", Journal of Mathematical
98-
Imaging and Vision, 40, 8pp. 120145. 2011.
102+
Imaging and Vision, 40, 8pp. 120-145. 2011.
99103
100104
"""
105+
# check if tau and mu are scalars or arrays
106+
fixedtau = fixedmu = False
107+
if isinstance(tau, (int, float)):
108+
tau = tau * np.ones(niter)
109+
fixedtau = True
110+
if isinstance(mu, (int, float)):
111+
mu = mu * np.ones(niter)
112+
fixedmu = True
113+
101114
if show:
102115
tstart = time.time()
103116
print('Primal-dual: min_x f(Ax) + x^T z + g(x)\n'
@@ -106,9 +119,10 @@ def PrimalDual(proxf, proxg, A, x0, tau, mu, z=None, theta=1., niter=10,
106119
'Proximal operator (g): %s\n'
107120
'Linear operator (A): %s\n'
108121
'Additional vector (z): %s\n'
109-
'tau = %10e\tmu = %10e\ntheta = %.2f\t\tniter = %d\n' %
122+
'tau = %s\t\tmu = %s\ntheta = %.2f\t\tniter = %d\n' %
110123
(type(proxf), type(proxg), type(A),
111-
None if z is None else 'vector', tau, mu, theta, niter))
124+
None if z is None else 'vector', str(tau[0]) if fixedtau else 'Variable',
125+
str(mu[0]) if fixedmu else 'Variable', theta, niter))
112126
head = ' Itn x[0] f g z^x J = f + g + z^x'
113127
print(head)
114128

@@ -119,24 +133,26 @@ def PrimalDual(proxf, proxg, A, x0, tau, mu, z=None, theta=1., niter=10,
119133
for iiter in range(niter):
120134
xold = x.copy()
121135
if gfirst:
122-
y = proxg.proxdual(y + mu * A.matvec(xhat), mu)
136+
y = proxg.proxdual(y + mu[iiter] * A.matvec(xhat), mu[iiter])
123137
ATy = A.rmatvec(y)
124138
if z is not None:
125139
ATy += z
126-
x = proxf.prox(x - tau * ATy, tau)
140+
x = proxf.prox(x - tau[iiter] * ATy, tau[iiter])
127141
xhat = x + theta * (x - xold)
128142
else:
129143
ATy = A.rmatvec(y)
130144
if z is not None:
131145
ATy += z
132-
x = proxf.prox(x - tau * ATy, tau)
146+
x = proxf.prox(x - tau[iiter] * ATy, tau[iiter])
133147
xhat = x + theta * (x - xold)
134-
y = proxg.proxdual(y + mu * A.matvec(xhat), mu)
148+
y = proxg.proxdual(y + mu[iiter] * A.matvec(xhat), mu[iiter])
135149

136150
# run callback
137151
if callback is not None:
138-
callback(x)
139-
152+
if callbacky:
153+
callback(x, y)
154+
else:
155+
callback(x)
140156
if show:
141157
if iiter < 10 or niter - iiter < 10 or iiter % (niter // 10) == 0:
142158
pf, pg = proxf(x), proxg(A.matvec(x))

pyproximal/projection/L1.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,4 +40,8 @@ def __init__(self, n, radius, maxiter=100, xtol=1e-5):
4040
self.simplex = SimplexProj(n, radius, maxiter, xtol)
4141

4242
def __call__(self, x):
43-
return np.sign(x) * self.simplex(np.abs(x))
43+
if np.iscomplexobj(x):
44+
return np.exp(1j * np.angle(x)) * self.simplex(np.abs(x))
45+
else:
46+
return np.sign(x) * self.simplex(np.abs(x))
47+

0 commit comments

Comments
 (0)