@@ -33,6 +33,13 @@ def _softthreshold(x, thresh):
3333 return x1
3434
3535
36+ def _current_sigma (sigma , count ):
37+ if isinstance (sigma , (int , float )):
38+ return sigma
39+ else :
40+ return sigma (count )
41+
42+
3643class L1 (ProxOperator ):
3744 r"""L1 norm proximal operator.
3845
@@ -42,7 +49,10 @@ class L1(ProxOperator):
4249 Parameters
4350 ----------
4451 sigma : :obj:`int`, optional
45- Multiplicative coefficient of L1 norm
52+ Multiplicative coefficient of L1 norm. This can be a constant number or
53+ a function that is called passing a counter which keeps track of how many
54+ times the ``prox`` method has been invoked before and
55+ returns the ``sigma`` to be used.
4656 g : :obj:`np.ndarray`, optional
4757 Vector to be subtracted
4858
@@ -85,18 +95,33 @@ def __init__(self, sigma=1., g=None):
8595 self .sigma = sigma
8696 self .g = g
8797 self .gdual = 0 if g is None else g
88- self .box = BoxProj (- sigma , sigma )
98+ if isinstance (sigma , (int , float )):
99+ self .box = BoxProj (- sigma , sigma )
100+ else :
101+ self .box = BoxProj (- sigma (0 ), sigma (0 ))
102+ self .count = 0
89103
90104 def __call__ (self , x ):
91- return self .sigma * np .sum (np .abs (x ))
92-
105+ sigma = _current_sigma (self .sigma , self .count )
106+ return sigma * np .sum (np .abs (x ))
107+
108+ def _increment_count (func ):
109+ """Increment counter
110+ """
111+ def wrapped (self , * args , ** kwargs ):
112+ self .count += 1
113+ return func (self , * args , ** kwargs )
114+ return wrapped
115+
116+ @_increment_count
93117 @_check_tau
94118 def prox (self , x , tau ):
119+ sigma = _current_sigma (self .sigma , self .count )
95120 if self .g is None :
96- x = _softthreshold (x , tau * self . sigma )
121+ x = _softthreshold (x , tau * sigma )
97122 else :
98123 # use precomposition property
99- x = _softthreshold (x - self .g , tau * self . sigma ) + self .g
124+ x = _softthreshold (x - self .g , tau * sigma ) + self .g
100125 return x
101126
102127 @_check_tau
0 commit comments