Skip to content

Commit 6cf849e

Browse files
committed
Added rMS proximal operator and rMS tutorial
1 parent 54bc6ba commit 6cf849e

File tree

3 files changed

+368
-1
lines changed

3 files changed

+368
-1
lines changed

pyproximal/proximal/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,6 @@
6868
'Euclidean', 'EuclideanBall', 'L0', 'L0Ball', 'L1', 'L1Ball', 'L2',
6969
'L2Convolve', 'L21', 'L21_plus_L1', 'Huber', 'TV', 'Nuclear',
7070
'NuclearBall', 'Orthogonal', 'VStack', 'Nonlinear', 'SCAD',
71-
'Log', 'Log1', 'ETP', 'Geman', 'QuadraticEnvelopeCard', 'SingularValuePenalty',
71+
'Log', 'Log1', 'ETP', 'Geman', 'QuadraticEnvelopeCard', 'rMS', 'SingularValuePenalty',
7272
'QuadraticEnvelopeCardIndicator', 'QuadraticEnvelopeRankL2',
7373
'Hankel']

pyproximal/proximal/rMS.py

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
import numpy as np
2+
3+
from pyproximal.ProxOperator import _check_tau
4+
from pyproximal import ProxOperator
5+
6+
7+
def _l2(x, thresh):
8+
r"""Soft thresholding.
9+
10+
Applies soft thresholding to vector ``x - g``.
11+
12+
Parameters
13+
----------
14+
x : :obj:`numpy.ndarray`
15+
Vector
16+
thresh : :obj:`float`
17+
Threshold
18+
19+
Returns
20+
-------
21+
x1 : :obj:`numpy.ndarray`
22+
Tresholded vector
23+
24+
"""
25+
y = 1 / (1 + 2 * thresh) * x
26+
return y
27+
28+
29+
def _current_sigma(sigma, count):
30+
if not callable(sigma):
31+
return sigma
32+
else:
33+
return sigma(count)
34+
35+
36+
def _current_kappa(kappa, count):
37+
if not callable(kappa):
38+
return kappa
39+
else:
40+
return kappa(count)
41+
42+
43+
class rMS(ProxOperator):
44+
r"""L1 norm proximal operator.
45+
46+
Proximal operator of the :math:`\ell_1` norm:
47+
:math:`\sigma\|\mathbf{x} - \mathbf{g}\|_1 = \sigma \sum |x_i - g_i|`.
48+
49+
Parameters
50+
----------
51+
sigma : :obj:`float` or :obj:`list` or :obj:`np.ndarray` or :obj:`func`, optional
52+
Multiplicative coefficient of L1 norm. This can be a constant number, a list
53+
of values (for multidimensional inputs, acting on the second dimension) or
54+
a function that is called passing a counter which keeps track of how many
55+
times the ``prox`` method has been invoked before and returns a scalar (or a list of)
56+
``sigma`` to be used.
57+
g : :obj:`np.ndarray`, optional
58+
Vector to be subtracted
59+
60+
Notes
61+
-----
62+
The :math:`\ell_1` proximal operator is defined as [1]_:
63+
64+
.. math::
65+
66+
\prox_{\tau \sigma \|\cdot\|_1}(\mathbf{x}) =
67+
\operatorname{soft}(\mathbf{x}, \tau \sigma) =
68+
\begin{cases}
69+
x_i + \tau \sigma, & x_i - g_i < -\tau \sigma \\
70+
g_i, & -\tau\sigma \leq x_i - g_i \leq \tau\sigma \\
71+
x_i - \tau\sigma, & x_i - g_i > \tau\sigma\\
72+
\end{cases}
73+
74+
where :math:`\operatorname{soft}` is the so-called called *soft thresholding*.
75+
76+
Moreover, as the conjugate of the :math:`\ell_1` norm is the orthogonal projection of
77+
its dual norm (i.e., :math:`\ell_\inf` norm) onto a unit ball, its dual
78+
operator (when :math:`\mathbf{g}=\mathbf{0}`) is defined as:
79+
80+
.. math::
81+
82+
\prox^*_{\tau \sigma \|\cdot\|_1}(\mathbf{x}) = P_{\|\cdot\|_{\infty} <=\sigma}(\mathbf{x}) =
83+
\begin{cases}
84+
-\sigma, & x_i < -\sigma \\
85+
x_i,& -\sigma \leq x_i \leq \sigma \\
86+
\sigma, & x_i > \sigma\\
87+
\end{cases}
88+
89+
.. [1] Chambolle, and A., Pock, "A first-order primal-dual algorithm for
90+
convex problems with applications to imaging", Journal of Mathematical
91+
Imaging and Vision, 40, 8pp. 120–145. 2011.
92+
93+
"""
94+
def __init__(self, sigma=1., kappa=1., g=None):
95+
super().__init__(None, False)
96+
self.sigma = sigma
97+
self.kappa = kappa
98+
self.g = g
99+
self.gdual = 0 if g is None else g
100+
self.count = 0
101+
102+
def __call__(self, x):
103+
sigma = _current_sigma(self.sigma, self.count)
104+
kappa = _current_sigma(self.kappa, self.count)
105+
return np.minimum(sigma * np.linalg.norm(x)**2, kappa)
106+
107+
def _increment_count(func):
108+
"""Increment counter
109+
"""
110+
def wrapped(self, *args, **kwargs):
111+
self.count += 1
112+
return func(self, *args, **kwargs)
113+
return wrapped
114+
115+
@_increment_count
116+
@_check_tau
117+
def prox(self, x, tau):
118+
sigma = _current_sigma(self.sigma, self.count)
119+
kappa = _current_sigma(self.kappa, self.count)
120+
121+
x = np.where(np.abs(x) <= np.sqrt(kappa / sigma * (1 + 2 * tau * sigma)), _l2(x, tau * sigma), x)
122+
return x
123+
124+
@_check_tau
125+
def proxdual(self, x, tau):
126+
x - tau * self.prox(x / tau, 1. / tau)
127+
# x = self._proxdual_moreau(x, tau)
128+
129+
return x

tutorials/relaxed_mumford-shah.py

Lines changed: 238 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,238 @@
1+
r"""
2+
Relaxed Mumford-Shah regularization
3+
======================================
4+
In this tutorial we will use a relaxed Mumford-Shah (rMS) functional [1] as regularization, which has the following form:
5+
6+
7+
.. math::
8+
\text{rMS}(x) = \min (\alpha\Vert x\Vert_2^2, \kappa).
9+
10+
11+
Its corresponding proximal operator is given by
12+
13+
14+
.. math::
15+
\text{prox}_{\text{rMS}}(x) =
16+
\begin{cases}
17+
\frac{1}{1+2\alpha}x & \text{ if } & \vert x\vert \leq \sqrt{\frac{\kappa}{\alpha}(1 + 2\alpha)} \\
18+
\kappa & \text{ else }
19+
\end{cases}.
20+
21+
22+
rMS is a combination of Tikhonov and TV regularization. Once the rMS hits a certain threshold, the solution will be allowed
23+
to jump due to the constant penalty $\kappa$, and below this value rMS will be smooth due to Tikhonov regularization.
24+
We show three denoising examples: one example that is well-suited for TV regularization and two examples where rMS
25+
outperforms TV and Tikhonov regularization, modeled after the experiments in [2].
26+
27+
28+
**References**
29+
30+
.. [1] Strekalovskiy, E., and D. Cremers, 2014, Real-time minimization of the piecewise smooth Mumford-Shah functional: European Conference on Computer Vision, 127–141
31+
.. [2] Kadu, A., and Kumar, R. and van Leeuwen, Tristan. Full-waveform inversion with Mumford-Shah regularization. SEG International Exposition and Annual Meeting, SEG-2018-2997224
32+
33+
"""
34+
35+
import numpy as np
36+
import pylops
37+
import matplotlib.pyplot as plt
38+
import pyproximal
39+
40+
from pyproximal.proximal import *
41+
from pyproximal import ProxOperator
42+
from pyproximal.optimization.primaldual import *
43+
44+
from pylops import FirstDerivative
45+
from pyproximal import L1, L2
46+
from pyproximal.proximal.rMS import rMS
47+
from pyproximal.optimization.primal import LinearizedADMM
48+
from pylops.optimization.leastsquares import regularized_inversion
49+
50+
np.random.seed(1)
51+
52+
###############################################################################
53+
# We start with a simple model with two jumps that is well-suited for TV regularization
54+
55+
# Create noisy data
56+
nx = 101
57+
idx_jump1 = nx//3
58+
idx_jump2 = 3*nx//4
59+
x = np.zeros(nx)
60+
x[:idx_jump1] = 2
61+
x[idx_jump1:idx_jump2] = 5
62+
n = np.random.normal(0, 0.5, nx)
63+
y = x + n
64+
65+
# Plot the model and the noisy data
66+
fig, axs = plt.subplots(1, 1, figsize=(6, 5))
67+
axs.plot(x, label='True model')
68+
axs.plot(y, label='Noisy model')
69+
axs.legend()
70+
71+
###############################################################################
72+
# For all rMS and TV we use the Linearized ADMM and for Tikhonov we use LSQR
73+
74+
# Define functionals
75+
l2 = L2(b=y)
76+
l1 = L1(sigma=5.)
77+
Dop = FirstDerivative(nx, edge=True, kind='backward')
78+
79+
# TV
80+
L = np.real((Dop.H * Dop).eigs(neigs=1, which='LM')[0])
81+
tau = 1.
82+
mu = 0.99 * tau / L
83+
xTV, _ = LinearizedADMM(l2, l1, Dop, tau=tau, mu=mu,
84+
x0=np.zeros_like(x), niter=200)
85+
86+
# rMS
87+
sigma = 1e5
88+
kappa = 1e0
89+
ms_relaxed = rMS(sigma=sigma, kappa=kappa)
90+
91+
# Solve
92+
tau = 1
93+
mu = 1. / (tau*L)
94+
95+
xrMS, _ = LinearizedADMM(l2, ms_relaxed, Dop, tau=tau, mu=mu,
96+
x0=np.zeros_like(x), niter=200)
97+
98+
# Tikhonov
99+
Op = pylops.Identity(nx)
100+
Regs = [Dop]
101+
epsR = [6e0]
102+
103+
xTikhonov = regularized_inversion(Op=Op, Regs=Regs, y=y, epsRs=epsR)[0]
104+
105+
# Plot the results
106+
fig, axs = plt.subplots(1, 1, figsize=(6, 5))
107+
axs.plot(x, label='True', linewidth=4, color='k')
108+
axs.plot(y, '--', label='Noisy', linewidth=2, color='y')
109+
axs.plot(xTV, label='TV')
110+
axs.plot(xrMS, label='rMS')
111+
axs.plot(xTikhonov, label='Tikhonov')
112+
axs.legend()
113+
114+
###############################################################################
115+
# Next, we consider an example where we replace the first jump with a slope. As we will see, TV can not deal with this
116+
# type of structure since a linear increase will greatly increase the TV norm, and instead TV will make a staircase. rMS.
117+
# on the other hand, can reconstruct the model with high accuracy.
118+
119+
nx = 101
120+
idx_jump1 = nx//3
121+
idx_jump2 = 3*nx//4
122+
x = np.zeros(nx)
123+
x[:idx_jump1] = 2
124+
x[idx_jump1:idx_jump2] = np.linspace(2, 4, idx_jump2 - idx_jump1)
125+
n = np.random.normal(0, 0.25, nx)
126+
y = x + n
127+
128+
# Define functionals
129+
l2 = L2(b=y)
130+
Dop = FirstDerivative(nx, edge=True, kind='backward')
131+
132+
# Plot the model and the noisy data
133+
fig, axs = plt.subplots(1, 1, figsize=(6, 5));
134+
axs.plot(x, label='True model');
135+
axs.plot(y, label='Noisy model');
136+
axs.legend();
137+
138+
###############################################################################
139+
140+
# Define functionals
141+
l2 = L2(b=y)
142+
l1 = L1(sigma=1.)
143+
Dop = FirstDerivative(nx, edge=True, kind='backward')
144+
145+
# TV
146+
L = np.real((Dop.H * Dop).eigs(neigs=1, which='LM')[0])
147+
tau = 1.
148+
mu = 0.99 * tau / L
149+
xTV, _ = LinearizedADMM(l2, l1, Dop, tau=tau, mu=mu,
150+
x0=np.zeros_like(x), niter=200)
151+
152+
# rMS
153+
sigma = 1e1
154+
kappa = 1e0
155+
ms_relaxed = rMS(sigma=sigma, kappa=kappa)
156+
157+
# Solve
158+
tau = 1
159+
mu = 1. / (tau*L)
160+
161+
xrMS, _ = LinearizedADMM(l2, ms_relaxed, Dop, tau=tau, mu=mu,
162+
x0=np.zeros_like(x), niter=200)
163+
164+
# Tikhonov
165+
Op = pylops.Identity(nx)
166+
Regs = [Dop]
167+
epsR = [3e0]
168+
169+
xTikhonov = regularized_inversion(Op=Op, Regs=Regs, y=y, epsRs=epsR)[0]
170+
171+
# Plot the results
172+
fig, axs = plt.subplots(1, 1, figsize=(6, 5))
173+
axs.plot(x, label='True', linewidth=4, color='k')
174+
axs.plot(y, '--', label='Noisy', linewidth=2, color='y')
175+
axs.plot(xTV, label='TV')
176+
axs.plot(xrMS, label='rMS')
177+
axs.plot(xTikhonov, label='Tikhonov')
178+
axs.legend()
179+
180+
###############################################################################
181+
# Finally, we take a trace from a section of the Marmousi model. This trace shows rather smooth behavior with a few jumps,
182+
# which makes it perfectly suited for rMS. TV on the other hand will artificially create a staircasing effect.
183+
184+
# Get a trace from the model and add some noise
185+
m_trace = np.load('../testdata/marmousi_trace.npy')
186+
nz = len(m_trace)
187+
m_trace_noisy = m_trace + np.random.normal(0, 0.1, nz)
188+
189+
# Trace of the Marmousi model
190+
fig, ax = plt.subplots(1, 1, figsize=(6,5))
191+
ax.plot(m_trace, linewidth=2, label='True')
192+
ax.plot(m_trace_noisy, label='Noisy')
193+
ax.set_title('Trace and noisy trace')
194+
ax.axis('tight')
195+
ax.legend()
196+
fig.tight_layout()
197+
198+
###############################################################################
199+
200+
# Define functionals
201+
l2 = L2(b=m_trace_noisy)
202+
l1 = L1(sigma=5e-1)
203+
Dop = FirstDerivative(nz, edge=True, kind='backward')
204+
205+
# TV
206+
L = np.real((Dop.H * Dop).eigs(neigs=1, which='LM')[0])
207+
tau = 1.
208+
mu = 0.99 * tau / L
209+
xTV, _ = LinearizedADMM(l2, l1, Dop, tau=tau, mu=mu,
210+
x0=np.zeros_like(m_trace), niter=200)
211+
212+
# rMS
213+
sigma = 5e0
214+
kappa = 1e-1
215+
ms_relaxed = rMS(sigma=sigma, kappa=kappa)
216+
217+
# Solve
218+
tau = 1
219+
mu = 1. / (tau*L)
220+
221+
xrMS, _ = LinearizedADMM(l2, ms_relaxed, Dop, tau=tau, mu=mu,
222+
x0=np.zeros_like(m_trace), niter=200)
223+
224+
# Tikhonov
225+
Op = pylops.Identity(nz)
226+
Regs = [Dop]
227+
epsR = [3e0]
228+
229+
xTikhonov = regularized_inversion(Op=Op, Regs=Regs, y=m_trace_noisy, epsRs=epsR)[0]
230+
231+
# Plot the results
232+
fig, axs = plt.subplots(1, 1, figsize=(6, 5))
233+
axs.plot(m_trace, label='True', linewidth=4, color='k')
234+
axs.plot(m_trace_noisy, '--', label='Noisy', linewidth=2, color='y')
235+
axs.plot(xTV, label='TV')
236+
axs.plot(xrMS, label='rMS')
237+
axs.plot(xTikhonov, label='Tikhonov')
238+
axs.legend()

0 commit comments

Comments
 (0)