Skip to content

Commit 4959e76

Browse files
committed
bugfix: change prints of ProxGrad and AccProxGrad for multivalues epsg
1 parent 0a0aa52 commit 4959e76

File tree

1 file changed

+27
-11
lines changed

1 file changed

+27
-11
lines changed

pyproximal/optimization/primal.py

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def ProximalGradient(proxf, proxg, x0, tau=None, beta=0.5,
132132
when dealing with problems with multiple right-hand-sides
133133
beta : obj:`float`, optional
134134
Backtracking parameter (must be between 0 and 1)
135-
epsg : :obj:`float`, optional
135+
epsg : :obj:`float` or :obj:`np.ndarray`, optional
136136
Scaling factor of g function
137137
niter : :obj:`int`, optional
138138
Number of iterations of iterative scheme
@@ -176,17 +176,23 @@ def ProximalGradient(proxf, proxg, x0, tau=None, beta=0.5,
176176
1/(2\tau)||\mathbf{x} - \mathbf{y}||_2^2`.
177177
178178
"""
179+
# check if epgs is a ve
180+
if np.asarray(epsg).size == 1.:
181+
epsg_print = str(epsg)
182+
else:
183+
epsg_print = 'Multi'
184+
179185
if show:
180186
tstart = time.time()
181187
print('Proximal Gradient\n'
182188
'---------------------------------------------------------\n'
183189
'Proximal operator (f): %s\n'
184190
'Proximal operator (g): %s\n'
185191
'tau = %10e\tbeta=%10e\n'
186-
'epsg = %10e\tniter = %d\t'
192+
'epsg = %s\tniter = %d\t'
187193
'niterback = %d\n' % (type(proxf), type(proxg),
188-
0 if tau is None else tau, beta, epsg,
189-
niter, niterback))
194+
0 if tau is None else tau, beta,
195+
epsg_print, niter, niterback))
190196
head = ' Itn x[0] f g J=f+eps*g'
191197
print(head)
192198

@@ -211,7 +217,9 @@ def ProximalGradient(proxf, proxg, x0, tau=None, beta=0.5,
211217
if iiter < 10 or niter - iiter < 10 or iiter % (niter // 10) == 0:
212218
pf, pg = proxf(x), proxg(x)
213219
msg = '%6g %12.5e %10.3e %10.3e %10.3e' % \
214-
(iiter + 1, x[0], pf, pg, pf + epsg * pg)
220+
(iiter + 1, x[0] if x.ndim == 1 else x[0, 0],
221+
pf, pg[0] if epsg_print == 'Multi' else pg,
222+
pf + np.sum(epsg * pg))
215223
print(msg)
216224
if show:
217225
print('\nTotal time (s) = %.2f' % (time.time() - tstart))
@@ -253,7 +261,7 @@ def AcceleratedProximalGradient(proxf, proxg, x0, tau=None, beta=0.5,
253261
when dealing with problems with multiple right-hand-sides
254262
beta : obj:`float`, optional
255263
Backtracking parameter (must be between 0 and 1)
256-
epsg : :obj:`float`, optional
264+
epsg : :obj:`float` or :obj:`np.ndarray`, optional
257265
Scaling factor of g function
258266
niter : :obj:`int`, optional
259267
Number of iterations of iterative scheme
@@ -294,6 +302,12 @@ def AcceleratedProximalGradient(proxf, proxg, x0, tau=None, beta=0.5,
294302
Imaging Sciences, vol. 2, pp. 183-202. 2009.
295303
296304
"""
305+
# check if epgs is a ve
306+
if np.asarray(epsg).size == 1.:
307+
epsg_print = str(epsg)
308+
else:
309+
epsg_print = 'Multi'
310+
297311
if acceleration not in ['vandenberghe', 'fista']:
298312
raise NotImplementedError('Acceleration should be vandenberghe '
299313
'or fista')
@@ -303,10 +317,10 @@ def AcceleratedProximalGradient(proxf, proxg, x0, tau=None, beta=0.5,
303317
'---------------------------------------------------------\n'
304318
'Proximal operator (f): %s\n'
305319
'Proximal operator (g): %s\n'
306-
'tau = %10e\tepsg = %10e\tniter = %d\n' % (type(proxf),
307-
type(proxg),
308-
0 if tau is None else tau,
309-
epsg, niter))
320+
'tau = %10e\tepsg = %s\tniter = %d\n' % (type(proxf),
321+
type(proxg),
322+
0 if tau is None else tau,
323+
epsg_print, niter))
310324
head = ' Itn x[0] f g J=f+eps*g'
311325
print(head)
312326

@@ -348,7 +362,9 @@ def AcceleratedProximalGradient(proxf, proxg, x0, tau=None, beta=0.5,
348362
if iiter < 10 or niter - iiter < 10 or iiter % (niter // 10) == 0:
349363
pf, pg = proxf(x), proxg(x)
350364
msg = '%6g %12.5e %10.3e %10.3e %10.3e' % \
351-
(iiter + 1, x[0], pf, pg, pf + epsg * pg)
365+
(iiter + 1, x[0] if x.ndim == 1 else x[0, 0],
366+
pf, pg[0] if epsg_print == 'Multi' else pg,
367+
pf + np.sum(epsg * pg))
352368
print(msg)
353369
if show:
354370
print('\nTotal time (s) = %.2f' % (time.time() - tstart))

0 commit comments

Comments
 (0)