Skip to content

Commit a9cca82

Browse files
authored
Merge pull request #142 from NickLuiken/dev
rMS proximal operator
2 parents a9f3249 + fd56deb commit a9cca82

File tree

6 files changed

+348
-9
lines changed

6 files changed

+348
-9
lines changed

docs/source/api/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ Non-Convex
9292
Log1
9393
QuadraticEnvelopeCard
9494
QuadraticEnvelopeCardIndicator
95+
RelaxedMumfordShah
9596
SCAD
9697

9798

pyproximal/proximal/RelaxedMS.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
import numpy as np
2+
3+
from pyproximal.ProxOperator import _check_tau
4+
from pyproximal import ProxOperator
5+
from pyproximal.proximal.L1 import _current_sigma
6+
7+
8+
def _l2(x, alpha):
9+
r"""Scaling operation.
10+
11+
Applies the proximal of ``alpha||y - x||_2^2`` which is essentially a scaling operation.
12+
13+
Parameters
14+
----------
15+
x : :obj:`numpy.ndarray`
16+
Vector
17+
alpha : :obj:`float`
18+
Scaling parameter
19+
20+
Returns
21+
-------
22+
y : :obj:`numpy.ndarray`
23+
Scaled vector
24+
25+
"""
26+
y = 1 / (1 + 2 * alpha) * x
27+
return y
28+
29+
30+
def _current_kappa(kappa, count):
31+
if not callable(kappa):
32+
return kappa
33+
else:
34+
return kappa(count)
35+
36+
37+
class RelaxedMumfordShah(ProxOperator):
38+
r"""Relaxed Mumford-Shah norm proximal operator.
39+
40+
Proximal operator of the relaxed Mumford-Shah norm:
41+
:math:`\text{rMS}(x) = \min (\alpha\Vert x\Vert_2^2, \kappa)`.
42+
43+
Parameters
44+
----------
45+
sigma : :obj:`float` or :obj:`list` or :obj:`np.ndarray` or :obj:`func`, optional
46+
Multiplicative coefficient of L2 norm that controls the smoothness of the solutuon.
47+
This can be a constant number, a list of values (for multidimensional inputs, acting
48+
on the second dimension) or a function that is called passing a counter which keeps
49+
track of how many times the ``prox`` method has been invoked before and returns a
50+
scalar (or a list of) ``sigma`` to be used.
51+
kappa : :obj:`float` or :obj:`list` or :obj:`np.ndarray` or :obj:`func`, optional
52+
Constant value in the rMS norm which essentially controls when the norm allows a jump. This can be a
53+
constant number, a list of values (for multidimensional inputs, acting on the second dimension) or
54+
a function that is called passing a counter which keeps track of how many
55+
times the ``prox`` method has been invoked before and returns a scalar (or a list of)
56+
``kappa`` to be used.
57+
58+
Notes
59+
-----
60+
The :math:`rMS` proximal operator is defined as [1]_:
61+
62+
.. math::
63+
\text{prox}_{\tau \text{rMS}}(x) =
64+
\begin{cases}
65+
\frac{1}{1+2\tau\alpha}x & \text{ if } & \vert x\vert \leq \sqrt{\frac{\kappa}{\alpha}(1 + 2\tau\alpha)} \\
66+
\kappa & \text{ else }
67+
\end{cases}.
68+
69+
.. [1] Strekalovskiy, E., and D. Cremers, 2014, Real-time minimization of the piecewise smooth
70+
Mumford-Shah functional: European Conference on Computer Vision, 127–141.
71+
72+
"""
73+
def __init__(self, sigma=1., kappa=1.):
74+
super().__init__(None, False)
75+
self.sigma = sigma
76+
self.kappa = kappa
77+
self.count = 0
78+
79+
def __call__(self, x):
80+
sigma = _current_sigma(self.sigma, self.count)
81+
kappa = _current_sigma(self.kappa, self.count)
82+
return np.minimum(sigma * np.linalg.norm(x) ** 2, kappa)
83+
84+
def _increment_count(func):
85+
"""Increment counter
86+
"""
87+
def wrapped(self, *args, **kwargs):
88+
self.count += 1
89+
return func(self, *args, **kwargs)
90+
return wrapped
91+
92+
@_increment_count
93+
@_check_tau
94+
def prox(self, x, tau):
95+
sigma = _current_sigma(self.sigma, self.count)
96+
kappa = _current_sigma(self.kappa, self.count)
97+
98+
x = np.where(np.abs(x) <= np.sqrt(kappa / sigma * (1 + 2 * tau * sigma)), _l2(x, tau * sigma), x)
99+
return x

pyproximal/proximal/__init__.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,21 +7,22 @@
77
Box Box indicator
88
Simplex Simplex indicator
99
Intersection Intersection indicator
10-
AffineSet Affines set indicator
10+
AffineSet Affines set indicator
1111
Quadratic Quadratic function
12-
Nonlinear Nonlinear function
12+
Nonlinear Nonlinear function
1313
L0 L0 Norm
1414
L0Ball L0 Ball
1515
L1 L1 Norm
1616
L1Ball L1 Ball
17-
Euclidean Euclidean Norm
18-
EuclideanBall Euclidean Ball
17+
Euclidean Euclidean Norm
18+
EuclideanBall Euclidean Ball
1919
L2 L2 Norm
2020
L2Convolve L2 Norm of convolution operator
2121
L21 L2,1 Norm
2222
L21_plus_L1 L2,1 + L1 mixed-norm
23-
Huber Huber Norm
24-
TV Total Variation Norm
23+
Huber Huber Norm
24+
TV Total Variation Norm
25+
RelaxedMumfordShah Relaxed Mumford Shah Norm
2526
Nuclear Nuclear Norm
2627
NuclearBall Nuclear Ball
2728
Orthogonal Product between orthogonal operator and vector
@@ -53,6 +54,7 @@
5354
from .L21_plus_L1 import *
5455
from .Huber import *
5556
from .TV import *
57+
from .RelaxedMS import *
5658
from .Nuclear import *
5759
from .Orthogonal import *
5860
from .VStack import *
@@ -66,8 +68,8 @@
6668

6769
__all__ = ['Box', 'Simplex', 'Intersection', 'AffineSet', 'Quadratic',
6870
'Euclidean', 'EuclideanBall', 'L0', 'L0Ball', 'L1', 'L1Ball', 'L2',
69-
'L2Convolve', 'L21', 'L21_plus_L1', 'Huber', 'TV', 'Nuclear',
70-
'NuclearBall', 'Orthogonal', 'VStack', 'Nonlinear', 'SCAD',
71+
'L2Convolve', 'L21', 'L21_plus_L1', 'Huber', 'TV', 'RelaxedMumfordShah',
72+
'Nuclear', 'NuclearBall', 'Orthogonal', 'VStack', 'Nonlinear', 'SCAD',
7173
'Log', 'Log1', 'ETP', 'Geman', 'QuadraticEnvelopeCard', 'SingularValuePenalty',
7274
'QuadraticEnvelopeCardIndicator', 'QuadraticEnvelopeRankL2',
7375
'Hankel']

pytests/test_norms.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55

66
from pylops.basicoperators import Identity, Diagonal, MatrixMult, FirstDerivative
77
from pyproximal.utils import moreau
8-
from pyproximal.proximal import Box, Euclidean, L2, L1, L21, L21_plus_L1, Huber, Nuclear, TV
8+
from pyproximal.proximal import Box, Euclidean, L2, L1, L21, L21_plus_L1, \
9+
Huber, Nuclear, RelaxedMumfordShah, TV
910

1011
par1 = {'nx': 10, 'sigma': 1., 'dtype': 'float32'} # even float32
1112
par2 = {'nx': 11, 'sigma': 2., 'dtype': 'float64'} # odd float64
@@ -202,6 +203,22 @@ def test_TV(par):
202203
assert_array_almost_equal(tv(x), par['sigma'] * np.sum(np.abs(dx), axis=0))
203204

204205

206+
@pytest.mark.parametrize("par", [(par1), (par2)])
207+
def test_rMS(par):
208+
"""rMS norm and proximal/dual proximal
209+
"""
210+
kappa = 1.
211+
rMS = RelaxedMumfordShah(sigma=par['sigma'], kappa=kappa)
212+
213+
# norm
214+
x = np.random.normal(0., 1., par['nx']).astype(par['dtype'])
215+
assert rMS(x) == np.minimum(par['sigma'] * np.linalg.norm(x) ** 2, kappa)
216+
217+
# prox / dualprox
218+
tau = 2.
219+
assert moreau(rMS, x, tau)
220+
221+
205222
def test_Nuclear_FOM():
206223
"""Nuclear norm benchmark with FOM solver
207224
"""

testdata/marmousi_trace.npy

1012 Bytes
Binary file not shown.

0 commit comments

Comments
 (0)