|
3 | 3 | from pyproximal.ProxOperator import _check_tau |
4 | 4 | from pyproximal.projection import L0BallProj |
5 | 5 | from pyproximal import ProxOperator |
| 6 | +from pyproximal.proximal.L1 import _current_sigma |
| 7 | + |
| 8 | + |
| 9 | +def _hardthreshold(x, thresh): |
| 10 | + r"""Hard thresholding. |
| 11 | +
|
| 12 | + Applies hard thresholding to vector ``x`` (equal to the proximity |
| 13 | + operator for :math:`\|\mathbf{x}\|_0`) as shown in [1]_. |
| 14 | +
|
| 15 | + .. [1] Chen, F., Shen, L., Suter, B.W., "Computing the proximity |
| 16 | + operator of the Lp norm with 0 < p < 1", |
| 17 | + IET Signal Processing, 10, 2016. |
| 18 | +
|
| 19 | + Parameters |
| 20 | + ---------- |
| 21 | + x : :obj:`numpy.ndarray` |
| 22 | + Vector |
| 23 | + thresh : :obj:`float` |
| 24 | + Threshold |
| 25 | +
|
| 26 | + Returns |
| 27 | + ------- |
| 28 | + x1 : :obj:`numpy.ndarray` |
| 29 | + Tresholded vector |
| 30 | +
|
| 31 | + """ |
| 32 | + x1 = x.copy() |
| 33 | + x1[np.abs(x) <= thresh] = 0 |
| 34 | + return x1 |
| 35 | + |
| 36 | + |
| 37 | +class L0(ProxOperator): |
| 38 | + r"""L0 norm proximal operator. |
| 39 | +
|
| 40 | + Proximal operator of the :math:`\ell_0` norm: |
| 41 | + :math:`\sigma\|\mathbf{x}\|_0 = \text{count}(x_i \ne 0)`. |
| 42 | +
|
| 43 | + Parameters |
| 44 | + ---------- |
| 45 | + sigma : :obj:`float` or :obj:`list` or :obj:`np.ndarray` or :obj:`func`, optional |
| 46 | + Multiplicative coefficient of L1 norm. This can be a constant number, a list |
| 47 | + of values (for multidimensional inputs, acting on the second dimension) or |
| 48 | + a function that is called passing a counter which keeps track of how many |
| 49 | + times the ``prox`` method has been invoked before and returns a scalar (or a list of) |
| 50 | + ``sigma`` to be used. |
| 51 | +
|
| 52 | + Notes |
| 53 | + ----- |
| 54 | + The :math:`\ell_0` proximal operator is defined as: |
| 55 | +
|
| 56 | + .. math:: |
| 57 | +
|
| 58 | + \prox_{\tau \sigma \|\cdot\|_0}(\mathbf{x}) = |
| 59 | + \operatorname{hard}(\mathbf{x}, \tau \sigma) = |
| 60 | + \begin{cases} |
| 61 | + x_i, & x_i < -\tau \sigma \\ |
| 62 | + 0, & -\tau\sigma \leq x_i \leq \tau\sigma \\ |
| 63 | + x_i, & x_i > \tau\sigma\\ |
| 64 | + \end{cases} |
| 65 | +
|
| 66 | + where :math:`\operatorname{hard}` is the so-called called *hard thresholding*. |
| 67 | +
|
| 68 | + """ |
| 69 | + def __init__(self, sigma=1.): |
| 70 | + super().__init__(None, False) |
| 71 | + self.sigma = sigma |
| 72 | + self.count = 0 |
| 73 | + |
| 74 | + def __call__(self, x): |
| 75 | + sigma = _current_sigma(self.sigma, self.count) |
| 76 | + return np.sum(np.abs(x) > sigma) |
| 77 | + |
| 78 | + def _increment_count(func): |
| 79 | + """Increment counter |
| 80 | + """ |
| 81 | + def wrapped(self, *args, **kwargs): |
| 82 | + self.count += 1 |
| 83 | + return func(self, *args, **kwargs) |
| 84 | + return wrapped |
| 85 | + |
| 86 | + @_increment_count |
| 87 | + @_check_tau |
| 88 | + def prox(self, x, tau): |
| 89 | + sigma = _current_sigma(self.sigma, self.count) |
| 90 | + x = _hardthreshold(x, tau * sigma) |
| 91 | + return x |
6 | 92 |
|
7 | 93 |
|
8 | 94 | class L0Ball(ProxOperator): |
|
0 commit comments