| 
 | 1 | +import numpy as np  | 
 | 2 | + | 
 | 3 | +from pyproximal.ProxOperator import _check_tau  | 
 | 4 | +from pyproximal import ProxOperator  | 
 | 5 | +from pyproximal.proximal.L1 import _current_sigma  | 
 | 6 | + | 
 | 7 | + | 
 | 8 | +def _l2(x, alpha):  | 
 | 9 | +    r"""Scaling operation.  | 
 | 10 | +
  | 
 | 11 | +    Applies the proximal of ``alpha||y - x||_2^2`` which is essentially a scaling operation.  | 
 | 12 | +
  | 
 | 13 | +    Parameters  | 
 | 14 | +    ----------  | 
 | 15 | +    x : :obj:`numpy.ndarray`  | 
 | 16 | +        Vector  | 
 | 17 | +    alpha : :obj:`float`  | 
 | 18 | +        Scaling parameter  | 
 | 19 | +
  | 
 | 20 | +    Returns  | 
 | 21 | +    -------  | 
 | 22 | +    y : :obj:`numpy.ndarray`  | 
 | 23 | +        Scaled vector  | 
 | 24 | +
  | 
 | 25 | +    """  | 
 | 26 | +    y = 1 / (1 + 2 * alpha) * x  | 
 | 27 | +    return y  | 
 | 28 | + | 
 | 29 | + | 
 | 30 | +def _current_kappa(kappa, count):  | 
 | 31 | +    if not callable(kappa):  | 
 | 32 | +        return kappa  | 
 | 33 | +    else:  | 
 | 34 | +        return kappa(count)  | 
 | 35 | + | 
 | 36 | + | 
 | 37 | +class RelaxedMumfordShah(ProxOperator):  | 
 | 38 | +    r"""Relaxed Mumford-Shah norm proximal operator.  | 
 | 39 | +
  | 
 | 40 | +    Proximal operator of the relaxed Mumford-Shah norm:  | 
 | 41 | +    :math:`\text{rMS}(x) = \min (\alpha\Vert x\Vert_2^2, \kappa)`.  | 
 | 42 | +
  | 
 | 43 | +    Parameters  | 
 | 44 | +    ----------  | 
 | 45 | +    sigma : :obj:`float` or :obj:`list` or :obj:`np.ndarray` or :obj:`func`, optional  | 
 | 46 | +        Multiplicative coefficient of L2 norm that controls the smoothness of the solutuon.  | 
 | 47 | +        This can be a constant number, a list of values (for multidimensional inputs, acting  | 
 | 48 | +        on the second dimension) or a function that is called passing a counter which keeps  | 
 | 49 | +        track of how many times the ``prox`` method has been invoked before and returns a  | 
 | 50 | +        scalar (or a list of) ``sigma`` to be used.  | 
 | 51 | +    kappa : :obj:`float` or :obj:`list` or :obj:`np.ndarray` or :obj:`func`, optional  | 
 | 52 | +        Constant value in the rMS norm which essentially controls when the norm allows a jump. This can be a  | 
 | 53 | +        constant number, a list of values (for multidimensional inputs, acting on the second dimension) or  | 
 | 54 | +        a function that is called passing a counter which keeps track of how many  | 
 | 55 | +        times the ``prox`` method has been invoked before and returns a scalar (or a list of)  | 
 | 56 | +        ``kappa`` to be used.  | 
 | 57 | +
  | 
 | 58 | +    Notes  | 
 | 59 | +    -----  | 
 | 60 | +    The :math:`rMS` proximal operator is defined as [1]_:  | 
 | 61 | +
  | 
 | 62 | +    .. math::  | 
 | 63 | +        \text{prox}_{\tau \text{rMS}}(x) =  | 
 | 64 | +        \begin{cases}  | 
 | 65 | +        \frac{1}{1+2\tau\alpha}x & \text{ if } & \vert x\vert \leq \sqrt{\frac{\kappa}{\alpha}(1 + 2\tau\alpha)} \\  | 
 | 66 | +        \kappa & \text{ else }  | 
 | 67 | +        \end{cases}.  | 
 | 68 | +
  | 
 | 69 | +    .. [1] Strekalovskiy, E., and D. Cremers, 2014, Real-time minimization of the piecewise smooth  | 
 | 70 | +            Mumford-Shah functional: European Conference on Computer Vision, 127–141.  | 
 | 71 | +
  | 
 | 72 | +    """  | 
 | 73 | +    def __init__(self, sigma=1., kappa=1.):  | 
 | 74 | +        super().__init__(None, False)  | 
 | 75 | +        self.sigma = sigma  | 
 | 76 | +        self.kappa = kappa  | 
 | 77 | +        self.count = 0  | 
 | 78 | + | 
 | 79 | +    def __call__(self, x):  | 
 | 80 | +        sigma = _current_sigma(self.sigma, self.count)  | 
 | 81 | +        kappa = _current_sigma(self.kappa, self.count)  | 
 | 82 | +        return np.minimum(sigma * np.linalg.norm(x) ** 2, kappa)  | 
 | 83 | + | 
 | 84 | +    def _increment_count(func):  | 
 | 85 | +        """Increment counter  | 
 | 86 | +        """  | 
 | 87 | +        def wrapped(self, *args, **kwargs):  | 
 | 88 | +            self.count += 1  | 
 | 89 | +            return func(self, *args, **kwargs)  | 
 | 90 | +        return wrapped  | 
 | 91 | + | 
 | 92 | +    @_increment_count  | 
 | 93 | +    @_check_tau  | 
 | 94 | +    def prox(self, x, tau):  | 
 | 95 | +        sigma = _current_sigma(self.sigma, self.count)  | 
 | 96 | +        kappa = _current_sigma(self.kappa, self.count)  | 
 | 97 | + | 
 | 98 | +        x = np.where(np.abs(x) <= np.sqrt(kappa / sigma * (1 + 2 * tau * sigma)), _l2(x, tau * sigma), x)  | 
 | 99 | +        return x  | 
0 commit comments