Skip to content

Commit f89d04e

Browse files
committed
minor: small changes to rMS and associated tutorial
1 parent 0f169af commit f89d04e

File tree

5 files changed

+126
-137
lines changed

5 files changed

+126
-137
lines changed

docs/source/api/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ Non-Convex
9292
Log1
9393
QuadraticEnvelopeCard
9494
QuadraticEnvelopeCardIndicator
95+
RelaxedMumfordShah
9596
SCAD
9697

9798

pyproximal/proximal/rMS.py renamed to pyproximal/proximal/RelaxedMS.py

Lines changed: 13 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from pyproximal.ProxOperator import _check_tau
44
from pyproximal import ProxOperator
5+
from pyproximal.proximal.L1 import _current_sigma
56

67

78
def _l2(x, thresh):
@@ -26,34 +27,27 @@ def _l2(x, thresh):
2627
return y
2728

2829

29-
def _current_sigma(sigma, count):
30-
if not callable(sigma):
31-
return sigma
32-
else:
33-
return sigma(count)
34-
35-
3630
def _current_kappa(kappa, count):
3731
if not callable(kappa):
3832
return kappa
3933
else:
4034
return kappa(count)
4135

4236

43-
class rMS(ProxOperator):
44-
r"""relaxed Mumford-Shoh norm proximal operator.
37+
class RelaxedMumfordShah(ProxOperator):
38+
r"""Relaxed Mumford-Shah norm proximal operator.
4539
4640
Proximal operator of the relaxed Mumford-Shah norm:
47-
:math:`\text{rMS}(x) = \min (\alpha\Vert x\Vert_2^2, \kappa).`.
41+
:math:`\text{rMS}(x) = \min (\alpha\Vert x\Vert_2^2, \kappa)`.
4842
4943
Parameters
5044
----------
5145
sigma : :obj:`float` or :obj:`list` or :obj:`np.ndarray` or :obj:`func`, optional
52-
Multiplicative coefficient of L2 norm that controls the smoothness. This can be a constant number, a list
53-
of values (for multidimensional inputs, acting on the second dimension) or
54-
a function that is called passing a counter which keeps track of how many
55-
times the ``prox`` method has been invoked before and returns a scalar (or a list of)
56-
``sigma`` to be used.
46+
Multiplicative coefficient of L2 norm that controls the smoothness of the solutuon.
47+
This can be a constant number, a list of values (for multidimensional inputs, acting
48+
on the second dimension) or a function that is called passing a counter which keeps
49+
track of how many times the ``prox`` method has been invoked before and returns a
50+
scalar (or a list of) ``sigma`` to be used.
5751
kappa : :obj:`float` or :obj:`list` or :obj:`np.ndarray` or :obj:`func`, optional
5852
Constant value in the rMS norm which essentially controls when the norm allows a jump. This can be a
5953
constant number, a list of values (for multidimensional inputs, acting on the second dimension) or
@@ -65,12 +59,12 @@ class rMS(ProxOperator):
6559
6660
Notes
6761
-----
68-
The :math:`\ell_1` proximal operator is defined as [1]_:
62+
The :math:`rMS` proximal operator is defined as [1]_:
6963
7064
.. math::
71-
\text{prox}_{\text{rMS}}(x) =
65+
\text{prox}_{\tau \text{rMS}}(x) =
7266
\begin{cases}
73-
\frac{1}{1+2\alpha}x & \text{ if } & \vert x\vert \leq \sqrt{\frac{\kappa}{\alpha}(1 + 2\alpha)} \\
67+
\frac{1}{1+2\tau\alpha}x & \text{ if } & \vert x\vert \leq \sqrt{\frac{\kappa}{\alpha}(1 + 2\tau\alpha)} \\
7468
\kappa & \text{ else }
7569
\end{cases}.
7670
@@ -89,7 +83,7 @@ def __init__(self, sigma=1., kappa=1., g=None):
8983
def __call__(self, x):
9084
sigma = _current_sigma(self.sigma, self.count)
9185
kappa = _current_sigma(self.kappa, self.count)
92-
return np.minimum(sigma * np.linalg.norm(x)**2, kappa)
86+
return np.minimum(sigma * np.linalg.norm(x) ** 2, kappa)
9387

9488
def _increment_count(func):
9589
"""Increment counter
@@ -107,10 +101,3 @@ def prox(self, x, tau):
107101

108102
x = np.where(np.abs(x) <= np.sqrt(kappa / sigma * (1 + 2 * tau * sigma)), _l2(x, tau * sigma), x)
109103
return x
110-
111-
@_check_tau
112-
def proxdual(self, x, tau):
113-
# x - tau * self.prox(x / tau, 1. / tau)
114-
x = self._proxdual_moreau(x, tau)
115-
116-
return x

pyproximal/proximal/__init__.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,21 +7,22 @@
77
Box Box indicator
88
Simplex Simplex indicator
99
Intersection Intersection indicator
10-
AffineSet Affines set indicator
10+
AffineSet Affines set indicator
1111
Quadratic Quadratic function
12-
Nonlinear Nonlinear function
12+
Nonlinear Nonlinear function
1313
L0 L0 Norm
1414
L0Ball L0 Ball
1515
L1 L1 Norm
1616
L1Ball L1 Ball
17-
Euclidean Euclidean Norm
18-
EuclideanBall Euclidean Ball
17+
Euclidean Euclidean Norm
18+
EuclideanBall Euclidean Ball
1919
L2 L2 Norm
2020
L2Convolve L2 Norm of convolution operator
2121
L21 L2,1 Norm
2222
L21_plus_L1 L2,1 + L1 mixed-norm
23-
Huber Huber Norm
24-
TV Total Variation Norm
23+
Huber Huber Norm
24+
TV Total Variation Norm
25+
RelaxedMumfordShah Relaxed Mumford Shah Norm
2526
Nuclear Nuclear Norm
2627
NuclearBall Nuclear Ball
2728
Orthogonal Product between orthogonal operator and vector
@@ -53,6 +54,7 @@
5354
from .L21_plus_L1 import *
5455
from .Huber import *
5556
from .TV import *
57+
from .RelaxedMS import *
5658
from .Nuclear import *
5759
from .Orthogonal import *
5860
from .VStack import *
@@ -66,8 +68,8 @@
6668

6769
__all__ = ['Box', 'Simplex', 'Intersection', 'AffineSet', 'Quadratic',
6870
'Euclidean', 'EuclideanBall', 'L0', 'L0Ball', 'L1', 'L1Ball', 'L2',
69-
'L2Convolve', 'L21', 'L21_plus_L1', 'Huber', 'TV', 'Nuclear',
70-
'NuclearBall', 'Orthogonal', 'VStack', 'Nonlinear', 'SCAD',
71-
'Log', 'Log1', 'ETP', 'Geman', 'QuadraticEnvelopeCard', 'rMS', 'SingularValuePenalty',
71+
'L2Convolve', 'L21', 'L21_plus_L1', 'Huber', 'TV', 'RelaxedMumfordShah',
72+
'Nuclear', 'NuclearBall', 'Orthogonal', 'VStack', 'Nonlinear', 'SCAD',
73+
'Log', 'Log1', 'ETP', 'Geman', 'QuadraticEnvelopeCard', 'SingularValuePenalty',
7274
'QuadraticEnvelopeCardIndicator', 'QuadraticEnvelopeRankL2',
7375
'Hankel']

pytests/test_norms.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55

66
from pylops.basicoperators import Identity, Diagonal, MatrixMult, FirstDerivative
77
from pyproximal.utils import moreau
8-
from pyproximal.proximal import Box, Euclidean, L2, L1, L21, L21_plus_L1, Huber, Nuclear, TV
8+
from pyproximal.proximal import Box, Euclidean, L2, L1, L21, L21_plus_L1, \
9+
Huber, Nuclear, RelaxedMumfordShah, TV
910

1011
par1 = {'nx': 10, 'sigma': 1., 'dtype': 'float32'} # even float32
1112
par2 = {'nx': 11, 'sigma': 2., 'dtype': 'float64'} # odd float64
@@ -202,6 +203,22 @@ def test_TV(par):
202203
assert_array_almost_equal(tv(x), par['sigma'] * np.sum(np.abs(dx), axis=0))
203204

204205

206+
@pytest.mark.parametrize("par", [(par1), (par2)])
207+
def test_rMS(par):
208+
"""rMS norm and proximal/dual proximal
209+
"""
210+
kappa = 1.
211+
rMS = RelaxedMumfordShah(sigma=par['sigma'], kappa=kappa)
212+
213+
# norm
214+
x = np.random.normal(0., 1., par['nx']).astype(par['dtype'])
215+
assert rMS(x) == np.minimum(par['sigma'] * np.linalg.norm(x) ** 2, kappa)
216+
217+
# prox / dualprox
218+
tau = 2.
219+
assert moreau(rMS, x, tau)
220+
221+
205222
def test_Nuclear_FOM():
206223
"""Nuclear norm benchmark with FOM solver
207224
"""

0 commit comments

Comments
 (0)