Skip to content

Commit 862f752

Browse files
authored
Merge pull request #89 from mrava87/patch-factorize
feature: added factorize option to L2
2 parents 6088ff5 + e474bdf commit 862f752

File tree

2 files changed

+43
-25
lines changed

2 files changed

+43
-25
lines changed

pyproximal/proximal/L2.py

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import numpy as np
2-
2+
from scipy.linalg import cho_factor, cho_solve
33
from scipy.sparse.linalg import lsqr
44
from pylops import MatrixMult, Identity
55
from pyproximal.ProxOperator import _check_tau
@@ -39,9 +39,12 @@ class L2(ProxOperator):
3939
Warm start (``True``) or not (``False``). Uses estimate from previous
4040
call of ``prox`` method.
4141
densesolver : :obj:`str`, optional
42-
Use ``numpy`` or ``scipy`` solver when dealing with explicit operators.
43-
Choose ``densesolver=None`` when using PyLops versions earlier than
44-
v1.18.1 or v2.0.0
42+
Use ``numpy``, ``scipy``, or ``factorize`` when dealing with explicit
43+
operators. The former two rely on dense solvers from either library,
44+
whilst the last computes a factorization of the matrix to invert and
45+
avoids to do so unless the :math:`\tau` or :math:`\sigma` paramets
46+
have changed. Choose ``densesolver=None`` when using PyLops versions
47+
earlier than v1.18.1 or v2.0.0
4548
4649
Notes
4750
-----
@@ -96,6 +99,11 @@ def __init__(self, Op=None, b=None, q=None, sigma=1., alpha=1.,
9699
self.densesolver = densesolver
97100
self.count = 0
98101

102+
# when using factorize, store the first tau*sigma=0 so that the
103+
# first time it will be recomputed (as tau cannot be 0)
104+
if self.densesolver == 'factorize':
105+
self.tausigma = 0
106+
99107
# create data term
100108
if self.Op is not None and self.b is not None:
101109
self.OpTb = self.sigma * self.Op.H @ self.b
@@ -137,14 +145,23 @@ def prox(self, x, tau):
137145
if self.q is not None:
138146
y -= tau * self.alpha * self.q
139147
if self.Op.explicit:
140-
Op1 = MatrixMult(np.eye(self.Op.shape[1]) +
141-
tau * self.sigma * self.ATA)
142-
if self.densesolver is None:
143-
# to allow backward compatibility with PyLops versions earlier
144-
# than v1.18.1 and v2.0.0
145-
x = Op1.div(y)
148+
if self.densesolver != 'factorize':
149+
Op1 = MatrixMult(np.eye(self.Op.shape[1]) +
150+
tau * self.sigma * self.ATA)
151+
if self.densesolver is None:
152+
# to allow backward compatibility with PyLops versions earlier
153+
# than v1.18.1 and v2.0.0
154+
x = Op1.div(y)
155+
else:
156+
x = Op1.div(y, densesolver=self.densesolver)
146157
else:
147-
x = Op1.div(y, densesolver=self.densesolver)
158+
if self.tausigma != tau * self.sigma:
159+
# recompute factorization
160+
self.tausigma = tau * self.sigma
161+
ATA = np.eye(self.Op.shape[1]) + \
162+
self.tausigma * self.ATA
163+
self.cl = cho_factor(ATA)
164+
x = cho_solve(self.cl, y)
148165
else:
149166
Op1 = Identity(self.Op.shape[1], dtype=self.Op.dtype) + \
150167
tau * self.sigma * self.Op.H * self.Op

pytests/test_norms.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -89,20 +89,21 @@ def test_L2_op(par):
8989
def test_L2_dense(par):
9090
"""L2 norm of Op*x with dense Op and proximal/dual proximal
9191
"""
92-
b = np.zeros(par['nx'], dtype=par['dtype'])
93-
d = np.random.normal(0., 1., par['nx']).astype(par['dtype'])
94-
l2 = L2(Op=MatrixMult(np.diag(d), dtype=par['dtype']),
95-
b=b, sigma=par['sigma'], densesolver='numpy')
96-
97-
# norm
98-
x = np.random.normal(0., 1., par['nx']).astype(par['dtype'])
99-
assert l2(x) == (par['sigma'] / 2.) * np.linalg.norm(d * x) ** 2
100-
101-
# prox: since Op is a Diagonal operator the denominator becomes
102-
# 1 + sigma*tau*d[i] for every i
103-
tau = 2.
104-
den = 1. + par['sigma'] * tau * d ** 2
105-
assert_array_almost_equal(l2.prox(x, tau), x / den, decimal=4)
92+
for densesolver in ('numpy', 'scipy', 'factorize'):
93+
b = np.zeros(par['nx'], dtype=par['dtype'])
94+
d = np.random.normal(0., 1., par['nx']).astype(par['dtype'])
95+
l2 = L2(Op=MatrixMult(np.diag(d), dtype=par['dtype']),
96+
b=b, sigma=par['sigma'], densesolver=densesolver)
97+
98+
# norm
99+
x = np.random.normal(0., 1., par['nx']).astype(par['dtype'])
100+
assert l2(x) == (par['sigma'] / 2.) * np.linalg.norm(d * x) ** 2
101+
102+
# prox: since Op is a Diagonal operator the denominator becomes
103+
# 1 + sigma*tau*d[i] for every i
104+
tau = 2.
105+
den = 1. + par['sigma'] * tau * d ** 2
106+
assert_array_almost_equal(l2.prox(x, tau), x / den, decimal=4)
106107

107108

108109
@pytest.mark.parametrize("par", [(par1), (par2)])

0 commit comments

Comments
 (0)