Skip to content

Commit 5d378e2

Browse files
committed
feature: modify PnP signature to allow any proximal solver
1 parent b98160e commit 5d378e2

File tree

3 files changed

+55
-26
lines changed

3 files changed

+55
-26
lines changed

pyproximal/optimization/pnp.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -33,19 +33,19 @@ def prox(self, x, tau):
3333
return xden.ravel()
3434

3535

36-
def PlugAndPlay(proxf, denoiser, dims, x0, tau, niter=10,
37-
callback=None, show=False):
38-
r"""Plug-and-Play Priors with ADMM optimization
36+
def PlugAndPlay(proxf, denoiser, dims, x0, solver=ADMM, **kwargs_solver):
37+
r"""Plug-and-Play Priors with any proximal algorithm of choice
3938
40-
Solves the following minimization problem using the ADMM algorithm:
39+
Solves the following minimization problem using any proximal a
40+
lgorithm of choice:
4141
4242
.. math::
4343
4444
\mathbf{x},\mathbf{z} = \argmin_{\mathbf{x}}
4545
f(\mathbf{x}) + \lambda g(\mathbf{x})
4646
47-
where :math:`f(\mathbf{x})` is a function that has a known proximal
48-
operator where :math:`g(\mathbf{x})` is a function acting as implicit
47+
where :math:`f(\mathbf{x})` is a function that has a known gradient or
48+
proximal operator and :math:`g(\mathbf{x})` is a function acting as implicit
4949
prior. Implicit means that no explicit function should be defined: instead,
5050
a denoising algorithm of choice is used. See Notes for details.
5151
@@ -62,6 +62,8 @@ def PlugAndPlay(proxf, denoiser, dims, x0, tau, niter=10,
6262
prior to calling the ``denoiser``
6363
x0 : :obj:`numpy.ndarray`
6464
Initial vector
65+
solver : :func:`pyproximal.optimization.primal` or :func:`pyproximal.optimization.primaldual`
66+
Solver of choice
6567
tau : :obj:`float`, optional
6668
Positive scalar weight, which should satisfy the following condition
6769
to guarantees convergence: :math:`\tau \in (0, 1/L]` where ``L`` is
@@ -83,7 +85,8 @@ def PlugAndPlay(proxf, denoiser, dims, x0, tau, niter=10,
8385
8486
Notes
8587
-----
86-
Plug-and-Play Priors [1]_ can be expressed by the following recursion:
88+
Plug-and-Play Priors [1]_ can be used with any proximal algorithm of choice. For example, when
89+
ADMM is selected, the resulting scheme can be expressed by the following recursion:
8790
8891
.. math::
8992
@@ -119,6 +122,4 @@ def PlugAndPlay(proxf, denoiser, dims, x0, tau, niter=10,
119122
# Denoiser
120123
proxpnp = _Denoise(denoiser, dims=dims)
121124

122-
return ADMM(proxf, proxpnp, tau=tau, x0=x0,
123-
niter=niter, callback=callback,
124-
show=show)
125+
return solver(proxf, proxpnp, x0=x0, **kwargs_solver)

pyproximal/optimization/primal.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -213,10 +213,11 @@ def ProximalGradient(proxf, proxg, x0, tau=None, beta=0.5,
213213
'Proximal operator (f): %s\n'
214214
'Proximal operator (g): %s\n'
215215
'tau = %s\tbeta=%10e\n'
216-
'epsg = %s\tniter = %d\t'
217-
'niterback = %d\n' % (type(proxf), type(proxg),
216+
'epsg = %s\tniter = %d\n'
217+
''
218+
'niterback = %d\tacceleration = %s\n' % (type(proxf), type(proxg),
218219
'Adaptive' if tau is None else str(tau), beta,
219-
epsg_print, niter, niterback))
220+
epsg_print, niter, niterback, acceleration))
220221
head = ' Itn x[0] f g J=f+eps*g'
221222
print(head)
222223

tutorials/plugandplay.py

Lines changed: 40 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,15 @@
1111
1212
As an example, we will consider a simplified MRI experiment, where the
1313
data is created by appling a 2D Fourier Transform to the input model and
14-
by randomly sampling 60% of its values. We will also use the famous
14+
by randomly sampling 60% of its values. We will use the famous
1515
`BM3D <https://pypi.org/project/bm3d>`_ as the denoiser, but any other denoiser
1616
of choice can be used instead!
1717
18+
Finally, whilst in the original paper, PnP is associated to the ADMM solver, subsequent
19+
research showed that the same principle can be applied to pretty much any proximal
20+
solver. We will show how to pass a solver of choice to our
21+
:func:`pyproximal.optimization.pnp.PlugAndPlay` solver.
22+
1823
"""
1924
import numpy as np
2025
import matplotlib.pyplot as plt
@@ -67,7 +72,7 @@
6772

6873
###############################################################################
6974
# At this point we create a denoiser instance using the BM3D algorithm and use
70-
# as Plug-and-Play Prior to the ADMM algorithm
75+
# as Plug-and-Play Prior to the PG and ADMM algorithms
7176

7277
def callback(x, xtrue, errhist):
7378
errhist.append(np.linalg.norm(x - xtrue))
@@ -83,24 +88,46 @@ def callback(x, xtrue, errhist):
8388
denoiser = lambda x, tau: bm3d.bm3d(np.real(x), sigma_psd=sigma * tau,
8489
stage_arg=bm3d.BM3DStages.HARD_THRESHOLDING)
8590

86-
errhist = []
87-
xpnp = pyproximal.optimization.pnp.PlugAndPlay(l2, denoiser, x.shape,
88-
tau=tau, x0=np.zeros(x.size),
89-
niter=40, show=True,
90-
callback=lambda xx: callback(xx, x.ravel(),
91-
errhist))[0]
92-
xpnp = np.real(xpnp.reshape(x.shape))
91+
# PG-Pnp
92+
errhistpg = []
93+
xpnppg = pyproximal.optimization.pnp.PlugAndPlay(l2, denoiser, x.shape,
94+
solver=pyproximal.optimization.primal.ProximalGradient,
95+
tau=tau, x0=np.zeros(x.size),
96+
niter=40,
97+
acceleration='fista',
98+
show=True,
99+
callback=lambda xx: callback(xx, x.ravel(),
100+
errhistpg))
101+
xpnppg = np.real(xpnppg.reshape(x.shape))
102+
103+
# ADMM-PnP
104+
errhistadmm = []
105+
xpnpadmm = pyproximal.optimization.pnp.PlugAndPlay(l2, denoiser, x.shape,
106+
solver=pyproximal.optimization.primal.ADMM,
107+
tau=tau, x0=np.zeros(x.size),
108+
niter=40, show=True,
109+
callback=lambda xx: callback(xx, x.ravel(),
110+
errhistadmm))[0]
111+
xpnpadmm = np.real(xpnpadmm.reshape(x.shape))
93112

94-
fig, axs = plt.subplots(1, 2, figsize=(9, 5))
113+
fig, axs = plt.subplots(1, 3, figsize=(14, 5))
95114
axs[0].imshow(x, vmin=0, vmax=1, cmap="gray")
96115
axs[0].set_title("Model")
97116
axs[0].axis("tight")
98-
axs[1].imshow(xpnp, vmin=0, vmax=1, cmap="gray")
99-
axs[1].set_title("PnP Inversion")
117+
axs[1].imshow(xpnppg, vmin=0, vmax=1, cmap="gray")
118+
axs[1].set_title("PG-PnP Inversion")
100119
axs[1].axis("tight")
120+
axs[2].imshow(xpnpadmm, vmin=0, vmax=1, cmap="gray")
121+
axs[2].set_title("ADMM-PnP Inversion")
122+
axs[2].axis("tight")
101123
plt.tight_layout()
102124

125+
###############################################################################
126+
# Finally, let's compare the error convergence of the two variations of PnP
127+
103128
plt.figure(figsize=(12, 3))
104-
plt.plot(errhist, 'k', lw=2)
129+
plt.plot(errhistpg, 'k', lw=2, label='PG')
130+
plt.plot(errhistadmm, 'r', lw=2, label='ADMM')
105131
plt.title("Error norm")
132+
plt.legend()
106133
plt.tight_layout()

0 commit comments

Comments
 (0)