Skip to content

Commit 56ab7cb

Browse files
committed
feature: Allow passing optional arguments to solvers in L2
1 parent 2078a6b commit 56ab7cb

File tree

2 files changed

+14
-5
lines changed

2 files changed

+14
-5
lines changed

pyproximal/optimization/primaldual.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,10 +113,10 @@ def PrimalDual(proxf, proxg, A, x0, tau, mu, y0=None, z=None, theta=1., niter=10
113113
# check if tau and mu are scalars or arrays
114114
fixedtau = fixedmu = False
115115
if isinstance(tau, (int, float)):
116-
tau = tau * ncp.ones(niter)
116+
tau = tau * ncp.ones(niter, dtype=x0.dtype)
117117
fixedtau = True
118118
if isinstance(mu, (int, float)):
119-
mu = mu * ncp.ones(niter)
119+
mu = mu * ncp.ones(niter, dtype=x0.dtype)
120120
fixedmu = True
121121

122122
if show:

pyproximal/proximal/L2.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,11 @@ class L2(ProxOperator):
4848
avoids to do so unless the :math:`\tau` or :math:`\sigma` paramets
4949
have changed. Choose ``densesolver=None`` when using PyLops versions
5050
earlier than v1.18.1 or v2.0.0
51+
**kwargs_solver : :obj:`dict`, optional
52+
Dictionary containing extra arguments for
53+
:py:func:`scipy.sparse.linalg.lsqr` solver when using
54+
numpy data (or :py:func:`pylops.optimization.solver.lsqr` and
55+
when using cupy data)
5156
5257
Notes
5358
-----
@@ -89,7 +94,8 @@ class L2(ProxOperator):
8994
9095
"""
9196
def __init__(self, Op=None, b=None, q=None, sigma=1., alpha=1.,
92-
qgrad=True, niter=10, x0=None, warm=True, densesolver=None):
97+
qgrad=True, niter=10, x0=None, warm=True,
98+
densesolver=None, kwargs_solver=None):
9399
super().__init__(Op, True)
94100
self.b = b
95101
self.q = q
@@ -101,6 +107,7 @@ def __init__(self, Op=None, b=None, q=None, sigma=1., alpha=1.,
101107
self.warm = warm
102108
self.densesolver = densesolver
103109
self.count = 0
110+
self.kwargs_solver = {} if kwargs_solver is None else kwargs_solver
104111

105112
# when using factorize, store the first tau*sigma=0 so that the
106113
# first time it will be recomputed (as tau cannot be 0)
@@ -169,9 +176,11 @@ def prox(self, x, tau):
169176
Op1 = Identity(self.Op.shape[1], dtype=self.Op.dtype) + \
170177
float(tau * self.sigma) * (self.Op.H * self.Op)
171178
if get_module_name(get_array_module(x)) == 'numpy':
172-
x = sp_lsqr(Op1, y, iter_lim=niter, x0=self.x0)[0]
179+
x = sp_lsqr(Op1, y, iter_lim=niter, x0=self.x0,
180+
**self.kwargs_solver)[0]
173181
else:
174-
x = lsqr(Op1, y, niter=niter, x0=self.x0)[0].ravel()
182+
x = lsqr(Op1, y, niter=niter, x0=self.x0,
183+
**self.kwargs_solver)[0].ravel()
175184
if self.warm:
176185
self.x0 = x
177186
elif self.b is not None:

0 commit comments

Comments
 (0)