@@ -33,6 +33,13 @@ def _softthreshold(x, thresh):
3333 return x1
3434
3535
36+ def _current_sigma (sigma , count ):
37+ if not callable (sigma ):
38+ return sigma
39+ else :
40+ return sigma (count )
41+
42+
3643class L1 (ProxOperator ):
3744 r"""L1 norm proximal operator.
3845
@@ -41,8 +48,12 @@ class L1(ProxOperator):
4148
4249 Parameters
4350 ----------
44- sigma : :obj:`int`, optional
45- Multiplicative coefficient of L1 norm
51+ sigma : :obj:`float` or :obj:`list` or :obj:`np.ndarray` or :obj:`func`, optional
52+ Multiplicative coefficient of L1 norm. This can be a constant number, a list
53+ of values (for multidimensional inputs, acting on the second dimension) or
54+ a function that is called passing a counter which keeps track of how many
55+ times the ``prox`` method has been invoked before and returns a scalar (or a list of)
56+ ``sigma`` to be used.
4657 g : :obj:`np.ndarray`, optional
4758 Vector to be subtracted
4859
@@ -85,18 +96,33 @@ def __init__(self, sigma=1., g=None):
8596 self .sigma = sigma
8697 self .g = g
8798 self .gdual = 0 if g is None else g
88- self .box = BoxProj (- sigma , sigma )
99+ if not callable (sigma ):
100+ self .box = BoxProj (- sigma , sigma )
101+ else :
102+ self .box = BoxProj (- sigma (0 ), sigma (0 ))
103+ self .count = 0
89104
90105 def __call__ (self , x ):
91- return self .sigma * np .sum (np .abs (x ))
92-
106+ sigma = _current_sigma (self .sigma , self .count )
107+ return sigma * np .sum (np .abs (x ))
108+
109+ def _increment_count (func ):
110+ """Increment counter
111+ """
112+ def wrapped (self , * args , ** kwargs ):
113+ self .count += 1
114+ return func (self , * args , ** kwargs )
115+ return wrapped
116+
117+ @_increment_count
93118 @_check_tau
94119 def prox (self , x , tau ):
120+ sigma = _current_sigma (self .sigma , self .count )
95121 if self .g is None :
96- x = _softthreshold (x , tau * self . sigma )
122+ x = _softthreshold (x , tau * sigma )
97123 else :
98124 # use precomposition property
99- x = _softthreshold (x - self .g , tau * self . sigma ) + self .g
125+ x = _softthreshold (x - self .g , tau * sigma ) + self .g
100126 return x
101127
102128 @_check_tau
0 commit comments