Skip to content

Commit 4f52b0b

Browse files
authored
Merge pull request #168 from mrava87/fix-printprimal
fix: change print of x[0] to work with cupy
2 parents 2d53f07 + c957605 commit 4f52b0b

File tree

1 file changed

+10
-5
lines changed

1 file changed

+10
-5
lines changed

pyproximal/optimization/primal.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -581,7 +581,8 @@ def HQS(proxf, proxg, x0, tau, niter=10, z0=None, gfirst=True,
581581
if iiter < 10 or niter - iiter < 10 or iiter % (niter // 10) == 0:
582582
pf, pg = proxf(x), proxg(x)
583583
msg = '%6g %12.5e %10.3e %10.3e %10.3e' % \
584-
(iiter + 1, x[0], pf, pg, pf + pg)
584+
(iiter + 1, np.real(to_numpy(x[0])),
585+
pf, pg, pf + pg)
585586
print(msg)
586587
if show:
587588
print('\nTotal time (s) = %.2f' % (time.time() - tstart))
@@ -706,7 +707,8 @@ def ADMM(proxf, proxg, x0, tau, niter=10, gfirst=False,
706707
if iiter < 10 or niter - iiter < 10 or iiter % (niter // 10) == 0:
707708
pf, pg = proxf(x), proxg(x)
708709
msg = '%6g %12.5e %10.3e %10.3e %10.3e' % \
709-
(iiter + 1, x[0], pf, pg, pf + pg)
710+
(iiter + 1, np.real(to_numpy(x[0])),
711+
pf, pg, pf + pg)
710712
print(msg)
711713
if show:
712714
print('\nTotal time (s) = %.2f' % (time.time() - tstart))
@@ -807,7 +809,8 @@ def ADMML2(proxg, Op, b, A, x0, tau, niter=10, callback=None, show=False, **kwar
807809
if iiter < 10 or niter - iiter < 10 or iiter % (niter // 10) == 0:
808810
pf, pg = 0.5 * np.linalg.norm(Op @ x - b) ** 2, proxg(Ax)
809811
msg = '%6g %12.5e %10.3e %10.3e %10.3e' % \
810-
(iiter + 1, x[0], pf, pg, pf + pg)
812+
(iiter + 1, np.real(to_numpy(x[0])),
813+
pf, pg, pf + pg)
811814
print(msg)
812815
if show:
813816
print('\nTotal time (s) = %.2f' % (time.time() - tstart))
@@ -912,7 +915,8 @@ def LinearizedADMM(proxf, proxg, A, x0, tau, mu, niter=10,
912915
if iiter < 10 or niter - iiter < 10 or iiter % (niter // 10) == 0:
913916
pf, pg = proxf(x), proxg(Ax)
914917
msg = '%6g %12.5e %10.3e %10.3e %10.3e' % \
915-
(iiter + 1, x[0], pf, pg, pf + pg)
918+
(iiter + 1, np.real(to_numpy(x[0])),
919+
pf, pg, pf + pg)
916920
print(msg)
917921
if show:
918922
print('\nTotal time (s) = %.2f' % (time.time() - tstart))
@@ -1060,7 +1064,8 @@ def TwIST(proxg, A, b, x0, alpha=None, beta=None, eigs=None, niter=10,
10601064
if iiter < 10 or niter - iiter < 10 or iiter % (niter // 10) == 0:
10611065
pf, pg = proxf(x), proxg(x)
10621066
msg = '%6g %12.5e %10.3e %10.3e %10.3e' % \
1063-
(iiter + 1, np.real(to_numpy(x[0])), pf, pg, pf + pg)
1067+
(iiter + 1, np.real(to_numpy(x[0])),
1068+
pf, pg, pf + pg)
10641069
print(msg)
10651070
if show:
10661071
print('\nTotal time (s) = %.2f' % (time.time() - tstart))

0 commit comments

Comments
 (0)