Skip to content

Commit 2078a6b

Browse files
committed
Modified L2 for cupy
1 parent 7a88310 commit 2078a6b

File tree

2 files changed

+10
-4
lines changed

2 files changed

+10
-4
lines changed

pyproximal/projection/AffineSet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from scipy.sparse.linalg import lsqr as sp_lsqr
2-
32
from pylops.optimization.basic import lsqr
43
from pylops.utils.backend import get_array_module, get_module_name
54

5+
66
class AffineSetProj():
77
r"""Affine set projection.
88

pyproximal/proximal/L2.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
import numpy as np
22
from scipy.linalg import cho_factor, cho_solve
3-
from scipy.sparse.linalg import lsqr
3+
from scipy.sparse.linalg import lsqr as sp_lsqr
44
from pylops import MatrixMult, Identity
5+
from pylops.optimization.basic import lsqr
6+
from pylops.utils.backend import get_array_module, get_module_name
7+
58
from pyproximal.ProxOperator import _check_tau
69
from pyproximal import ProxOperator
710

@@ -164,8 +167,11 @@ def prox(self, x, tau):
164167
x = cho_solve(self.cl, y)
165168
else:
166169
Op1 = Identity(self.Op.shape[1], dtype=self.Op.dtype) + \
167-
tau * self.sigma * self.Op.H * self.Op
168-
x = lsqr(Op1, y, iter_lim=niter, x0=self.x0)[0]
170+
float(tau * self.sigma) * (self.Op.H * self.Op)
171+
if get_module_name(get_array_module(x)) == 'numpy':
172+
x = sp_lsqr(Op1, y, iter_lim=niter, x0=self.x0)[0]
173+
else:
174+
x = lsqr(Op1, y, niter=niter, x0=self.x0)[0].ravel()
169175
if self.warm:
170176
self.x0 = x
171177
elif self.b is not None:

0 commit comments

Comments
 (0)