Skip to content

Commit 2feec64

Browse files
committed
Added L0 proximal operator
1 parent b02a92a commit 2feec64

File tree

4 files changed

+94
-3
lines changed

4 files changed

+94
-3
lines changed

pyproximal/projection/L1.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,4 +40,8 @@ def __init__(self, n, radius, maxiter=100, xtol=1e-5):
4040
self.simplex = SimplexProj(n, radius, maxiter, xtol)
4141

4242
def __call__(self, x):
43-
return np.sign(x) * self.simplex(np.abs(x))
43+
if np.iscomplexobj(x):
44+
return np.exp(1j * np.angle(x)) * self.simplex(np.abs(x))
45+
else:
46+
return np.sign(x) * self.simplex(np.abs(x))
47+

pyproximal/proximal/L0.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,92 @@
33
from pyproximal.ProxOperator import _check_tau
44
from pyproximal.projection import L0BallProj
55
from pyproximal import ProxOperator
6+
from pyproximal.proximal.L1 import _current_sigma
7+
8+
9+
def _hardthreshold(x, thresh):
10+
r"""Hard thresholding.
11+
12+
Applies hard thresholding to vector ``x`` (equal to the proximity
13+
operator for :math:`\|\mathbf{x}\|_0`) as shown in [1]_.
14+
15+
.. [1] Chen, F., Shen, L., Suter, B.W., "Computing the proximity
16+
operator of the Lp norm with 0 < p < 1",
17+
IET Signal Processing, 10, 2016.
18+
19+
Parameters
20+
----------
21+
x : :obj:`numpy.ndarray`
22+
Vector
23+
thresh : :obj:`float`
24+
Threshold
25+
26+
Returns
27+
-------
28+
x1 : :obj:`numpy.ndarray`
29+
Tresholded vector
30+
31+
"""
32+
x1 = x.copy()
33+
x1[np.abs(x) <= thresh] = 0
34+
return x1
35+
36+
37+
class L0(ProxOperator):
38+
r"""L0 norm proximal operator.
39+
40+
Proximal operator of the :math:`\ell_0` norm:
41+
:math:`\sigma\|\mathbf{x}\|_0 = \text{count}(x_i \ne 0)`.
42+
43+
Parameters
44+
----------
45+
sigma : :obj:`float` or :obj:`list` or :obj:`np.ndarray` or :obj:`func`, optional
46+
Multiplicative coefficient of L1 norm. This can be a constant number, a list
47+
of values (for multidimensional inputs, acting on the second dimension) or
48+
a function that is called passing a counter which keeps track of how many
49+
times the ``prox`` method has been invoked before and returns a scalar (or a list of)
50+
``sigma`` to be used.
51+
52+
Notes
53+
-----
54+
The :math:`\ell_0` proximal operator is defined as:
55+
56+
.. math::
57+
58+
\prox_{\tau \sigma \|\cdot\|_0}(\mathbf{x}) =
59+
\operatorname{hard}(\mathbf{x}, \tau \sigma) =
60+
\begin{cases}
61+
x_i, & x_i < -\tau \sigma \\
62+
0, & -\tau\sigma \leq x_i \leq \tau\sigma \\
63+
x_i, & x_i > \tau\sigma\\
64+
\end{cases}
65+
66+
where :math:`\operatorname{hard}` is the so-called called *hard thresholding*.
67+
68+
"""
69+
def __init__(self, sigma=1.):
70+
super().__init__(None, False)
71+
self.sigma = sigma
72+
self.count = 0
73+
74+
def __call__(self, x):
75+
sigma = _current_sigma(self.sigma, self.count)
76+
return np.sum(np.abs(x) > sigma)
77+
78+
def _increment_count(func):
79+
"""Increment counter
80+
"""
81+
def wrapped(self, *args, **kwargs):
82+
self.count += 1
83+
return func(self, *args, **kwargs)
84+
return wrapped
85+
86+
@_increment_count
87+
@_check_tau
88+
def prox(self, x, tau):
89+
sigma = _current_sigma(self.sigma, self.count)
90+
x = _hardthreshold(x, tau * sigma)
91+
return x
692

793

894
class L0Ball(ProxOperator):

pyproximal/proximal/L1.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ class L1(ProxOperator):
6767
\operatorname{soft}(\mathbf{x}, \tau \sigma) =
6868
\begin{cases}
6969
x_i + \tau \sigma, & x_i - g_i < -\tau \sigma \\
70-
g_i, & -\sigma \leq x_i - g_i \leq \tau\sigma \\
70+
g_i, & -\tau\sigma \leq x_i - g_i \leq \tau\sigma \\
7171
x_i - \tau\sigma, & x_i - g_i > \tau\sigma\\
7272
\end{cases}
7373

pyproximal/proximal/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
AffineSet Affines set indicator
1111
Quadratic Quadratic function
1212
Nonlinear Nonlinear function
13+
L0 L0 Norm
1314
L0Ball L0 Ball
1415
L1 L1 Norm
1516
L1Ball L1 Ball
@@ -57,7 +58,7 @@
5758
from .SingularValuePenalty import *
5859

5960
__all__ = ['Box', 'Simplex', 'Intersection', 'AffineSet', 'Quadratic',
60-
'Euclidean', 'EuclideanBall', 'L0Ball', 'L1', 'L1Ball', 'L2',
61+
'Euclidean', 'EuclideanBall', 'L0', 'L0Ball', 'L1', 'L1Ball', 'L2',
6162
'L2Convolve', 'L21', 'L21_plus_L1', 'Huber', 'Nuclear',
6263
'NuclearBall', 'Orthogonal', 'VStack', 'Nonlinear', 'SCAD',
6364
'Log', 'ETP', 'Geman', 'QuadraticEnvelopeCard', 'SingularValuePenalty']

0 commit comments

Comments
 (0)