Skip to content

Commit de7a5e8

Browse files
authored
Merge pull request #84 from mrava87/main
Variable sigma in L1
2 parents c0e7ef9 + 0a0aa52 commit de7a5e8

File tree

3 files changed

+44
-13
lines changed

3 files changed

+44
-13
lines changed

pyproximal/optimization/palm.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -77,28 +77,33 @@ def PALM(H, proxf, proxg, x0, y0, gammaf=1., gammag=1.,
7777
'Proximal operator (g): %s\n'
7878
'gammaf = %10e\tgammaf = %10e\tniter = %d\n' %
7979
(type(H), type(proxf), type(proxg), gammaf, gammag, niter))
80-
head = ' Itn x[0] y[0] f g H'
80+
head = ' Itn x[0] y[0] f g H ck dk'
8181
print(head)
8282

8383
x, y = x0.copy(), y0.copy()
8484
for iiter in range(niter):
8585
ck = gammaf * H.ly(y)
8686
x = x - (1 / ck) * H.gradx(x.ravel())
87-
x = proxf.prox(x, ck)
87+
if proxf is not None:
88+
x = proxf.prox(x, ck)
8889
H.updatex(x.copy())
8990
dk = gammag * H.lx(x)
9091
y = y - (1 / dk) * H.grady(y.ravel())
91-
y = proxg.prox(y, dk)
92+
if proxg is not None:
93+
y = proxg.prox(y, dk)
9294
H.updatey(y.copy())
9395

9496
# run callback
9597
if callback is not None:
9698
callback(x, y)
9799

98100
if show:
101+
pf = proxf(x) if proxf is not None else 0.
102+
pg = proxg(y) if proxg is not None else 0.
99103
if iiter < 10 or niter - iiter < 10 or iiter % (niter // 10) == 0:
100-
msg = '%6g %5.5e %5.2e %5.2e %5.2e %5.2e' % \
101-
(iiter + 1, x[0], y[0], proxf(x), proxg(y), H(x, y))
104+
msg = '%6g %5.5e %5.2e %5.2e %5.2e %5.2e %5.2e %5.2e' % \
105+
(iiter + 1, x[0], y[0], pf if pf is not None else 0.,
106+
pg if pg is not None else 0., H(x, y), ck, dk)
102107
print(msg)
103108
if show:
104109
print('\nTotal time (s) = %.2f' % (time.time() - tstart))

pyproximal/proximal/L1.py

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,13 @@ def _softthreshold(x, thresh):
3333
return x1
3434

3535

36+
def _current_sigma(sigma, count):
37+
if not callable(sigma):
38+
return sigma
39+
else:
40+
return sigma(count)
41+
42+
3643
class L1(ProxOperator):
3744
r"""L1 norm proximal operator.
3845
@@ -41,8 +48,12 @@ class L1(ProxOperator):
4148
4249
Parameters
4350
----------
44-
sigma : :obj:`int`, optional
45-
Multiplicative coefficient of L1 norm
51+
sigma : :obj:`float` or :obj:`list` or :obj:`np.ndarray` or :obj:`func`, optional
52+
Multiplicative coefficient of L1 norm. This can be a constant number, a list
53+
of values (for multidimensional inputs, acting on the second dimension) or
54+
a function that is called passing a counter which keeps track of how many
55+
times the ``prox`` method has been invoked before and returns a scalar (or a list of)
56+
``sigma`` to be used.
4657
g : :obj:`np.ndarray`, optional
4758
Vector to be subtracted
4859
@@ -85,18 +96,33 @@ def __init__(self, sigma=1., g=None):
8596
self.sigma = sigma
8697
self.g = g
8798
self.gdual = 0 if g is None else g
88-
self.box = BoxProj(-sigma, sigma)
99+
if not callable(sigma):
100+
self.box = BoxProj(-sigma, sigma)
101+
else:
102+
self.box = BoxProj(-sigma(0), sigma(0))
103+
self.count = 0
89104

90105
def __call__(self, x):
91-
return self.sigma * np.sum(np.abs(x))
92-
106+
sigma = _current_sigma(self.sigma, self.count)
107+
return sigma * np.sum(np.abs(x))
108+
109+
def _increment_count(func):
110+
"""Increment counter
111+
"""
112+
def wrapped(self, *args, **kwargs):
113+
self.count += 1
114+
return func(self, *args, **kwargs)
115+
return wrapped
116+
117+
@_increment_count
93118
@_check_tau
94119
def prox(self, x, tau):
120+
sigma = _current_sigma(self.sigma, self.count)
95121
if self.g is None:
96-
x = _softthreshold(x, tau * self.sigma)
122+
x = _softthreshold(x, tau * sigma)
97123
else:
98124
# use precomposition property
99-
x = _softthreshold(x - self.g, tau * self.sigma) + self.g
125+
x = _softthreshold(x - self.g, tau * sigma) + self.g
100126
return x
101127

102128
@_check_tau

pyproximal/proximal/L2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def __init__(self, Op=None, b=None, q=None, sigma=1., alpha=1.,
9797
self.count = 0
9898

9999
# create data term
100-
if self.Op is not None:
100+
if self.Op is not None and self.b is not None:
101101
self.OpTb = self.sigma * self.Op.H @ self.b
102102
# create A.T A upfront for explicit operators
103103
if self.Op.explicit:

0 commit comments

Comments
 (0)