Skip to content

Commit 3898d37

Browse files
committed
WIP: batch sparse.linalg.cg
[skip ci]
1 parent fe03769 commit 3898d37

File tree

1 file changed

+26
-2
lines changed

1 file changed

+26
-2
lines changed

scipy/sparse/linalg/_isolve/iterative.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
from .utils import make_system
77
from scipy.linalg import get_lapack_funcs
88

9+
from scipy._lib import array_api_extra as xpx
10+
911
__all__ = ['bicg', 'bicgstab', 'cg', 'cgs', 'gmres', 'qmr']
1012

1113

@@ -399,21 +401,43 @@ def cg(A, b, x0=None, *, rtol=1e-5, atol=0., maxiter=None, M=None, callback=None
399401
rho_prev, p = None, None
400402

401403
for iteration in range(maxiter):
402-
if np.all(np.linalg.norm(r, axis=-1) < atol): # Are we done?
404+
converged = np.linalg.norm(r, axis=-1) < atol
405+
if np.all(converged):
403406
return x, 0
404407

405408
z = psolve(r)
406409
rho_cur = dotprod(r, z)
410+
407411
if iteration > 0:
412+
<<<<<<< HEAD
408413
beta = rho_cur / rho_prev
409414
p = (beta * p.T).T
415+
=======
416+
beta = xpx.apply_where(
417+
not converged,
418+
(rho_cur, rho_prev),
419+
lambda cur, prev: cur / prev,
420+
fill_value=0.0,
421+
xp=np
422+
)
423+
p = (beta * p.T).T
424+
>>>>>>> 07ade92418 (WIP: batch `sparse.linalg.cg`)
410425
p += z
411426
else: # First spin
412427
p = np.empty_like(r)
413428
p[:] = z[:]
414429

415430
q = matvec(p)
416-
alpha = rho_cur / dotprod(p, q)
431+
c = dotprod(p, q)
432+
433+
alpha = xpx.apply_where(
434+
not converged,
435+
(rho_cur, c),
436+
lambda rc, c: rc / c,
437+
fill_value=0.0,
438+
xp=np
439+
)
440+
417441
x += (alpha * p.T).T
418442
r -= (alpha * q.T).T
419443
rho_prev = rho_cur

0 commit comments

Comments
 (0)