Skip to content

Commit 7e26fe4

Browse files
author
Olivier Leblanc
committed
add TV operator
1 parent 4a13fff commit 7e26fe4

File tree

1 file changed

+296
-0
lines changed

1 file changed

+296
-0
lines changed

pyproximal/proximal/TV.py

Lines changed: 296 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,296 @@
1+
import numpy as np
2+
3+
from copy import deepcopy
4+
from scipy.sparse.linalg import lsqr
5+
from pylops import FirstDerivative, Gradient
6+
from pyproximal.ProxOperator import _check_tau
7+
from pyproximal import ProxOperator
8+
9+
10+
class TV(ProxOperator):
11+
r"""TV Norm proximal operator.
12+
13+
Proximal operator for the TV norm defined as: :math:`f(\mathbf{x}) =
14+
\sigma ||\mathbf{x}||_{\text{TV}}`.
15+
16+
Parameters
17+
----------
18+
dim : :obj:`int`
19+
Dimension of the object.
20+
sigma : :obj:`int`, optional
21+
Multiplicative coefficient of TV norm
22+
niter : :obj:`int` or :obj:`func`, optional
23+
Number of iterations of iterative scheme used to compute the proximal.
24+
This can be a constant number or a function that is called passing a
25+
counter which keeps track of how many times the ``prox`` method has
26+
been invoked before and returns the ``niter`` to be used.
27+
x0 : :obj:`np.ndarray`, optional
28+
Initial vector
29+
warm : :obj:`bool`, optional
30+
Warm start (``True``) or not (``False``). Uses estimate from previous
31+
call of ``prox`` method.
32+
rtol : :obj:`float`, optional
33+
Relative tolerance for stopping criterion.
34+
35+
Notes
36+
-----
37+
The proximal algorithm is implemented following [1].
38+
39+
.. [1] Beck, A. and Teboulle, M., "Fast gradient-based algorithms for constrained total variation image denoising and deblurring problems", 2009.
40+
"""
41+
def __init__(self, dim=2, Op=None, sigma=1.,
42+
niter=10, x0=None, warm=True, rtol=1e-4, **kwargs):
43+
super().__init__(Op, True)
44+
self.dim = dim
45+
self.Op = Op
46+
self.sigma = sigma
47+
self.niter = niter
48+
self.x0 = x0
49+
self.warm = warm
50+
self.count = 0
51+
self.rtol = rtol
52+
self.kwargs = kwargs
53+
54+
def __call__(self, x):
55+
if self.dim == 1:
56+
if (x.ndim == 1):
57+
N = len(x)
58+
else:
59+
N = x.shape[0]
60+
derivOp = FirstDerivative(N, dims=None, dir=0, edge=False,
61+
dtype="float64", kind="forward")
62+
dx = derivOp @ x
63+
y = np.sum(np.abs(dx), axis=0)
64+
65+
elif self.dim >= 2:
66+
y = 0
67+
grads = []
68+
gradOp = Gradient(x.shape, edge=False, dtype="float64", kind="forward")
69+
grads = gradOp.matvec(x.ravel())
70+
grads = grads.reshape((self.dim,)+x.shape)
71+
for g in grads:
72+
y += np.power(abs(g), 2)
73+
y = np.sqrt(y)
74+
75+
return self.sigma * np.sum(y)
76+
77+
def _increment_count(func):
78+
"""Increment counter
79+
"""
80+
def wrapped(self, *args, **kwargs):
81+
self.count += 1
82+
return func(self, *args, **kwargs)
83+
return wrapped
84+
85+
@_increment_count
86+
@_check_tau
87+
def prox(self, x, tau):
88+
# define current number of iterations
89+
if isinstance(self.niter, int):
90+
niter = self.niter
91+
else:
92+
niter = self.niter(self.count)
93+
94+
gamma = self.sigma * tau
95+
rtol = self.rtol
96+
97+
# TODO implement test_gamma
98+
# Initialization
99+
sol = x
100+
101+
if self.dim == 1:
102+
if (x.ndim == 1):
103+
N = len(x)
104+
else:
105+
N = x.shape[0]
106+
derivOp = FirstDerivative(N, dims=None, dir=0, edge=False,
107+
dtype="float64", kind="forward")
108+
else:
109+
gradOp = Gradient(x.shape, edge=False, dtype="float64", kind="forward")
110+
111+
if self.dim == 1:
112+
r = derivOp @ (x*0)
113+
# r = op.grad(x * 0, dim=self.dim, **self.kwargs)
114+
rr = deepcopy(r)
115+
elif self.dim == 2:
116+
r, s = gradOp.matvec( (x*0).ravel()).reshape((self.dim,)+x.shape)
117+
rr, ss = deepcopy(r), deepcopy(s)
118+
elif self.dim == 3:
119+
r, s, k = gradOp.matvec( (x*0).ravel()).reshape((self.dim,)+x.shape)
120+
rr, ss, kk = deepcopy(r), deepcopy(s), deepcopy(k)
121+
elif self.dim == 4:
122+
r, s, k, u = gradOp.matvec( (x*0).ravel()).reshape((self.dim,)+x.shape)
123+
rr, ss, kk, uu = deepcopy(r), deepcopy(s), deepcopy(k), deepcopy(u)
124+
125+
if self.dim >= 1:
126+
pold = r
127+
if self.dim >= 2:
128+
qold = s
129+
if self.dim >= 3:
130+
kold = k
131+
if self.dim >= 4:
132+
uold = u
133+
134+
told, prev_obj = 1., 0.
135+
136+
# Initialization for weights
137+
if self.dim >= 1:
138+
try:
139+
wx = self.kwargs["wx"]
140+
except (KeyError, TypeError):
141+
wx = 1.
142+
if self.dim >= 2:
143+
try:
144+
wy = self.kwargs["wy"]
145+
except (KeyError, TypeError):
146+
wy = 1.
147+
if self.dim >= 3:
148+
try:
149+
wz = self.kwargs["wz"]
150+
except (KeyError, TypeError):
151+
wz = 1.
152+
if self.dim >= 4:
153+
try:
154+
wt = self.kwargs["wt"]
155+
except (KeyError, TypeError):
156+
wt = 1.
157+
158+
if self.dim == 1:
159+
mt = wx
160+
elif self.dim == 2:
161+
mt = np.maximum(wx, wy)
162+
elif self.dim == 3:
163+
mt = np.maximum(wx, np.maximum(wy, wz))
164+
elif self.dim == 4:
165+
mt = np.maximum(np.maximum(wx, wy), np.maximum(wz, wt))
166+
167+
168+
if self.dim >= 1:
169+
try:
170+
rr *= np.conjugate(wx)
171+
except KeyError:
172+
pass
173+
if self.dim >= 2:
174+
try:
175+
ss *= np.conjugate(wy)
176+
except KeyError:
177+
pass
178+
if self.dim >= 3:
179+
try:
180+
kk *= np.conjugate(wz)
181+
except KeyError:
182+
pass
183+
if self.dim >= 4:
184+
try:
185+
uu *= np.conjugate(wt)
186+
except KeyError:
187+
pass
188+
189+
iter = 0
190+
while iter <= niter:
191+
# Current Solution
192+
if self.dim == 0:
193+
raise ValueError("Need to input at least one value")
194+
195+
if self.dim >= 1:
196+
div = np.concatenate((np.expand_dims(rr[0, ], axis=0),
197+
rr[1:-1, ] - rr[:-2, ],
198+
-np.expand_dims(rr[-2, ], axis=0)),
199+
axis=0)
200+
201+
if self.dim >= 2:
202+
div += np.concatenate((np.expand_dims(ss[:, 0, ], axis=1),
203+
ss[:, 1:-1, ] - ss[:, :-2, ],
204+
-np.expand_dims(ss[:, -2, ], axis=1)),
205+
axis=1)
206+
207+
if self.dim >= 3:
208+
div += np.concatenate((np.expand_dims(kk[:, :, 0, ], axis=2),
209+
kk[:, :, 1:-1, ] - kk[:, :, :-2, ],
210+
-np.expand_dims(kk[:, :, -2, ], axis=2)),
211+
axis=2)
212+
213+
if self.dim >= 4:
214+
div += np.concatenate((np.expand_dims(uu[:, :, :, 0, ], axis=3),
215+
uu[:, :, :, 1:-1, ] - uu[:, :, :, :-2, ],
216+
-np.expand_dims(uu[:, :, :, -2, ], axis=3)),
217+
axis=3)
218+
sol = x - gamma * div
219+
220+
# Objective function value
221+
obj = 0.5 * np.power(np.linalg.norm(x[:] - sol[:]), 2) + \
222+
gamma * np.sum(self.__call__(sol), axis=0)
223+
if (obj > 1e-10):
224+
rel_obj = np.abs(obj - prev_obj) / obj
225+
else:
226+
rel_obj = 2*rtol
227+
prev_obj = obj
228+
229+
# Stopping criterion
230+
if rel_obj < rtol:
231+
break
232+
233+
# Update divergence vectors and project
234+
if self.dim == 1:
235+
dx = derivOp(sol)
236+
r -= 1. / (4 * gamma * mt**2) * dx
237+
weights = np.maximum(1, np.abs(r))
238+
239+
elif self.dim == 2:
240+
dx, dy = gradOp.matvec( sol.ravel()).reshape((self.dim,)+x.shape)
241+
r -= (1. / (8. * gamma * mt**2.)) * dx
242+
s -= (1. / (8. * gamma * mt**2.)) * dy
243+
weights = np.maximum(1, np.sqrt(np.power(np.abs(r), 2) +
244+
np.power(np.abs(s), 2)))
245+
246+
elif self.dim == 3:
247+
dx, dy, dz = gradOp.matvec( sol.ravel()).reshape((self.dim,)+x.shape)
248+
r -= 1. / (12. * gamma * mt**2) * dx
249+
s -= 1. / (12. * gamma * mt**2) * dy
250+
k -= 1. / (12. * gamma * mt**2) * dz
251+
weights = np.maximum(1, np.sqrt(np.power(np.abs(r), 2) +
252+
np.power(np.abs(s), 2) +
253+
np.power(np.abs(k), 2)))
254+
255+
elif self.dim == 4:
256+
dx, dy, dz, dt = gradOp.matvec( sol.ravel()).reshape((self.dim,)+x.shape)
257+
r -= 1. / (16 * gamma * mt**2) * dx
258+
s -= 1. / (16 * gamma * mt**2) * dy
259+
k -= 1. / (16 * gamma * mt**2) * dz
260+
u -= 1. / (16 * gamma * mt**2) * dt
261+
weights = np.maximum(1, np.sqrt(np.power(np.abs(r), 2) +
262+
np.power(np.abs(s), 2) +
263+
np.power(np.abs(k), 2) +
264+
np.power(np.abs(u), 2)))
265+
266+
# FISTA update
267+
t = (1 + np.sqrt(4 * told**2)) / 2.
268+
269+
if self.dim >= 1:
270+
p = r / weights
271+
r = p + (told - 1) / t * (p - pold)
272+
pold = p
273+
rr = deepcopy(r)
274+
275+
if self.dim >= 2:
276+
q = s / weights
277+
s = q + (told - 1) / t * (q - qold)
278+
ss = deepcopy(s)
279+
qold = q
280+
281+
if self.dim >= 3:
282+
o = k / weights
283+
k = o + (told - 1) / t * (o - kold)
284+
kk = deepcopy(k)
285+
kold = o
286+
287+
if self.dim >= 4:
288+
m = u / weights
289+
u = m + (told - 1) / t * (m - uold)
290+
uu = deepcopy(u)
291+
uold = m
292+
293+
told = t
294+
iter += 1
295+
296+
return sol

0 commit comments

Comments
 (0)