Skip to content

Commit b19399c

Browse files
authored
Feat: l01ball (#154)
* feat: added L01Ball operator
1 parent 330bcea commit b19399c

File tree

8 files changed

+286
-12
lines changed

8 files changed

+286
-12
lines changed

docs/source/api/index.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ Orthogonal projections
2424
HyperPlaneBoxProj
2525
IntersectionProj
2626
L0BallProj
27+
L01BallProj
2728
L1BallProj
2829
NuclearBallProj
2930
SimplexProj
@@ -68,6 +69,7 @@ Convex
6869
Intersection
6970
L0
7071
L0Ball
72+
L01Ball
7173
L1
7274
L1Ball
7375
L2

pyproximal/projection/L0.py

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
import numpy as np
2-
from pyproximal.projection import SimplexProj
32

43

54
class L0BallProj():
6-
r"""L0 ball projection.
5+
r""":math:`L_0` ball projection.
76
87
Parameters
98
----------
@@ -32,4 +31,40 @@ def __call__(self, x):
3231
xshape = x.shape
3332
xf = x.copy().flatten()
3433
xf[np.argsort(np.abs(xf))[:-self.radius]] = 0
35-
return xf.reshape(xshape)
34+
return xf.reshape(xshape)
35+
36+
37+
class L01BallProj():
38+
r""":math:`L_{0,1}` ball projection.
39+
40+
Parameters
41+
----------
42+
radius : :obj:`int`
43+
Radius
44+
45+
Notes
46+
-----
47+
Given an :math:`L_{0,1}` ball defined as:
48+
49+
.. math::
50+
51+
L_{0,1}^{r} =
52+
\{ \mathbf{x}: \text{count}([||\mathbf{x}_1||_1,
53+
||\mathbf{x}_2||_1, ..., ||\mathbf{x}_1||_1] \ne 0) \leq r \}
54+
55+
its orthogonal projection is computed by finding the :math:`r` highest
56+
largest entries of a vector obtained by applying the :math:`L_1` norm to each
57+
column of a matrix :math:`\mathbf{x}` (in absolute value), keeping those
58+
and zero-ing all the other entries.
59+
Note that this is the proximal operator of the corresponding
60+
indicator function :math:`\mathcal{I}_{L_{0,1}^{r}}`.
61+
62+
"""
63+
def __init__(self, radius):
64+
self.radius = int(radius)
65+
66+
def __call__(self, x):
67+
xc = x.copy()
68+
xf = np.linalg.norm(x, axis=0, ord=1)
69+
xc[:, np.argsort(np.abs(xf))[:-self.radius]] = 0
70+
return xc

pyproximal/projection/L1.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44

55
class L1BallProj():
6-
r"""L1 ball projection.
6+
r""":math:`L_1` ball projection.
77
88
Parameters
99
----------

pyproximal/projection/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
HyperPlaneBoxProj Projection onto an intersection beween a HyperPlane and a Box
99
SimplexProj Projection onto a Simplex
1010
L0Proj Projection onto an L0 Ball
11+
L01Proj Projection onto an L0,1 Ball
1112
L1Proj Projection onto an L1 Ball
1213
EuclideanBallProj Projection onto an Euclidean Ball
1314
NuclearBallProj Projection onto a Nuclear Ball
@@ -29,5 +30,5 @@
2930

3031

3132
__all__ = ['BoxProj', 'HyperPlaneBoxProj', 'SimplexProj', 'L0BallProj',
32-
'L1BallProj', 'EuclideanBallProj', 'NuclearBallProj',
33+
'L01BallProj', 'L1BallProj', 'EuclideanBallProj', 'NuclearBallProj',
3334
'IntersectionProj', 'AffineSetProj', 'HankelProj']

pyproximal/proximal/L0.py

Lines changed: 61 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import numpy as np
22

33
from pyproximal.ProxOperator import _check_tau
4-
from pyproximal.projection import L0BallProj
4+
from pyproximal.projection import L0BallProj, L01BallProj
55
from pyproximal import ProxOperator
66
from pyproximal.proximal.L1 import _current_sigma
77

@@ -35,7 +35,7 @@ def _hardthreshold(x, thresh):
3535

3636

3737
class L0(ProxOperator):
38-
r"""L0 norm proximal operator.
38+
r""":math:`L_0` norm proximal operator.
3939
4040
Proximal operator of the :math:`\ell_0` norm:
4141
:math:`\sigma\|\mathbf{x}\|_0 = \text{count}(x_i \ne 0)`.
@@ -92,7 +92,7 @@ def prox(self, x, tau):
9292

9393

9494
class L0Ball(ProxOperator):
95-
r"""L0 ball proximal operator.
95+
r""":math:`L_0` ball proximal operator.
9696
9797
Proximal operator of the L0 ball: :math:`L0_{r} =
9898
\{ \mathbf{x}: ||\mathbf{x}||_0 \leq r \}`.
@@ -103,7 +103,6 @@ class L0Ball(ProxOperator):
103103
Radius. This can be a constant number or a function that is called passing a
104104
counter which keeps track of how many times the ``prox`` method has been
105105
invoked before and returns a scalar ``radius`` to be used.
106-
Radius
107106
108107
Notes
109108
-----
@@ -136,4 +135,61 @@ def prox(self, x, tau):
136135
radius = _current_sigma(self.radius, self.count)
137136
self.ball.radius = radius
138137
y = self.ball(x)
139-
return y
138+
return y
139+
140+
141+
class L01Ball(ProxOperator):
142+
r""":math:`L_{0,1}` ball proximal operator.
143+
144+
Proximal operator of the :math:`L_{0,1}` ball: :math:`L_{0,1}^{r} =
145+
\{ \mathbf{x}: \text{count}([||\mathbf{x}_1||_1, ||\mathbf{x}_2||_1, ...,
146+
||\mathbf{x}_1||_1] \ne 0) \leq r \}`
147+
148+
Parameters
149+
----------
150+
ndim : :obj:`int`
151+
Number of dimensions :math:`N_{dim}`. Used to reshape the input array
152+
in a matrix of size :math:`N_{dim} \times N'_{x}` where
153+
:math:`N'_x = \frac{N_x}{N_{dim}}`. Note that the input
154+
vector ``x`` should be created by stacking vectors from different
155+
dimensions.
156+
radius : :obj:`int` or :obj:`func`, optional
157+
Radius. This can be a constant number or a function that is called passing a
158+
counter which keeps track of how many times the ``prox`` method has been
159+
invoked before and returns a scalar ``radius`` to be used.
160+
161+
Notes
162+
-----
163+
As the L0 ball is an indicator function, the proximal operator
164+
corresponds to its orthogonal projection
165+
(see :class:`pyproximal.projection.L01BallProj` for details.
166+
167+
"""
168+
def __init__(self, ndim, radius):
169+
super().__init__(None, False)
170+
self.ndim = ndim
171+
self.radius = radius
172+
self.ball = L01BallProj(self.radius if not callable(radius) else radius(0))
173+
self.count = 0
174+
175+
def __call__(self, x, tol=1e-4):
176+
x = x.reshape(self.ndim, len(x) // self.ndim)
177+
radius = _current_sigma(self.radius, self.count)
178+
return np.linalg.norm(np.linalg.norm(x, ord=1, axis=0), ord=0) <= radius
179+
180+
def _increment_count(func):
181+
"""Increment counter
182+
"""
183+
def wrapped(self, *args, **kwargs):
184+
self.count += 1
185+
return func(self, *args, **kwargs)
186+
return wrapped
187+
188+
@_increment_count
189+
@_check_tau
190+
def prox(self, x, tau):
191+
x = x.reshape(self.ndim, len(x) // self.ndim)
192+
radius = _current_sigma(self.radius, self.count)
193+
self.ball.radius = radius
194+
y = self.ball(x)
195+
return y.ravel()

pyproximal/proximal/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
Nonlinear Nonlinear function
1313
L0 L0 Norm
1414
L0Ball L0 Ball
15+
L01pBall L0,1 Ball
1516
L1 L1 Norm
1617
L1Ball L1 Ball
1718
Euclidean Euclidean Norm
@@ -67,7 +68,7 @@
6768
from .Hankel import *
6869

6970
__all__ = ['Box', 'Simplex', 'Intersection', 'AffineSet', 'Quadratic',
70-
'Euclidean', 'EuclideanBall', 'L0', 'L0Ball', 'L1', 'L1Ball', 'L2',
71+
'Euclidean', 'EuclideanBall', 'L0', 'L0Ball', 'L01Ball', 'L1', 'L1Ball', 'L2',
7172
'L2Convolve', 'L21', 'L21_plus_L1', 'Huber', 'TV', 'RelaxedMumfordShah',
7273
'Nuclear', 'NuclearBall', 'Orthogonal', 'VStack', 'Nonlinear', 'SCAD',
7374
'Log', 'Log1', 'ETP', 'Geman', 'QuadraticEnvelopeCard', 'SingularValuePenalty',

pytests/test_projection.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from numpy.testing import assert_array_almost_equal
55
from pylops.basicoperators import Identity
66
from pyproximal.utils import moreau
7-
from pyproximal.proximal import Box, EuclideanBall, L0Ball, L1Ball, \
7+
from pyproximal.proximal import Box, EuclideanBall, L0Ball, L01Ball, L1Ball, \
88
NuclearBall, Simplex, AffineSet, Hankel
99

1010
par1 = {'nx': 10, 'ny': 8, 'axis': 0, 'dtype': 'float32'} # even float32 dir0
@@ -65,6 +65,25 @@ def test_L0Ball(par):
6565
assert moreau(l0, x, tau)
6666

6767

68+
@pytest.mark.parametrize("par", [(par1), (par2)])
69+
def test_L01Ball(par):
70+
"""L01 Ball projection and proximal/dual proximal of related indicator
71+
"""
72+
np.random.seed(10)
73+
74+
l0 = L01Ball(3, 1)
75+
x = np.random.normal(0., 1., (3, par['nx'])).astype(par['dtype']).ravel() + 1.
76+
77+
# evaluation
78+
assert l0(x) == False
79+
xp = l0.prox(x, 1.)
80+
assert l0(xp) == True
81+
82+
# prox / dualprox
83+
tau = 2.
84+
assert moreau(l0, x, tau)
85+
86+
6887
@pytest.mark.parametrize("par", [(par1), (par2)])
6988
def test_L1Ball(par):
7089
"""L1 Ball projection and proximal/dual proximal of related indicator

0 commit comments

Comments
 (0)