Skip to content

Commit e141baa

Browse files
committed
WIP: batch sparse.linalg.cg
[skip ci]
1 parent fe03769 commit e141baa

File tree

1 file changed

+23
-4
lines changed

1 file changed

+23
-4
lines changed

scipy/sparse/linalg/_isolve/iterative.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
from .utils import make_system
77
from scipy.linalg import get_lapack_funcs
88

9+
from scipy._lib import array_api_extra as xpx
10+
911
__all__ = ['bicg', 'bicgstab', 'cg', 'cgs', 'gmres', 'qmr']
1012

1113

@@ -399,21 +401,38 @@ def cg(A, b, x0=None, *, rtol=1e-5, atol=0., maxiter=None, M=None, callback=None
399401
rho_prev, p = None, None
400402

401403
for iteration in range(maxiter):
402-
if np.all(np.linalg.norm(r, axis=-1) < atol): # Are we done?
404+
converged = np.linalg.norm(r, axis=-1) < atol
405+
if np.all(converged):
403406
return x, 0
404407

405408
z = psolve(r)
406409
rho_cur = dotprod(r, z)
410+
407411
if iteration > 0:
408-
beta = rho_cur / rho_prev
409-
p = (beta * p.T).T
412+
beta = xpx.apply_where(
413+
not converged,
414+
(rho_cur, rho_prev),
415+
lambda cur, prev: cur / prev,
416+
fill_value=0.0,
417+
xp=np
418+
)
419+
p = (beta * p.T).T
410420
p += z
411421
else: # First spin
412422
p = np.empty_like(r)
413423
p[:] = z[:]
414424

415425
q = matvec(p)
416-
alpha = rho_cur / dotprod(p, q)
426+
c = dotprod(p, q)
427+
428+
alpha = xpx.apply_where(
429+
not converged,
430+
(rho_cur, c),
431+
lambda rc, c: rc / c,
432+
fill_value=0.0,
433+
xp=np
434+
)
435+
417436
x += (alpha * p.T).T
418437
r -= (alpha * q.T).T
419438
rho_prev = rho_cur

0 commit comments

Comments
 (0)