@@ -99,7 +99,10 @@ class L0Ball(ProxOperator):
9999
100100 Parameters
101101 ----------
102- radius : :obj:`float`
102+ radius : :obj:`int` or :obj:`func`, optional
103+ Radius. This can be a constant number or a function that is called passing a
104+ counter which keeps track of how many times the ``prox`` method has been
105+ invoked before and returns a scalar ``radius`` to be used.
103106 Radius
104107
105108 Notes
@@ -112,11 +115,25 @@ class L0Ball(ProxOperator):
112115 def __init__ (self , radius ):
113116 super ().__init__ (None , False )
114117 self .radius = radius
115- self .ball = L0BallProj (self .radius )
118+ self .ball = L0BallProj (self .radius if not callable (radius ) else radius (0 ))
119+ self .count = 0
116120
117121 def __call__ (self , x , tol = 1e-4 ):
118- return np .linalg .norm (np .abs (x ), ord = 0 ) <= self .radius
122+ radius = _current_sigma (self .radius , self .count )
123+ return np .linalg .norm (np .abs (x ), ord = 0 ) <= radius
124+
125+ def _increment_count (func ):
126+ """Increment counter
127+ """
128+ def wrapped (self , * args , ** kwargs ):
129+ self .count += 1
130+ return func (self , * args , ** kwargs )
131+ return wrapped
119132
133+ @_increment_count
120134 @_check_tau
121135 def prox (self , x , tau ):
122- return self .ball (x )
136+ radius = _current_sigma (self .radius , self .count )
137+ self .ball .radius = radius
138+ y = self .ball (x )
139+ return y
0 commit comments