Skip to content

Commit 6941d00

Browse files
committed
minor: added option to pass aux variable to callbacks
1 parent 2feec64 commit 6941d00

File tree

2 files changed

+48
-19
lines changed

2 files changed

+48
-19
lines changed

pyproximal/optimization/primal.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -372,7 +372,8 @@ def AcceleratedProximalGradient(proxf, proxg, x0, tau=None, beta=0.5,
372372
return x
373373

374374

375-
def HQS(proxf, proxg, x0, tau, niter=10, gfirst=True, callback=None, show=False):
375+
def HQS(proxf, proxg, x0, tau, niter=10, gfirst=True,
376+
callback=None, callbackz=False, show=False):
376377
r"""Half Quadratic splitting
377378
378379
Solves the following minimization problem using Half Quadratic splitting
@@ -396,7 +397,9 @@ def HQS(proxf, proxg, x0, tau, niter=10, gfirst=True, callback=None, show=False)
396397
x0 : :obj:`numpy.ndarray`
397398
Initial vector
398399
tau : :obj:`float`, optional
399-
CHECK!!!
400+
Positive scalar weight, which should satisfy the following condition
401+
to guarantees convergence: :math:`\tau \in (0, 1/L]` where ``L`` is
402+
the Lipschitz constant of :math:`\nabla f`.
400403
niter : :obj:`int`, optional
401404
Number of iterations of iterative scheme
402405
gfirst : :obj:`bool`, optional
@@ -405,6 +408,8 @@ def HQS(proxf, proxg, x0, tau, niter=10, gfirst=True, callback=None, show=False)
405408
callback : :obj:`callable`, optional
406409
Function with signature (``callback(x)``) to call after each iteration
407410
where ``x`` is the current model vector
411+
callbackz : :obj:`bool`, optional
412+
Modify callback signature to (``callback(x, z)``) when ``callbackz=True``
408413
show : :obj:`bool`, optional
409414
Display iterations log
410415
@@ -457,7 +462,10 @@ def HQS(proxf, proxg, x0, tau, niter=10, gfirst=True, callback=None, show=False)
457462

458463
# run callback
459464
if callback is not None:
460-
callback(x)
465+
if callbackz:
466+
callback(x, z)
467+
else:
468+
callback(x)
461469

462470
if show:
463471
if iiter < 10 or niter - iiter < 10 or iiter % (niter // 10) == 0:
@@ -471,7 +479,8 @@ def HQS(proxf, proxg, x0, tau, niter=10, gfirst=True, callback=None, show=False)
471479
return x, z
472480

473481

474-
def ADMM(proxf, proxg, x0, tau, niter=10, gfirst=False, callback=None, show=False):
482+
def ADMM(proxf, proxg, x0, tau, niter=10, gfirst=False,
483+
callback=None, callbackz=False, show=False):
475484
r"""Alternating Direction Method of Multipliers
476485
477486
Solves the following minimization problem using Alternating Direction
@@ -522,6 +531,8 @@ def ADMM(proxf, proxg, x0, tau, niter=10, gfirst=False, callback=None, show=Fals
522531
callback : :obj:`callable`, optional
523532
Function with signature (``callback(x)``) to call after each iteration
524533
where ``x`` is the current model vector
534+
callbackz : :obj:`bool`, optional
535+
Modify callback signature to (``callback(x, z)``) when ``callbackz=True``
525536
show : :obj:`bool`, optional
526537
Display iterations log
527538
@@ -577,8 +588,10 @@ def ADMM(proxf, proxg, x0, tau, niter=10, gfirst=False, callback=None, show=Fals
577588

578589
# run callback
579590
if callback is not None:
580-
callback(x)
581-
591+
if callbackz:
592+
callback(x, z)
593+
else:
594+
callback(x)
582595
if show:
583596
if iiter < 10 or niter - iiter < 10 or iiter % (niter // 10) == 0:
584597
pf, pg = proxf(x), proxg(x)

pyproximal/optimization/primaldual.py

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44

55
def PrimalDual(proxf, proxg, A, x0, tau, mu, z=None, theta=1., niter=10,
6-
gfirst=True, callback=None, show=False):
6+
gfirst=True, callback=None, callbacky=False, show=False):
77
r"""Primal-dual algorithm
88
99
Solves the following (possibly) nonlinear minimization problem using
@@ -39,10 +39,12 @@ def PrimalDual(proxf, proxg, A, x0, tau, mu, z=None, theta=1., niter=10,
3939
Linear operator of g
4040
x0 : :obj:`numpy.ndarray`
4141
Initial vector
42-
tau : :obj:`float`
43-
Stepsize of subgradient of :math:`f`
44-
mu : :obj:`float`
45-
Stepsize of subgradient of :math:`g^*`
42+
tau : :obj:`float` or :obj:`np.ndarray`
43+
Stepsize of subgradient of :math:`f`. This can be constant
44+
or function of iterations (in the latter cases provided as np.ndarray)
45+
mu : :obj:`float` or :obj:`np.ndarray`
46+
Stepsize of subgradient of :math:`g^*`. This can be constant
47+
or function of iterations (in the latter cases provided as np.ndarray)
4648
z : :obj:`numpy.ndarray`, optional
4749
Additional vector
4850
theta : :obj:`float`
@@ -58,6 +60,8 @@ def PrimalDual(proxf, proxg, A, x0, tau, mu, z=None, theta=1., niter=10,
5860
callback : :obj:`callable`, optional
5961
Function with signature (``callback(x)``) to call after each iteration
6062
where ``x`` is the current model vector
63+
callbacky : :obj:`bool`, optional
64+
Modify callback signature to (``callback(x, y)``) when ``callbacky=True``
6165
show : :obj:`bool`, optional
6266
Display iterations log
6367
@@ -98,6 +102,15 @@ def PrimalDual(proxf, proxg, A, x0, tau, mu, z=None, theta=1., niter=10,
98102
Imaging and Vision, 40, 8pp. 120-145. 2011.
99103
100104
"""
105+
# check if tau and mu are scalars or arrays
106+
fixedtau = fixedmu = False
107+
if isinstance(tau, (int, float)):
108+
tau = tau * np.ones(niter)
109+
fixedtau = True
110+
if isinstance(mu, (int, float)):
111+
mu = mu * np.ones(niter)
112+
fixedmu = True
113+
101114
if show:
102115
tstart = time.time()
103116
print('Primal-dual: min_x f(Ax) + x^T z + g(x)\n'
@@ -106,9 +119,10 @@ def PrimalDual(proxf, proxg, A, x0, tau, mu, z=None, theta=1., niter=10,
106119
'Proximal operator (g): %s\n'
107120
'Linear operator (A): %s\n'
108121
'Additional vector (z): %s\n'
109-
'tau = %10e\tmu = %10e\ntheta = %.2f\t\tniter = %d\n' %
122+
'tau = %s\t\tmu = %s\ntheta = %.2f\t\tniter = %d\n' %
110123
(type(proxf), type(proxg), type(A),
111-
None if z is None else 'vector', tau, mu, theta, niter))
124+
None if z is None else 'vector', str(tau[0]) if fixedtau else 'Variable',
125+
str(mu[0]) if fixedmu else 'Variable', theta, niter))
112126
head = ' Itn x[0] f g z^x J = f + g + z^x'
113127
print(head)
114128

@@ -119,24 +133,26 @@ def PrimalDual(proxf, proxg, A, x0, tau, mu, z=None, theta=1., niter=10,
119133
for iiter in range(niter):
120134
xold = x.copy()
121135
if gfirst:
122-
y = proxg.proxdual(y + mu * A.matvec(xhat), mu)
136+
y = proxg.proxdual(y + mu[iiter] * A.matvec(xhat), mu[iiter])
123137
ATy = A.rmatvec(y)
124138
if z is not None:
125139
ATy += z
126-
x = proxf.prox(x - tau * ATy, tau)
140+
x = proxf.prox(x - tau[iiter] * ATy, tau[iiter])
127141
xhat = x + theta * (x - xold)
128142
else:
129143
ATy = A.rmatvec(y)
130144
if z is not None:
131145
ATy += z
132-
x = proxf.prox(x - tau * ATy, tau)
146+
x = proxf.prox(x - tau[iiter] * ATy, tau[iiter])
133147
xhat = x + theta * (x - xold)
134-
y = proxg.proxdual(y + mu * A.matvec(xhat), mu)
148+
y = proxg.proxdual(y + mu[iiter] * A.matvec(xhat), mu[iiter])
135149

136150
# run callback
137151
if callback is not None:
138-
callback(x)
139-
152+
if callbacky:
153+
callback(x, y)
154+
else:
155+
callback(x)
140156
if show:
141157
if iiter < 10 or niter - iiter < 10 or iiter % (niter // 10) == 0:
142158
pf, pg = proxf(x), proxg(A.matvec(x))

0 commit comments

Comments
 (0)