11import numpy as np
22
33from pyproximal .ProxOperator import _check_tau
4- from pyproximal .projection import L0BallProj
4+ from pyproximal .projection import L0BallProj , L01BallProj
55from pyproximal import ProxOperator
66from pyproximal .proximal .L1 import _current_sigma
77
@@ -35,7 +35,7 @@ def _hardthreshold(x, thresh):
3535
3636
3737class L0 (ProxOperator ):
38- r"""L0 norm proximal operator.
38+ r""":math:`L_0` norm proximal operator.
3939
4040 Proximal operator of the :math:`\ell_0` norm:
4141 :math:`\sigma\|\mathbf{x}\|_0 = \text{count}(x_i \ne 0)`.
@@ -92,7 +92,7 @@ def prox(self, x, tau):
9292
9393
9494class L0Ball (ProxOperator ):
95- r"""L0 ball proximal operator.
95+ r""":math:`L_0` ball proximal operator.
9696
9797 Proximal operator of the L0 ball: :math:`L0_{r} =
9898 \{ \mathbf{x}: ||\mathbf{x}||_0 \leq r \}`.
@@ -103,7 +103,6 @@ class L0Ball(ProxOperator):
103103 Radius. This can be a constant number or a function that is called passing a
104104 counter which keeps track of how many times the ``prox`` method has been
105105 invoked before and returns a scalar ``radius`` to be used.
106- Radius
107106
108107 Notes
109108 -----
@@ -136,4 +135,61 @@ def prox(self, x, tau):
136135 radius = _current_sigma (self .radius , self .count )
137136 self .ball .radius = radius
138137 y = self .ball (x )
139- return y
138+ return y
139+
140+
141+ class L01Ball (ProxOperator ):
142+ r""":math:`L_{0,1}` ball proximal operator.
143+
144+ Proximal operator of the :math:`L_{0,1}` ball: :math:`L_{0,1}^{r} =
145+ \{ \mathbf{x}: \text{count}([||\mathbf{x}_1||_1, ||\mathbf{x}_2||_1, ...,
146+ ||\mathbf{x}_1||_1] \ne 0) \leq r \}`
147+
148+ Parameters
149+ ----------
150+ ndim : :obj:`int`
151+ Number of dimensions :math:`N_{dim}`. Used to reshape the input array
152+ in a matrix of size :math:`N_{dim} \times N'_{x}` where
153+ :math:`N'_x = \frac{N_x}{N_{dim}}`. Note that the input
154+ vector ``x`` should be created by stacking vectors from different
155+ dimensions.
156+ radius : :obj:`int` or :obj:`func`, optional
157+ Radius. This can be a constant number or a function that is called passing a
158+ counter which keeps track of how many times the ``prox`` method has been
159+ invoked before and returns a scalar ``radius`` to be used.
160+
161+ Notes
162+ -----
163+ As the L0 ball is an indicator function, the proximal operator
164+ corresponds to its orthogonal projection
165+ (see :class:`pyproximal.projection.L01BallProj` for details.
166+
167+ """
168+ def __init__ (self , ndim , radius ):
169+ super ().__init__ (None , False )
170+ self .ndim = ndim
171+ self .radius = radius
172+ self .ball = L01BallProj (self .radius if not callable (radius ) else radius (0 ))
173+ self .count = 0
174+
175+ def __call__ (self , x , tol = 1e-4 ):
176+ x = x .reshape (self .ndim , len (x ) // self .ndim )
177+ radius = _current_sigma (self .radius , self .count )
178+ return np .linalg .norm (np .linalg .norm (x , ord = 1 , axis = 0 ), ord = 0 ) <= radius
179+
180+ def _increment_count (func ):
181+ """Increment counter
182+ """
183+ def wrapped (self , * args , ** kwargs ):
184+ self .count += 1
185+ return func (self , * args , ** kwargs )
186+ return wrapped
187+
188+ @_increment_count
189+ @_check_tau
190+ def prox (self , x , tau ):
191+ x = x .reshape (self .ndim , len (x ) // self .ndim )
192+ radius = _current_sigma (self .radius , self .count )
193+ self .ball .radius = radius
194+ y = self .ball (x )
195+ return y .ravel ()
0 commit comments