Skip to content

Commit 0835b1a

Browse files
committed
feature: added variable sigma in L1
1 parent 3a83f3c commit 0835b1a

File tree

1 file changed

+31
-6
lines changed

1 file changed

+31
-6
lines changed

pyproximal/proximal/L1.py

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,13 @@ def _softthreshold(x, thresh):
3333
return x1
3434

3535

36+
def _current_sigma(sigma, count):
37+
if isinstance(sigma, (int, float)):
38+
return sigma
39+
else:
40+
return sigma(count)
41+
42+
3643
class L1(ProxOperator):
3744
r"""L1 norm proximal operator.
3845
@@ -42,7 +49,10 @@ class L1(ProxOperator):
4249
Parameters
4350
----------
4451
sigma : :obj:`int`, optional
45-
Multiplicative coefficient of L1 norm
52+
Multiplicative coefficient of L1 norm. This can be a constant number or
53+
a function that is called passing a counter which keeps track of how many
54+
times the ``prox`` method has been invoked before and
55+
returns the ``sigma`` to be used.
4656
g : :obj:`np.ndarray`, optional
4757
Vector to be subtracted
4858
@@ -85,18 +95,33 @@ def __init__(self, sigma=1., g=None):
8595
self.sigma = sigma
8696
self.g = g
8797
self.gdual = 0 if g is None else g
88-
self.box = BoxProj(-sigma, sigma)
98+
if isinstance(sigma, (int, float)):
99+
self.box = BoxProj(-sigma, sigma)
100+
else:
101+
self.box = BoxProj(-sigma(0), sigma(0))
102+
self.count = 0
89103

90104
def __call__(self, x):
91-
return self.sigma * np.sum(np.abs(x))
92-
105+
sigma = _current_sigma(self.sigma, self.count)
106+
return sigma * np.sum(np.abs(x))
107+
108+
def _increment_count(func):
109+
"""Increment counter
110+
"""
111+
def wrapped(self, *args, **kwargs):
112+
self.count += 1
113+
return func(self, *args, **kwargs)
114+
return wrapped
115+
116+
@_increment_count
93117
@_check_tau
94118
def prox(self, x, tau):
119+
sigma = _current_sigma(self.sigma, self.count)
95120
if self.g is None:
96-
x = _softthreshold(x, tau * self.sigma)
121+
x = _softthreshold(x, tau * sigma)
97122
else:
98123
# use precomposition property
99-
x = _softthreshold(x - self.g, tau * self.sigma) + self.g
124+
x = _softthreshold(x - self.g, tau * sigma) + self.g
100125
return x
101126

102127
@_check_tau

0 commit comments

Comments
 (0)