Skip to content

Commit 45e2f22

Browse files
authored
Merge pull request #100 from mrava87/main
Added variable radius in L0Ball
2 parents 3be1112 + 894c894 commit 45e2f22

File tree

1 file changed

+21
-4
lines changed

1 file changed

+21
-4
lines changed

pyproximal/proximal/L0.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,10 @@ class L0Ball(ProxOperator):
9999
100100
Parameters
101101
----------
102-
radius : :obj:`float`
102+
radius : :obj:`int` or :obj:`func`, optional
103+
Radius. This can be a constant number or a function that is called passing a
104+
counter which keeps track of how many times the ``prox`` method has been
105+
invoked before and returns a scalar ``radius`` to be used.
103106
Radius
104107
105108
Notes
@@ -112,11 +115,25 @@ class L0Ball(ProxOperator):
112115
def __init__(self, radius):
113116
super().__init__(None, False)
114117
self.radius = radius
115-
self.ball = L0BallProj(self.radius)
118+
self.ball = L0BallProj(self.radius if not callable(radius) else radius(0))
119+
self.count = 0
116120

117121
def __call__(self, x, tol=1e-4):
118-
return np.linalg.norm(np.abs(x), ord=0) <= self.radius
122+
radius = _current_sigma(self.radius, self.count)
123+
return np.linalg.norm(np.abs(x), ord=0) <= radius
124+
125+
def _increment_count(func):
126+
"""Increment counter
127+
"""
128+
def wrapped(self, *args, **kwargs):
129+
self.count += 1
130+
return func(self, *args, **kwargs)
131+
return wrapped
119132

133+
@_increment_count
120134
@_check_tau
121135
def prox(self, x, tau):
122-
return self.ball(x)
136+
radius = _current_sigma(self.radius, self.count)
137+
self.ball.radius = radius
138+
y = self.ball(x)
139+
return y

0 commit comments

Comments
 (0)