11import time
22import numpy as np
33
4+ from pylops .utils .backend import get_array_module
45
5- def PrimalDual (proxf , proxg , A , x0 , tau , mu , z = None , theta = 1. , niter = 10 ,
6- gfirst = True , callback = None , callbacky = False , show = False ):
6+
7+ def PrimalDual (proxf , proxg , A , x0 , tau , mu , y0 = None , z = None , theta = 1. , niter = 10 ,
8+ gfirst = True , callback = None , callbacky = False , returny = False , show = False ):
79 r"""Primal-dual algorithm
810
911 Solves the following (possibly) nonlinear minimization problem using
@@ -45,6 +47,8 @@ def PrimalDual(proxf, proxg, A, x0, tau, mu, z=None, theta=1., niter=10,
4547 mu : :obj:`float` or :obj:`np.ndarray`
4648 Stepsize of subgradient of :math:`g^*`. This can be constant
4749 or function of iterations (in the latter cases provided as np.ndarray)
50+ z0 : :obj:`numpy.ndarray`
51+ Initial auxiliary vector
4852 z : :obj:`numpy.ndarray`, optional
4953 Additional vector
5054 theta : :obj:`float`
@@ -62,6 +66,8 @@ def PrimalDual(proxf, proxg, A, x0, tau, mu, z=None, theta=1., niter=10,
6266 where ``x`` is the current model vector
6367 callbacky : :obj:`bool`, optional
6468 Modify callback signature to (``callback(x, y)``) when ``callbacky=True``
69+ returny : :obj:`bool`, optional
70+ Return also ``y``
6571 show : :obj:`bool`, optional
6672 Display iterations log
6773
@@ -102,13 +108,15 @@ def PrimalDual(proxf, proxg, A, x0, tau, mu, z=None, theta=1., niter=10,
102108 Imaging and Vision, 40, 8pp. 120-145. 2011.
103109
104110 """
111+ ncp = get_array_module (x0 )
112+
105113 # check if tau and mu are scalars or arrays
106114 fixedtau = fixedmu = False
107115 if isinstance (tau , (int , float )):
108- tau = tau * np .ones (niter )
116+ tau = tau * ncp .ones (niter )
109117 fixedtau = True
110118 if isinstance (mu , (int , float )):
111- mu = mu * np .ones (niter )
119+ mu = mu * ncp .ones (niter )
112120 fixedmu = True
113121
114122 if show :
@@ -128,8 +136,7 @@ def PrimalDual(proxf, proxg, A, x0, tau, mu, z=None, theta=1., niter=10,
128136
129137 x = x0 .copy ()
130138 xhat = x .copy ()
131- y = np .zeros (A .shape [0 ], dtype = x .dtype )
132-
139+ y = y0 .copy () if y0 is not None else ncp .zeros (A .shape [0 ], dtype = x .dtype )
133140 for iiter in range (niter ):
134141 xold = x .copy ()
135142 if gfirst :
@@ -165,7 +172,10 @@ def PrimalDual(proxf, proxg, A, x0, tau, mu, z=None, theta=1., niter=10,
165172 if show :
166173 print ('\n Total time (s) = %.2f' % (time .time () - tstart ))
167174 print ('---------------------------------------------------------\n ' )
168- return x
175+ if not returny :
176+ return x
177+ else :
178+ return x , y
169179
170180
171181def AdaptivePrimalDual (proxf , proxg , A , x0 , tau , mu ,
0 commit comments