1111
1212As an example, we will consider a simplified MRI experiment, where the
1313data is created by appling a 2D Fourier Transform to the input model and
14- by randomly sampling 60% of its values. We will also use the famous
14+ by randomly sampling 60% of its values. We will use the famous
1515`BM3D <https://pypi.org/project/bm3d>`_ as the denoiser, but any other denoiser
1616of choice can be used instead!
1717
18+ Finally, whilst in the original paper, PnP is associated to the ADMM solver, subsequent
19+ research showed that the same principle can be applied to pretty much any proximal
20+ solver. We will show how to pass a solver of choice to our
21+ :func:`pyproximal.optimization.pnp.PlugAndPlay` solver.
22+
1823"""
1924import numpy as np
2025import matplotlib .pyplot as plt
6772
6873###############################################################################
6974# At this point we create a denoiser instance using the BM3D algorithm and use
70- # as Plug-and-Play Prior to the ADMM algorithm
75+ # as Plug-and-Play Prior to the PG and ADMM algorithms
7176
7277def callback (x , xtrue , errhist ):
7378 errhist .append (np .linalg .norm (x - xtrue ))
@@ -83,24 +88,46 @@ def callback(x, xtrue, errhist):
8388denoiser = lambda x , tau : bm3d .bm3d (np .real (x ), sigma_psd = sigma * tau ,
8489 stage_arg = bm3d .BM3DStages .HARD_THRESHOLDING )
8590
86- errhist = []
87- xpnp = pyproximal .optimization .pnp .PlugAndPlay (l2 , denoiser , x .shape ,
88- tau = tau , x0 = np .zeros (x .size ),
89- niter = 40 , show = True ,
90- callback = lambda xx : callback (xx , x .ravel (),
91- errhist ))[0 ]
92- xpnp = np .real (xpnp .reshape (x .shape ))
91+ # PG-Pnp
92+ errhistpg = []
93+ xpnppg = pyproximal .optimization .pnp .PlugAndPlay (l2 , denoiser , x .shape ,
94+ solver = pyproximal .optimization .primal .ProximalGradient ,
95+ tau = tau , x0 = np .zeros (x .size ),
96+ niter = 40 ,
97+ acceleration = 'fista' ,
98+ show = True ,
99+ callback = lambda xx : callback (xx , x .ravel (),
100+ errhistpg ))
101+ xpnppg = np .real (xpnppg .reshape (x .shape ))
102+
103+ # ADMM-PnP
104+ errhistadmm = []
105+ xpnpadmm = pyproximal .optimization .pnp .PlugAndPlay (l2 , denoiser , x .shape ,
106+ solver = pyproximal .optimization .primal .ADMM ,
107+ tau = tau , x0 = np .zeros (x .size ),
108+ niter = 40 , show = True ,
109+ callback = lambda xx : callback (xx , x .ravel (),
110+ errhistadmm ))[0 ]
111+ xpnpadmm = np .real (xpnpadmm .reshape (x .shape ))
93112
94- fig , axs = plt .subplots (1 , 2 , figsize = (9 , 5 ))
113+ fig , axs = plt .subplots (1 , 3 , figsize = (14 , 5 ))
95114axs [0 ].imshow (x , vmin = 0 , vmax = 1 , cmap = "gray" )
96115axs [0 ].set_title ("Model" )
97116axs [0 ].axis ("tight" )
98- axs [1 ].imshow (xpnp , vmin = 0 , vmax = 1 , cmap = "gray" )
99- axs [1 ].set_title ("PnP Inversion" )
117+ axs [1 ].imshow (xpnppg , vmin = 0 , vmax = 1 , cmap = "gray" )
118+ axs [1 ].set_title ("PG- PnP Inversion" )
100119axs [1 ].axis ("tight" )
120+ axs [2 ].imshow (xpnpadmm , vmin = 0 , vmax = 1 , cmap = "gray" )
121+ axs [2 ].set_title ("ADMM-PnP Inversion" )
122+ axs [2 ].axis ("tight" )
101123plt .tight_layout ()
102124
125+ ###############################################################################
126+ # Finally, let's compare the error convergence of the two variations of PnP
127+
103128plt .figure (figsize = (12 , 3 ))
104- plt .plot (errhist , 'k' , lw = 2 )
129+ plt .plot (errhistpg , 'k' , lw = 2 , label = 'PG' )
130+ plt .plot (errhistadmm , 'r' , lw = 2 , label = 'ADMM' )
105131plt .title ("Error norm" )
132+ plt .legend ()
106133plt .tight_layout ()
0 commit comments